Go to the documentation of this file.
26 #ifndef TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
34 namespace expressions {
42 template <
typename Derived>
99 template <TensorProduct ProductType>
101 static_assert(ProductType == TensorProduct::Contraction ||
102 ProductType == TensorProduct::Hadamard);
106 std::conditional_t<ProductType == TensorProduct::Contraction,
110 std::shared_ptr<BinaryOpPermutationOptimizer> outer_opt, inner_opt;
111 if (!target_indices) {
112 outer_opt = std::make_shared<permopt_type>(
114 left_type::leaves <= right_type::leaves);
117 left_type::leaves <= right_type::leaves);
119 outer_opt = std::make_shared<permopt_type>(
121 outer(
right_.indices()), left_type::leaves <= right_type::leaves);
124 inner(
right_.indices()), left_type::leaves <= right_type::leaves);
127 left_indices_ = BipartiteIndexList(outer_opt->target_left_indices(),
128 inner_opt->target_left_indices());
129 right_indices_ = BipartiteIndexList(outer_opt->target_right_indices(),
130 inner_opt->target_right_indices());
131 indices_ = BipartiteIndexList(outer_opt->target_result_indices(),
132 inner_opt->target_result_indices());
143 using left_tile_type =
typename EngineTrait<left_type>::eval_type;
144 using right_tile_type =
typename EngineTrait<right_type>::eval_type;
145 constexpr
bool left_tile_is_tot =
146 TiledArray::detail::is_tensor_of_tensor_v<left_tile_type>;
147 constexpr
bool right_tile_is_tot =
148 TiledArray::detail::is_tensor_of_tensor_v<right_tile_type>;
149 static_assert(!(left_tile_is_tot ^ right_tile_is_tot),
150 "ContEngine can only handle tensors of same nested-ness "
151 "(both plain or both ToT)");
152 constexpr
bool args_are_plain_tensors =
153 !left_tile_is_tot && !right_tile_is_tot;
154 if (args_are_plain_tensors &&
157 left_.permute_tiles(
false);
159 if (!args_are_plain_tensors &&
164 left_.permute_tiles(
false);
166 if (args_are_plain_tensors &&
169 right_.permute_tiles(
false);
171 if (!args_are_plain_tensors &&
176 right_.permute_tiles(
false);
181 template <
typename D>
197 init_indices_<TensorProduct::Hadamard>(target_indices);
212 left_.init_indices(target_indices);
213 right_.init_indices(target_indices);
219 if (!children_initialized) {
220 left_.init_indices();
224 init_indices_<TensorProduct::Hadamard>();
241 "The TiledRanges of the left- and right-hand arguments of the "
242 "binary operation are not equal:"
243 <<
"\n left = " <<
left_.trange()
244 <<
"\n right = " <<
right_.trange());
248 "The TiledRanges of the left- and right-hand arguments "
249 "of the binary operation are not equal.");
262 const std::shared_ptr<pmap_interface>& pmap) {
263 left_.init_distribution(world, pmap);
278 return perm *
left_.trange();
286 typename left_type::dist_eval_type,
typename right_type::dist_eval_type,
291 const typename left_type::dist_eval_type left =
left_.make_dist_eval();
292 const typename right_type::dist_eval_type right =
right_.make_dist_eval();
295 std::shared_ptr<impl_type> pimpl =
318 #endif // TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
trange_type make_trange(const Permutation &perm) const
Permuting tiled range factory function.
trange_type trange_
The tiled range of the result tensor.
void inc()
Increment the number of tabs.
void init_indices_(const BipartiteIndexList &target_indices={})
BipartiteIndexList right_indices_
Target right-hand index list.
dist_eval_type make_dist_eval() const
Construct the distributed evaluator for this expression.
PermutationType left_inner_permtype_
Left-hand permutation type.
trange_type trange_
The tiled range of the result tensor.
Permutation of a sequence of objects indexed by base-0 indices.
BipartiteIndexList indices_
void perm_indices(const BipartiteIndexList &target_indices)
Set the index list for this expression.
auto outer(const IndexList &p)
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
World * world_
The world where this expression will be evaluated.
ExprEngine< Derived > ExprEngine_
Base class type.
trange_type make_trange() const
Non-permuting tiled range factory function.
void init_indices(const BipartiteIndexList &target_indices)
Initialize the index list of this expression.
Binary expression object.
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
std::shared_ptr< BinaryOpPermutationOptimizer > make_permutation_optimizer(TensorProduct product_type, const IndexList &left_indices, const IndexList &right_indices, bool prefer_to_permute_left)
BipartiteIndexList left_indices_
Target left-hand index list.
auto rank(const DistArray< Tile, Policy > &a)
BinaryEngine< Derived > BinaryEngine_
This class type.
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
auto inner(const IndexList &p)
BipartiteIndexList indices_
#define TA_ASSERT(EXPR,...)
World & get_default_world()
static constexpr bool consumable
World * world_
The world where this expression will be evaluated.
EngineTrait< Derived >::size_type size_type
Size type.
EngineTrait< Derived >::trange_type trange_type
Tiled range type.
right_type right_
The right-hand argument.
EngineTrait< Derived >::right_type right_type
The right-hand expression type.
void init_indices(bool children_initialized=false)
Initialize the index list of this expression.
left_type left_
The left-hand argument.
Expression output stream.
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
void dec()
Decrement the number of tabs.
EngineTrait< Derived >::op_type op_type
The tile operation type.
EngineTrait< Derived >::policy policy
The result policy type.
PermutationType left_outer_permtype_
Left-hand permutation type.
void print(ExprOStream os, const BipartiteIndexList &target_indices) const
Expression print.
EngineTrait< Derived >::dist_eval_type dist_eval_type
The distributed evaluator type.
EngineTrait< Derived >::shape_type shape_type
Shape type.
shape_type shape_
The shape of the result tensor.
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
BinaryEngine(const BinaryExpr< D > &expr)
PermutationType right_inner_permtype_
Right-hand permutation type.
#define TA_USER_ERROR_MESSAGE(m)
EngineTrait< Derived >::value_type value_type
The result tile type.
Binary, distributed tensor evaluator.
shape_type shape_
The shape of the result tensor.
static constexpr unsigned int leaves
EngineTrait< Derived >::left_type left_type
The left-hand expression type.
void print(ExprOStream &os, const BipartiteIndexList &target_indices) const
Expression print.
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
PermutationType right_outer_permtype_
Right-hand permutation type.
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.