26 #ifndef TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED 27 #define TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED 32 #include "../tile_interface/add.h" 33 #include "../tile_interface/permute.h" 47 template <
typename Result,
typename Left,
typename Right,
typename Scalar>
60 Impl(
const madness::cblas::CBLAS_TRANSPOSE left_op,
61 const madness::cblas::CBLAS_TRANSPOSE right_op,
66 alpha_(alpha), perm_(
perm)
76 std::shared_ptr<Impl> pimpl_;
100 const madness::cblas::CBLAS_TRANSPOSE right_op,
114 return pimpl_->gemm_helper_;
122 return pimpl_->perm_;
131 return pimpl_->alpha_;
141 return pimpl_->gemm_helper_.num_contract_ranks();
149 return pimpl_->gemm_helper_.result_rank();
157 return pimpl_->gemm_helper_.left_rank();
165 return pimpl_->gemm_helper_.right_rank();
177 template <
typename Result,
typename Left,
typename Right,
typename Scalar>
213 const madness::cblas::CBLAS_TRANSPOSE right_op,
281 template <
typename Result,
typename Left,
typename Right>
285 TiledArray::detail::ComplexConjugate<void> >
298 typedef decltype(
gemm(std::declval<Left>(), std::declval<Right>(), 1,
299 std::declval<math::GemmHelper>()))
323 const
madness::cblas::CBLAS_TRANSPOSE right_op,
393 template <
typename Result,
typename Left,
typename Right,
typename Scalar>
397 TiledArray::detail::ComplexConjugate<Scalar> >
410 typedef decltype(
gemm(std::declval<Left>(), std::declval<Right>(), 1,
411 std::declval<math::GemmHelper>()))
434 const
madness::cblas::CBLAS_TRANSPOSE right_op,
499 #endif // TILEDARRAY_CONTRACT_REDUCE_H__INCLUDED ContractReduceBase_::first_argument_type first_argument_type
The left tile type.
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduceBase_
This class type.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduceBase_
This class type.
result_type operator()() const
Create a result type object.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
ContractReduce_ & operator=(const ContractReduce_ &)=default
ContractReduceBase_::first_argument_type first_argument_type
The left tile type.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
result_type operator()(result_type &temp) const
Post processing step.
const Left & first_argument_type
The left tile type.
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Permutation &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
Result result_type
The result type.
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
result_type operator()(const result_type &temp) const
Post processing step.
result_type operator()() const
Create a result type object.
result_type operator()() const
Create a result type object.
ContractReduce< Result, Left, Right, Scalar > ContractReduce_
This class type.
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
Tile< Result > & add_to(Tile< Result > &result, const Tile< Arg > &arg)
Add to the result tile.
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
unsigned int right_rank() const
Right-hand argument rank accessor.
const math::GemmHelper & gemm_helper() const
Gemm meta data accessor.
decltype(auto) conj(const Tile< Arg > &arg)
Create a complex conjugated copy of a tile.
Contract and (sum) reduce operation.
result_type operator()(result_type &temp) const
Post processing step.
unsigned int result_rank() const
Result rank accessor.
ContractReduceBase()=default
ContractReduceBase_::second_argument_type second_argument_type
The right tile type.
const Right & second_argument_type
The right tile type.
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
ContractReduceBase(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Permutation &perm=Permutation())
Construct contract/reduce functor.
Result result_type
The result tile type.
Contract and (sum) reduce base.
ContractReduceBase_ & operator=(const ContractReduceBase_ &)=default
constexpr bool empty()
Test for empty tensors in an empty list.
ContractReduceBase_::second_argument_type second_argument_type
The right tile type.
Scalar scalar_type
The scaling factor type.
scalar_type factor() const
Scaling factor accessor.
~ContractReduceBase()=default
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
ContractReduceBase_::second_argument_type second_argument_type
The right tile type.
TILEDARRAY_FORCE_INLINE R conj(const R r)
Wrapper function for std::conj
Contraction to *GEMM helper.
Permutation of a sequence of objects indexed by base-0 indices.
bool empty(const Tile< Arg > &arg)
Check that arg is empty (no data)
ContractReduceBase_::first_argument_type first_argument_type
The left tile type.
Contract and (sum) reduce operation.
unsigned int left_rank() const
Left-hand argument rank accessor.
ContractReduce(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Permutation &perm=Permutation())
Construct contract/reduce functor.
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduce_
This class type.
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduce_
This class type.
Contract and reduce operation.
const Permutation & perm() const
Permutation accessor.
decltype(auto) gemm(const Tile< Left > &left, const Tile< Right > &right, const Scalar factor, const math::GemmHelper &gemm_config)
Contract and scale tile arguments.
Result & conj_to(Tile< Result > &result)
In-place complex conjugate a tile.
~ContractReduce()=default
Specialization of ComplexConjugate<S> for the case of a unit/identity factor.