binary_engine.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * binary_engine.h
22  * Mar 31, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
28 
32 
33 namespace TiledArray {
34 namespace expressions {
35 
36 // Forward declarations
37 template <typename>
38 class BinaryExpr;
39 template <typename>
40 class BinaryEngine;
41 
42 template <typename Derived>
43 class BinaryEngine : public ExprEngine<Derived> {
44  public:
45  // Class hierarchy typedefs
48 
49  // Argument typedefs
50  typedef typename EngineTrait<Derived>::left_type
52  typedef typename EngineTrait<Derived>::right_type
54 
55  // Operational typedefs
56  typedef typename EngineTrait<Derived>::value_type
58  typedef typename EngineTrait<Derived>::op_type
60  typedef
64 
65  // Meta data typedefs
67  typedef typename EngineTrait<Derived>::trange_type
72 
74  static constexpr unsigned int leaves = EngineTrait<Derived>::leaves;
75 
76  protected:
77  // Import base class variables to this scope
79  using ExprEngine_::perm_;
81  using ExprEngine_::pmap_;
82  using ExprEngine_::shape_;
84  using ExprEngine_::world_;
85 
98 
99  template <TensorProduct ProductType>
100  void init_indices_(const BipartiteIndexList& target_indices = {}) {
101  static_assert(ProductType == TensorProduct::Contraction ||
102  ProductType == TensorProduct::Hadamard);
103  // prefer to permute the arg with fewest leaves to try to minimize the
104  // number of possible permutations
105  using permopt_type =
106  std::conditional_t<ProductType == TensorProduct::Contraction,
109 
110  std::shared_ptr<BinaryOpPermutationOptimizer> outer_opt, inner_opt;
111  if (!target_indices) {
112  outer_opt = std::make_shared<permopt_type>(
113  outer(left_.indices()), outer(right_.indices()),
114  left_type::leaves <= right_type::leaves);
115  inner_opt = make_permutation_optimizer(
116  inner(left_.indices()), inner(right_.indices()),
117  left_type::leaves <= right_type::leaves);
118  } else {
119  outer_opt = std::make_shared<permopt_type>(
120  outer(target_indices), outer(left_.indices()),
121  outer(right_.indices()), left_type::leaves <= right_type::leaves);
122  inner_opt = make_permutation_optimizer(
123  inner(target_indices), inner(left_.indices()),
124  inner(right_.indices()), left_type::leaves <= right_type::leaves);
125  }
126 
127  left_indices_ = BipartiteIndexList(outer_opt->target_left_indices(),
128  inner_opt->target_left_indices());
129  right_indices_ = BipartiteIndexList(outer_opt->target_right_indices(),
130  inner_opt->target_right_indices());
131  indices_ = BipartiteIndexList(outer_opt->target_result_indices(),
132  inner_opt->target_result_indices());
133 
134  left_outer_permtype_ = outer_opt->left_permtype();
135  right_outer_permtype_ = outer_opt->right_permtype();
136  left_inner_permtype_ = inner_opt->left_permtype();
137  right_inner_permtype_ = inner_opt->right_permtype();
138 
139  // Here we set the type of permutation that will be applied to the
140  // argument tensors. If both arguments are plain tensors
141  // (tensors-of-scalars) and their permutations can be fused into GEMM,
142  // disable their permutation
143  using left_tile_type = typename EngineTrait<left_type>::eval_type;
144  using right_tile_type = typename EngineTrait<right_type>::eval_type;
145  constexpr bool left_tile_is_tot =
146  TiledArray::detail::is_tensor_of_tensor_v<left_tile_type>;
147  constexpr bool right_tile_is_tot =
148  TiledArray::detail::is_tensor_of_tensor_v<right_tile_type>;
149  static_assert(!(left_tile_is_tot ^ right_tile_is_tot),
150  "ContEngine can only handle tensors of same nested-ness "
151  "(both plain or both ToT)");
152  constexpr bool args_are_plain_tensors =
153  !left_tile_is_tot && !right_tile_is_tot;
154  if (args_are_plain_tensors &&
155  (left_outer_permtype_ == PermutationType::matrix_transpose ||
157  left_.permute_tiles(false);
158  }
159  if (!args_are_plain_tensors &&
160  ((left_outer_permtype_ == PermutationType::matrix_transpose ||
162  (left_inner_permtype_ == PermutationType::matrix_transpose ||
164  left_.permute_tiles(false);
165  }
166  if (args_are_plain_tensors &&
167  (right_outer_permtype_ == PermutationType::matrix_transpose ||
169  right_.permute_tiles(false);
170  }
171  if (!args_are_plain_tensors &&
172  ((left_outer_permtype_ == PermutationType::matrix_transpose ||
174  (right_inner_permtype_ == PermutationType::matrix_transpose ||
176  right_.permute_tiles(false);
177  }
178  }
179 
180  public:
181  template <typename D>
183  : ExprEngine_(expr), left_(expr.left()), right_(expr.right()) {}
184 
186 
192  void perm_indices(const BipartiteIndexList& target_indices) {
193  if (permute_tiles_) {
194  TA_ASSERT(left_.indices().size() == target_indices.size());
195  TA_ASSERT(right_.indices().size() == target_indices.size());
196 
197  init_indices_<TensorProduct::Hadamard>(target_indices);
198 
201 
202  if (left_.indices() != left_indices_) left_.perm_indices(left_indices_);
203  if (right_.indices() != right_indices_)
204  right_.perm_indices(right_indices_);
205  }
206  }
207 
209 
211  void init_indices(const BipartiteIndexList& target_indices) {
212  left_.init_indices(target_indices);
213  right_.init_indices(target_indices);
214  perm_indices(target_indices);
215  }
216 
218  void init_indices(bool children_initialized = false) {
219  if (!children_initialized) {
220  left_.init_indices();
221  right_.init_indices();
222  }
223 
224  init_indices_<TensorProduct::Hadamard>();
227  }
228 
230 
234  void init_struct(const BipartiteIndexList& target_indices) {
235  left_.init_struct(left_indices_);
236  right_.init_struct(right_indices_);
237 #ifndef NDEBUG
238  if (left_.trange() != right_.trange()) {
239  if (TiledArray::get_default_world().rank() == 0) {
241  "The TiledRanges of the left- and right-hand arguments of the "
242  "binary operation are not equal:"
243  << "\n left = " << left_.trange()
244  << "\n right = " << right_.trange());
245  }
246 
247  TA_EXCEPTION(
248  "The TiledRanges of the left- and right-hand arguments "
249  "of the binary operation are not equal.");
250  }
251 #endif // NDEBUG
252  ExprEngine_::init_struct(target_indices);
253  }
254 
256 
261  void init_distribution(World* world,
262  const std::shared_ptr<pmap_interface>& pmap) {
263  left_.init_distribution(world, pmap);
264  right_.init_distribution(world, left_.pmap());
265  ExprEngine_::init_distribution(world, left_.pmap());
266  }
267 
269 
271  trange_type make_trange() const { return left_.trange(); }
272 
274 
277  trange_type make_trange(const Permutation& perm) const {
278  return perm * left_.trange();
279  }
280 
282 
286  typename left_type::dist_eval_type, typename right_type::dist_eval_type,
287  op_type, policy>
288  impl_type;
289 
290  // Construct left and right distributed evaluators
291  const typename left_type::dist_eval_type left = left_.make_dist_eval();
292  const typename right_type::dist_eval_type right = right_.make_dist_eval();
293 
294  // Construct the distributed evaluator type
295  std::shared_ptr<impl_type> pimpl =
296  std::make_shared<impl_type>(left, right, *world_, trange_, shape_,
297  pmap_, perm_, this->derived().make_op());
298 
299  return dist_eval_type(pimpl);
300  }
301 
303 
306  void print(ExprOStream os, const BipartiteIndexList& target_indices) const {
307  ExprEngine_::print(os, target_indices);
308  os.inc();
309  left_.print(os, indices_);
310  right_.print(os, indices_);
311  os.dec();
312  }
313 }; // class BinaryEngine
314 
315 } // namespace expressions
316 } // namespace TiledArray
317 
318 #endif // TILEDARRAY_EXPRESSIONS_BINARY_ENGINE_H__INCLUDED
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
Definition: expr_engine.h:150
trange_type make_trange(const Permutation &perm) const
Permuting tiled range factory function.
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:81
void inc()
Increment the number of tabs.
Definition: expr_trace.h:71
void init_indices_(const BipartiteIndexList &target_indices={})
BipartiteIndexList right_indices_
Target right-hand index list.
Definition: binary_engine.h:89
dist_eval_type make_dist_eval() const
Construct the distributed evaluator for this expression.
PermutationType left_inner_permtype_
Left-hand permutation type.
Definition: binary_engine.h:94
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:81
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
void perm_indices(const BipartiteIndexList &target_indices)
Set the index list for this expression.
auto outer(const IndexList &p)
Definition: index_list.h:879
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:84
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:84
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:73
ExprEngine< Derived > ExprEngine_
Base class type.
Definition: binary_engine.h:47
trange_type make_trange() const
Non-permuting tiled range factory function.
void init_indices(const BipartiteIndexList &target_indices)
Initialize the index list of this expression.
Binary expression object.
Definition: binary_expr.h:39
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
Definition: expr_engine.h:80
std::shared_ptr< BinaryOpPermutationOptimizer > make_permutation_optimizer(TensorProduct product_type, const IndexList &left_indices, const IndexList &right_indices, bool prefer_to_permute_left)
Definition: permopt.h:530
#define TA_EXCEPTION(m)
Definition: error.h:83
BipartiteIndexList left_indices_
Target left-hand index list.
Definition: binary_engine.h:88
auto rank(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1617
BinaryEngine< Derived > BinaryEngine_
This class type.
Definition: binary_engine.h:46
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
auto inner(const IndexList &p)
Definition: index_list.h:872
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
World & get_default_world()
Definition: madness.h:90
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:73
EngineTrait< Derived >::size_type size_type
Size type.
Definition: binary_engine.h:66
EngineTrait< Derived >::trange_type trange_type
Tiled range type.
Definition: binary_engine.h:68
right_type right_
The right-hand argument.
Definition: binary_engine.h:87
EngineTrait< Derived >::right_type right_type
The right-hand expression type.
Definition: binary_engine.h:53
void init_indices(bool children_initialized=false)
Initialize the index list of this expression.
left_type left_
The left-hand argument.
Definition: binary_engine.h:86
Expression output stream.
Definition: expr_trace.h:41
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
Definition: expr_engine.h:171
void dec()
Decrement the number of tabs.
Definition: expr_trace.h:74
EngineTrait< Derived >::op_type op_type
The tile operation type.
Definition: binary_engine.h:59
EngineTrait< Derived >::policy policy
The result policy type.
Definition: binary_engine.h:61
PermutationType left_outer_permtype_
Left-hand permutation type.
Definition: binary_engine.h:90
void print(ExprOStream os, const BipartiteIndexList &target_indices) const
Expression print.
EngineTrait< Derived >::dist_eval_type dist_eval_type
The distributed evaluator type.
Definition: binary_engine.h:63
EngineTrait< Derived >::shape_type shape_type
Shape type.
Definition: binary_engine.h:69
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:82
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
BinaryEngine(const BinaryExpr< D > &expr)
PermutationType right_inner_permtype_
Right-hand permutation type.
Definition: binary_engine.h:96
#define TA_USER_ERROR_MESSAGE(m)
Definition: error.h:93
EngineTrait< Derived >::value_type value_type
The result tile type.
Definition: binary_engine.h:57
Binary, distributed tensor evaluator.
Definition: binary_eval.h:42
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:82
static constexpr unsigned int leaves
Definition: binary_engine.h:74
EngineTrait< Derived >::left_type left_type
The left-hand expression type.
Definition: binary_engine.h:51
void print(ExprOStream &os, const BipartiteIndexList &target_indices) const
Expression print.
Definition: expr_engine.h:256
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
Definition: binary_engine.h:71
PermutationType right_outer_permtype_
Right-hand permutation type.
Definition: binary_engine.h:92
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
Definition: expr_engine.h:80