template<typename Result, typename Left, typename Right, typename Scalar>
class TiledArray::detail::ContractReduce< Result, Left, Right, Scalar >
Contract and (sum) reduce operation.
This encodes a binary tensor contraction mapped to a GEMM, as well as the sum reduction and post-processing.
- Template Parameters
-
Result | The result tile type |
Left | The left-hand tile type |
Right | The right-hand tile type |
Scalar | The scaling factor type |
Definition at line 232 of file contract_reduce.h.
|
typedef ContractReduce< Result, Left, Right, Scalar > | ContractReduce_ |
| This class type. More...
|
|
typedef ContractReduceBase< Result, Left, Right, Scalar > | ContractReduceBase_ |
| This class type. More...
|
|
typedef ContractReduceBase_::first_argument_type | first_argument_type |
| The left tile type. More...
|
|
typedef ContractReduceBase_::second_argument_type | second_argument_type |
| The right tile type. More...
|
|
typedef Result | result_type |
| The result tile type. More...
|
|
typedef Scalar | scalar_type |
|
using | elem_muladd_op_type = void(result_value_type &, const left_value_type &, const right_value_type &) |
|
using | left_value_type = typename Left::value_type |
|
using | result_value_type = typename Result::value_type |
|
using | right_value_type = typename Right::value_type |
|
typedef ContractReduceBase< Result, Left, Right, Scalar > | ContractReduceBase_ |
| This class type. More...
|
|
typedef const Left & | first_argument_type |
| The left tile type. More...
|
|
typedef const Right & | second_argument_type |
| The right tile type. More...
|
|
typedef Result | result_type |
| The result type. More...
|
|
typedef Scalar | scalar_type |
| The scaling factor type. More...
|
|
using | left_value_type = typename Left::value_type |
|
using | right_value_type = typename Right::value_type |
|
using | result_value_type = typename Result::value_type |
|
using | elem_muladd_op_type = void(result_value_type &, const left_value_type &, const right_value_type &) |
|
|
| ContractReduce ()=default |
|
| ContractReduce (const ContractReduce_ &)=default |
|
| ContractReduce (ContractReduce_ &&)=default |
|
| ~ContractReduce ()=default |
|
ContractReduce_ & | operator= (const ContractReduce_ &)=default |
|
ContractReduce_ & | operator= (ContractReduce_ &&)=default |
|
template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>> |
| ContractReduce (const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={}) |
| Construct contract/reduce functor. More...
|
|
result_type | operator() () const |
| Create a result type object. More...
|
|
result_type | operator() (const result_type &temp) const |
| Post processing step. More...
|
|
void | operator() (result_type &result, const result_type &arg) const |
| Reduce two result objects. More...
|
|
void | operator() (result_type &result, const first_argument_type &left, const second_argument_type &right) const |
| Contract a pair of tiles and add to a target tile. More...
|
|
| ContractReduceBase ()=default |
|
| ContractReduceBase (const ContractReduceBase_ &)=default |
|
| ContractReduceBase (ContractReduceBase_ &&)=default |
|
| ~ContractReduceBase ()=default |
|
ContractReduceBase_ & | operator= (const ContractReduceBase_ &)=default |
|
ContractReduceBase_ & | operator= (ContractReduceBase_ &&)=default |
|
template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>> |
| ContractReduceBase (const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={}) |
| Construct contract/reduce functor. More...
|
|
const math::GemmHelper & | gemm_helper () const |
| Gemm meta data accessor. More...
|
|
const BipartitePermutation & | perm () const |
| Permutation accessor. More...
|
|
scalar_type | factor () const |
| Scaling factor accessor. More...
|
|
const auto & | elem_muladd_op () const |
| Element multiply-add op accessor. More...
|
|
unsigned int | num_contract_ranks () const |
| Compute the number of contracted ranks. More...
|
|
unsigned int | result_rank () const |
| Result rank accessor. More...
|
|
unsigned int | left_rank () const |
| Left-hand argument rank accessor. More...
|
|
unsigned int | right_rank () const |
| Right-hand argument rank accessor. More...
|
|
template<typename Result , typename Left , typename Right , typename Scalar >
template<typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>, typename = std::enable_if_t< TiledArray::detail::is_permutation_v<Perm> && std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>, result_value_type&, const left_value_type&, const right_value_type&>>>
TiledArray::detail::ContractReduce< Result, Left, Right, Scalar >::ContractReduce |
( |
const math::blas::Op |
left_op, |
|
|
const math::blas::Op |
right_op, |
|
|
const scalar_type |
alpha, |
|
|
const unsigned int |
result_rank, |
|
|
const unsigned int |
left_rank, |
|
|
const unsigned int |
right_rank, |
|
|
const Perm & |
perm = {} , |
|
|
ElemMultAddOp && |
elem_muladd_op = {} |
|
) |
| |
|
inline |
Construct contract/reduce functor.
- Template Parameters
-
Perm | a permutation type |
ElemMultAddOp | a callable with signature elem_muladd_op_type |
- Parameters
-
left_op | The left-hand BLAS matrix operation |
right_op | The right-hand BLAS matrix operation |
alpha | The scaling factor applied to the contracted tiles |
result_rank | The rank of the result tensor |
left_rank | The rank of the left-hand tensor |
right_rank | The rank of the right-hand tensor |
perm | The permutation to be applied to the result tensor (default = no permute) |
elem_muladd_op | The element multiply-add op |
Definition at line 280 of file contract_reduce.h.