template<typename Result, typename Left, typename Right, typename Scalar>
class TiledArray::detail::ContractReduce< Result, Left, Right, Scalar >
Contract and (sum) reduce operation.
This encodes a binary tensor contraction mapped to a GEMM, as well as the sum reduction and post-processing.
- Template Parameters
-
| Result | The result tile type |
| Left | The left-hand tile type |
| Right | The right-hand tile type |
| Scalar | The scaling factor type |
Definition at line 232 of file contract_reduce.h.
|
| typedef ContractReduce< Result, Left, Right, Scalar > | ContractReduce_ |
| | This class type. More...
|
| |
| typedef ContractReduceBase< Result, Left, Right, Scalar > | ContractReduceBase_ |
| | This class type. More...
|
| |
| typedef ContractReduceBase_::first_argument_type | first_argument_type |
| | The left tile type. More...
|
| |
| typedef ContractReduceBase_::second_argument_type | second_argument_type |
| | The right tile type. More...
|
| |
| typedef Result | result_type |
| | The result tile type. More...
|
| |
| typedef Scalar | scalar_type |
| |
| using | elem_muladd_op_type = void(result_value_type &, const left_value_type &, const right_value_type &) |
| |
| using | left_value_type = typename Left::value_type |
| |
| using | result_value_type = typename Result::value_type |
| |
| using | right_value_type = typename Right::value_type |
| |
| typedef ContractReduceBase< Result, Left, Right, Scalar > | ContractReduceBase_ |
| | This class type. More...
|
| |
| typedef const Left & | first_argument_type |
| | The left tile type. More...
|
| |
| typedef const Right & | second_argument_type |
| | The right tile type. More...
|
| |
| typedef Result | result_type |
| | The result type. More...
|
| |
| typedef Scalar | scalar_type |
| | The scaling factor type. More...
|
| |
| using | left_value_type = typename Left::value_type |
| |
| using | right_value_type = typename Right::value_type |
| |
| using | result_value_type = typename Result::value_type |
| |
| using | elem_muladd_op_type = void(result_value_type &, const left_value_type &, const right_value_type &) |
| |
|
| | ContractReduce ()=default |
| |
| | ContractReduce (const ContractReduce_ &)=default |
| |
| | ContractReduce (ContractReduce_ &&)=default |
| |
| | ~ContractReduce ()=default |
| |
| ContractReduce_ & | operator= (const ContractReduce_ &)=default |
| |
| ContractReduce_ & | operator= (ContractReduce_ &&)=default |
| |
| template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>> |
| | ContractReduce (const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={}) |
| | Construct contract/reduce functor. More...
|
| |
| result_type | operator() () const |
| | Create a result type object. More...
|
| |
| result_type | operator() (const result_type &temp) const |
| | Post processing step. More...
|
| |
| void | operator() (result_type &result, const result_type &arg) const |
| | Reduce two result objects. More...
|
| |
| void | operator() (result_type &result, const first_argument_type &left, const second_argument_type &right) const |
| | Contract a pair of tiles and add to a target tile. More...
|
| |
| | ContractReduceBase ()=default |
| |
| | ContractReduceBase (const ContractReduceBase_ &)=default |
| |
| | ContractReduceBase (ContractReduceBase_ &&)=default |
| |
| | ~ContractReduceBase ()=default |
| |
| ContractReduceBase_ & | operator= (const ContractReduceBase_ &)=default |
| |
| ContractReduceBase_ & | operator= (ContractReduceBase_ &&)=default |
| |
| template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>> |
| | ContractReduceBase (const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={}) |
| | Construct contract/reduce functor. More...
|
| |
| const math::GemmHelper & | gemm_helper () const |
| | Gemm meta data accessor. More...
|
| |
| const BipartitePermutation & | perm () const |
| | Permutation accessor. More...
|
| |
| scalar_type | factor () const |
| | Scaling factor accessor. More...
|
| |
| const auto & | elem_muladd_op () const |
| | Element multiply-add op accessor. More...
|
| |
| unsigned int | num_contract_ranks () const |
| | Compute the number of contracted ranks. More...
|
| |
| unsigned int | result_rank () const |
| | Result rank accessor. More...
|
| |
| unsigned int | left_rank () const |
| | Left-hand argument rank accessor. More...
|
| |
| unsigned int | right_rank () const |
| | Right-hand argument rank accessor. More...
|
| |
template<typename Result , typename Left , typename Right , typename Scalar >
template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>>
| TiledArray::detail::ContractReduce< Result, Left, Right, Scalar >::ContractReduce |
( |
const math::blas::Op |
left_op, |
|
|
const math::blas::Op |
right_op, |
|
|
const scalar_type |
alpha, |
|
|
const unsigned int |
result_rank, |
|
|
const unsigned int |
left_rank, |
|
|
const unsigned int |
right_rank, |
|
|
const Perm & |
perm = {}, |
|
|
ElemMultAddOp && |
elem_muladd_op = {} |
|
) |
| |
|
inline |
Construct contract/reduce functor.
- Template Parameters
-
| Perm | a permutation type |
| ElemMultAddOp | a callable with signature elem_muladd_op_type |
- Parameters
-
| left_op | The left-hand BLAS matrix operation |
| right_op | The right-hand BLAS matrix operation |
| alpha | The scaling factor applied to the contracted tiles |
| result_rank | The rank of the result tensor |
| left_rank | The rank of the left-hand tensor |
| right_rank | The rank of the right-hand tensor |
| perm | The permutation to be applied to the result tensor (default = no permute) |
| elem_muladd_op | The element multiply-add op |
Definition at line 280 of file contract_reduce.h.