cont_engine.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * cont_engine.h
22  * Mar 31, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED
28 
32 #include <TiledArray/proc_grid.h>
36 
37 namespace TiledArray {
38 namespace expressions {
39 
40 // Forward declarations
41 template <typename, typename>
42 class MultExpr;
43 template <typename, typename, typename>
44 class ScalMultExpr;
45 
47 
49 template <typename Derived>
50 class ContEngine : public BinaryEngine<Derived> {
51  public:
52  // Class hierarchy typedefs
55  typedef ExprEngine<Derived>
57 
58  // Argument typedefs
59  typedef typename EngineTrait<Derived>::left_type
61  typedef typename EngineTrait<Derived>::right_type
63 
64  // Operational typedefs
65  typedef typename EngineTrait<Derived>::value_type
67  typedef typename EngineTrait<Derived>::scalar_type
74  typedef
78 
79  // Meta data typedefs
81  typedef typename EngineTrait<Derived>::trange_type
86 
87  protected:
88  // Import base class variables to this scope
98  using ExprEngine_::perm_;
100  using ExprEngine_::pmap_;
101  using ExprEngine_::shape_;
102  using ExprEngine_::trange_;
103  using ExprEngine_::world_;
104 
105  protected:
107 
108  protected:
110  using tile_element_type = typename value_type::value_type;
111  std::function<void(tile_element_type&, const tile_element_type&,
112  const tile_element_type&)>
114  std::function<tile_element_type(const tile_element_type&,
116  const tile_element_type&)>
121  size_type K_ = 1;
122 
123  static unsigned int find(const BipartiteIndexList& indices,
124  const std::string& index_label, unsigned int i,
125  const unsigned int n) {
126  for (; i < n; ++i) {
127  if (indices[i] == index_label) break;
128  }
129 
130  return i;
131  }
132 
135 
139  TensorProduct::Invalid); // init_indices() must initialize this
141  TA_ASSERT(product_type_ == TensorProduct::Hadamard ||
142  product_type_ == TensorProduct::Contraction);
143  return product_type_;
144  }
145 
149  TensorProduct::Invalid); // init_indices() must initialize this
151  TA_ASSERT(inner_product_type_ == TensorProduct::Hadamard ||
152  inner_product_type_ == TensorProduct::Contraction);
153  return inner_product_type_;
154  }
155 
156  public:
158 
162  template <typename L, typename R>
163  ContEngine(const MultExpr<L, R>& expr) : BinaryEngine_(expr), factor_(1) {}
164 
166 
171  template <typename L, typename R, typename S>
173  : BinaryEngine_(expr), factor_(expr.factor()) {}
174 
175  // Pull base class functions into this class.
176  using ExprEngine_::derived;
177  using ExprEngine_::indices;
178 
180 
190  void perm_indices(const BipartiteIndexList& target_indices) {
191  // assert that init_indices has been called
192  TA_ASSERT(left_.indices() && right_.indices());
193  if (permute_tiles_) {
194  this->template init_indices_<TensorProduct::Contraction>(target_indices);
195 
196  // propagate the indices down the tree, if needed
197  if (left_indices_ != left_.indices()) {
198  left_.perm_indices(left_indices_);
199  }
200  if (right_indices_ != right_.indices()) {
201  right_.perm_indices(right_indices_);
202  }
203  }
204  }
205 
207 
212  void init_indices(bool children_initialized = false) {
213  if (!children_initialized) {
214  left_.init_indices();
215  right_.init_indices();
216  }
217 
218  this->template init_indices_<TensorProduct::Contraction>();
219  }
220 
222 
228  void init_indices(const BipartiteIndexList& target_indices) {
229  init_indices();
230  perm_indices(target_indices);
231  }
232 
234 
238  void init_struct(const BipartiteIndexList& target_indices) {
239  // precondition checks
240  // 1. if ToT inner tile op has been initialized
241  if constexpr (TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
244  }
245 
246  // Initialize children
247  left_.init_struct(left_indices_);
248  right_.init_struct(right_indices_);
249 
250  // Initialize the tile operation in this function because it is used to
251  // evaluate the tiled range and shape.
252 
253  const math::blas::Op left_op =
254  (left_outer_permtype_ == PermutationType::matrix_transpose
255  ? math::blas::Transpose
256  : math::blas::NoTranspose);
257  const math::blas::Op right_op =
258  (right_outer_permtype_ == PermutationType::matrix_transpose
259  ? math::blas::Transpose
260  : math::blas::NoTranspose);
261 
262  if (outer(target_indices) != outer(indices_)) {
263  // Initialize permuted structure
264  perm_ = ExprEngine_::make_perm(target_indices);
265  if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
266  op_ = op_type(left_op, right_op, factor_, outer_size(indices_),
269  } else {
270  // factor_ is absorbed into inner_tile_nonreturn_op_
271  op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
275  }
278  } else {
279  // Initialize non-permuted structure
280  if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
281  op_ = op_type(left_op, right_op, factor_, outer_size(indices_),
283  } else {
284  // factor_ is absorbed into inner_tile_nonreturn_op_
285  op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
288  }
291  }
292 
294  shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape);
295  }
296  }
297 
299 
304  void init_distribution(World* world, std::shared_ptr<pmap_interface> pmap) {
305  const unsigned int inner_rank = op_.gemm_helper().num_contract_ranks();
306  const unsigned int left_rank = op_.gemm_helper().left_rank();
307  const unsigned int right_rank = op_.gemm_helper().right_rank();
308  const unsigned int left_outer_rank = left_rank - inner_rank;
309 
310  // Get pointers to the argument sizes
311  const auto* MADNESS_RESTRICT const left_tiles_size =
312  left_.trange().tiles_range().extent_data();
313  const auto* MADNESS_RESTRICT const left_element_size =
314  left_.trange().elements_range().extent_data();
315  const auto* MADNESS_RESTRICT const right_tiles_size =
316  right_.trange().tiles_range().extent_data();
317  const auto* MADNESS_RESTRICT const right_element_size =
318  right_.trange().elements_range().extent_data();
319 
320  // Compute the fused sizes of the contraction
321  size_type M = 1ul, m = 1ul, N = 1ul, n = 1ul;
322  unsigned int i = 0u;
323  for (; i < left_outer_rank; ++i) {
324  M *= left_tiles_size[i];
325  m *= left_element_size[i];
326  }
327  for (; i < left_rank; ++i) K_ *= left_tiles_size[i];
328  for (i = inner_rank; i < right_rank; ++i) {
329  N *= right_tiles_size[i];
330  n *= right_element_size[i];
331  }
332 
333  // Construct the process grid.
334  proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n);
335 
336  // Initialize children
337  left_.init_distribution(world, proc_grid_.make_row_phase_pmap(K_));
338  right_.init_distribution(world, proc_grid_.make_col_phase_pmap(K_));
339 
340  // Initialize the process map in not already defined
341  if (!pmap) pmap = proc_grid_.make_pmap();
342  ExprEngine_::init_distribution(world, pmap);
343  }
344 
346 
349  trange_type make_trange(const Permutation& perm = {}) const {
350  // Compute iteration limits
351  const unsigned int left_rank = op_.gemm_helper().left_rank();
352  const unsigned int right_rank = op_.gemm_helper().right_rank();
353  const unsigned int inner_rank = op_.gemm_helper().num_contract_ranks();
354  const unsigned int left_outer_rank = left_rank - inner_rank;
355 
356  // Construct the trange input and compute the gemm sizes
357  typename trange_type::Ranges ranges(op_.gemm_helper().result_rank());
358  unsigned int i = 0ul;
359  for (unsigned int x = 0ul; x < left_outer_rank; ++x, ++i) {
360  const unsigned int pi = (perm ? perm[i] : i);
361  ranges[pi] = left_.trange().data()[x];
362  }
363  for (unsigned int x = inner_rank; x < right_rank; ++x, ++i) {
364  const unsigned int pi = (perm ? perm[i] : i);
365  ranges[pi] = right_.trange().data()[x];
366  }
367 
368 #ifndef NDEBUG
369 
370  // Check that the contracted dimensions have congruent tilings
371  for (unsigned int l = left_outer_rank, r = 0ul; l < left_rank; ++l, ++r) {
372  if (!is_congruent(left_.trange().data()[l], right_.trange().data()[r])) {
373  if (TiledArray::get_default_world().rank() == 0) {
375  "The contracted dimensions of the left- "
376  "and right-hand arguments are not congruent:"
377  << "\n left = " << left_.trange()
378  << "\n right = " << right_.trange());
379 
380  TA_EXCEPTION(
381  "The contracted dimensions of the left- and "
382  "right-hand expressions are not congruent.");
383  }
384 
385  TA_EXCEPTION(
386  "The contracted dimensions of the left- and "
387  "right-hand expressions are not congruent.");
388  }
389  }
390 #endif // NDEBUG
391 
392  return trange_type(ranges.begin(), ranges.end());
393  }
394 
396 
399  const TiledArray::math::GemmHelper shape_gemm_helper(
400  math::blas::NoTranspose, math::blas::NoTranspose,
403  return left_.shape().gemm(right_.shape(), factor_, shape_gemm_helper);
404  }
405 
407 
410  shape_type make_shape(const Permutation& perm) const {
411  const TiledArray::math::GemmHelper shape_gemm_helper(
412  math::blas::NoTranspose, math::blas::NoTranspose,
415  return left_.shape().gemm(right_.shape(), factor_, shape_gemm_helper, perm);
416  }
417 
419  // Define the impl type
420  typedef TiledArray::detail::Summa<typename left_type::dist_eval_type,
421  typename right_type::dist_eval_type,
422  op_type, typename Derived::policy>
423  impl_type;
424 
425  typename left_type::dist_eval_type left = left_.make_dist_eval();
426  typename right_type::dist_eval_type right = right_.make_dist_eval();
427 
428  std::shared_ptr<impl_type> pimpl =
429  std::make_shared<impl_type>(left, right, *world_, trange_, shape_,
430  pmap_, perm_, op_, K_, proc_grid_);
431 
432  return dist_eval_type(pimpl);
433  }
434 
436 
438  std::string make_tag() const {
439  std::stringstream ss;
440  ss << "[*]";
441  if (factor_ != scalar_type(1)) ss << "[" << factor_ << "]";
442  return ss.str();
443  }
444 
446 
449  void print(ExprOStream os, const BipartiteIndexList& target_indices) const {
450  ExprEngine_::print(os, target_indices);
451  os.inc();
452  left_.print(os, left_indices_);
453  right_.print(os, right_indices_);
454  os.dec();
455  }
456 
457  protected:
458  void init_inner_tile_op(const IndexList& inner_target_indices) {
459  if constexpr (TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
460  using inner_tile_type = typename value_type::value_type;
461  const auto inner_prod = this->inner_product_type();
462  TA_ASSERT(inner_prod == TensorProduct::Contraction ||
463  inner_prod == TensorProduct::Hadamard);
464  if (inner_prod == TensorProduct::Contraction) {
465  using inner_tile_type = typename value_type::value_type;
466  using contract_inner_tile_type =
467  TiledArray::detail::ContractReduce<inner_tile_type, inner_tile_type,
468  inner_tile_type, scalar_type>;
469  // factor_ is absorbed into inner_tile_nonreturn_op_
470  auto contrreduce_op =
471  (inner_target_indices != inner(this->indices_))
472  ? contract_inner_tile_type(
475  inner_size(this->indices_),
476  inner_size(this->left_indices_),
477  inner_size(this->right_indices_),
478  (this->permute_tiles_ ? inner(this->perm_)
479  : Permutation{}))
480  : contract_inner_tile_type(
482  to_cblas_op(this->right_inner_permtype_), this->factor_,
483  inner_size(this->indices_),
484  inner_size(this->left_indices_),
485  inner_size(this->right_indices_));
486  this->inner_tile_nonreturn_op_ = [contrreduce_op](
487  inner_tile_type& result,
488  const inner_tile_type& left,
489  const inner_tile_type& right) {
490  contrreduce_op(result, left, right);
491  };
492  } else if (inner_prod == TensorProduct::Hadamard) {
493  // inner tile op depends on the outer op ... e.g. if outer op
494  // is contract then inner must implement (ternary) multiply-add;
495  // if the outer is hadamard then the inner is binary multiply
496  const auto outer_prod = this->product_type();
497  if (this->factor_ == 1) {
498  using base_op_type =
499  TiledArray::detail::Mult<inner_tile_type, inner_tile_type,
500  inner_tile_type, false, false>;
502  base_op_type>; // can't consume inputs if they are used multiple
503  // times, e.g. when outer op is gemm
504  auto mult_op = (inner_target_indices != inner(this->indices_))
505  ? op_type(base_op_type(), this->permute_tiles_
506  ? inner(this->perm_)
507  : Permutation{})
508  : op_type(base_op_type());
509  this->inner_tile_nonreturn_op_ = [mult_op, outer_prod](
510  inner_tile_type& result,
511  const inner_tile_type& left,
512  const inner_tile_type& right) {
513  if (outer_prod == TensorProduct::Hadamard)
514  result = mult_op(left, right);
515  else {
516  TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
517  outer_prod == TensorProduct::Contraction);
518  // there is currently no fused MultAdd ternary Op, only Add and
519  // Mult thus implement this as 2 separate steps
520  // TODO optimize by implementing (ternary) MultAdd
521  if (empty(result))
522  result = mult_op(left, right);
523  else {
524  auto result_increment = mult_op(left, right);
525  add_to(result, result_increment);
526  }
527  }
528  };
529  } else {
530  using base_op_type =
531  TiledArray::detail::ScalMult<inner_tile_type, inner_tile_type,
532  inner_tile_type, scalar_type, false,
533  false>;
535  base_op_type>; // can't consume inputs if they are used multiple
536  // times, e.g. when outer op is gemm
537  auto mult_op = (inner_target_indices != inner(this->indices_))
538  ? op_type(base_op_type(this->factor_),
539  this->permute_tiles_ ? inner(this->perm_)
540  : Permutation{})
541  : op_type(base_op_type(this->factor_));
542  this->inner_tile_nonreturn_op_ = [mult_op, outer_prod](
543  inner_tile_type& result,
544  const inner_tile_type& left,
545  const inner_tile_type& right) {
546  TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
547  outer_prod == TensorProduct::Contraction);
548  if (outer_prod == TensorProduct::Hadamard)
549  result = mult_op(left, right);
550  else {
551  // there is currently no fused MultAdd ternary Op, only Add and
552  // Mult thus implement this as 2 separate steps
553  // TODO optimize by implementing (ternary) MultAdd
554  if (empty(result))
555  result = mult_op(left, right);
556  else {
557  auto result_increment = mult_op(left, right);
558  add_to(result, result_increment);
559  }
560  }
561  };
562  }
563  } else
564  abort(); // unsupported TensorProduct type
566  this->inner_tile_return_op_ =
567  [inner_tile_nonreturn_op = this->inner_tile_nonreturn_op_](
568  const inner_tile_type& left, const inner_tile_type& right) {
569  inner_tile_type result;
570  inner_tile_nonreturn_op(result, left, right);
571  return result;
572  };
573  }
574  }
575 
576 }; // class ContEngine
577 
578 } // namespace expressions
579 } // namespace TiledArray
580 
581 #endif // TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED
void init_indices(bool children_initialized=false)
Initialize the index list of this expression.
Definition: cont_engine.h:212
dist_eval_type make_dist_eval() const
Definition: cont_engine.h:418
auto outer_size(const IndexList &p)
Definition: index_list.h:883
Contraction to *GEMM helper.
Definition: gemm_helper.h:40
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:82
std::function< void(tile_element_type &, const tile_element_type &, const tile_element_type &)> inner_tile_nonreturn_op_
Definition: cont_engine.h:113
::blas::Op Op
Definition: blas.h:46
void inc()
Increment the number of tabs.
Definition: expr_trace.h:71
unsigned int left_rank() const
Left-hand argument rank accessor.
Definition: gemm_helper.h:138
BipartiteIndexList right_indices_
Target right-hand index list.
Definition: binary_engine.h:89
PermutationType right_inner_permtype_
Right-hand permutation type.
Definition: binary_engine.h:96
PermutationType left_inner_permtype_
Left-hand permutation type.
Definition: binary_engine.h:94
right_type right_
The right-hand argument.
Definition: binary_engine.h:87
Tile scale-multiplication operation.
Definition: mult.h:269
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
Definition: cont_engine.h:85
PermutationType left_inner_permtype_
Left-hand permutation type.
Definition: binary_engine.h:94
derived_type & derived()
Cast this object to its derived type.
Definition: expr_engine.h:209
Contract and (sum) reduce operation.
A 2D processor grid.
Definition: proc_grid.h:58
PermutationType left_outer_permtype_
Left-hand permutation type.
Definition: binary_engine.h:90
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:84
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:81
TiledArray::detail::ContractReduce< value_type, typename eval_trait< typename left_type::value_type >::type, typename eval_trait< typename right_type::value_type >::type, scalar_type > op_type
The tile operation type.
Definition: cont_engine.h:73
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
Tile multiplication operation.
Definition: mult.h:57
const math::GemmHelper & gemm_helper() const
Gemm meta data accessor.
auto outer(const IndexList &p)
Definition: index_list.h:879
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:81
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:84
TensorProduct inner_product_type() const
Definition: cont_engine.h:147
std::shared_ptr< EngineParamOverride< Derived > > override_ptr_
The engine params overriding the default.
Definition: expr_engine.h:86
std::shared_ptr< Pmap > make_pmap() const
Construct a cyclic process.
Definition: proc_grid.h:515
BipartiteIndexList left_indices_
Target left-hand index list.
Definition: binary_engine.h:88
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:73
std::string make_tag() const
Expression identification tag.
Definition: cont_engine.h:438
trange_type make_trange() const
Non-permuting tiled range factory function.
const BipartiteIndexList & indices() const
Index list accessor.
Definition: expr_engine.h:224
ContEngine< Derived > ContEngine_
This class type.
Definition: cont_engine.h:53
shape_type make_shape() const
Non-permuting shape factory function.
Definition: cont_engine.h:398
EngineTrait< Derived >::trange_type trange_type
Tiled range type.
Definition: cont_engine.h:82
#define TA_EXCEPTION(m)
Definition: error.h:83
EngineTrait< Derived >::dist_eval_type dist_eval_type
The distributed evaluator type.
Definition: cont_engine.h:77
BipartiteIndexList left_indices_
Target left-hand index list.
Definition: binary_engine.h:88
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
Definition: expr_engine.h:80
TiledArray::detail::ProcGrid proc_grid_
Process grid for the contraction.
Definition: cont_engine.h:120
EngineTrait< Derived >::scalar_type scalar_type
Tile scalar type.
Definition: cont_engine.h:68
Distributed contraction evaluator implementation.
BipartiteIndexList right_indices_
Target right-hand index list.
Definition: binary_engine.h:89
void perm_indices(const BipartiteIndexList &target_indices)
Set the index list for this expression.
Definition: cont_engine.h:190
auto inner(const IndexList &p)
Definition: index_list.h:872
Binary tile operation wrapper.
blas::Op to_cblas_op(PermutationType permtype)
Definition: permopt.h:46
Multiplication expression.
Definition: mult_expr.h:143
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
std::shared_ptr< Pmap > make_col_phase_pmap(const size_type rows) const
Construct column phased a cyclic process.
Definition: proc_grid.h:528
World & get_default_world()
Definition: madness.h:90
unsigned int result_rank() const
Result rank accessor.
Definition: gemm_helper.h:133
std::function< tile_element_type(const tile_element_type &, const tile_element_type &)> inner_tile_return_op_
Definition: cont_engine.h:117
EngineTrait< Derived >::size_type size_type
Size type.
Definition: binary_engine.h:66
ContEngine(const ScalMultExpr< L, R, S > &expr)
Constructor.
Definition: cont_engine.h:172
EngineTrait< Derived >::trange_type trange_type
Tiled range type.
Definition: binary_engine.h:68
right_type right_
The right-hand argument.
Definition: binary_engine.h:87
Tile< Result > & add_to(Tile< Result > &result, const Tile< Arg > &arg)
Add to the result tile.
Definition: tile.h:831
EngineTrait< Derived >::left_type left_type
The left-hand expression type.
Definition: cont_engine.h:60
left_type left_
The left-hand argument.
Definition: binary_engine.h:86
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:73
EngineTrait< Derived >::size_type size_type
Size type.
Definition: cont_engine.h:80
ExprEngine< Derived > ExprEngine_
Expression engine base class type.
Definition: cont_engine.h:56
static unsigned int find(const BipartiteIndexList &indices, const std::string &index_label, unsigned int i, const unsigned int n)
Definition: cont_engine.h:123
void print(ExprOStream os, const BipartiteIndexList &target_indices) const
Expression print.
Definition: cont_engine.h:449
Expression output stream.
Definition: expr_trace.h:41
void init_indices(const BipartiteIndexList &target_indices)
Initialize the index list of this expression.
Definition: cont_engine.h:228
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
Definition: expr_engine.h:171
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
Definition: cont_engine.h:238
void dec()
Decrement the number of tabs.
Definition: expr_trace.h:74
bool is_congruent(const BlockRange &r1, const BlockRange &r2)
Test that two BlockRange objects are congruent.
Definition: block_range.h:400
scalar_type factor_
Contraction scaling factor.
Definition: cont_engine.h:106
EngineTrait< Derived >::op_type op_type
The tile operation type.
Definition: binary_engine.h:59
op_type op_
Tile operation.
Definition: cont_engine.h:109
EngineTrait< Derived >::value_type value_type
The result tile type.
Definition: cont_engine.h:66
unsigned int right_rank() const
Right-hand argument rank accessor.
Definition: gemm_helper.h:143
PermutationType left_outer_permtype_
Left-hand permutation type.
Definition: binary_engine.h:90
EngineTrait< Derived >::shape_type shape_type
Shape type.
Definition: cont_engine.h:83
TensorProduct
types of binary tensor products known to TiledArray
Definition: product.h:35
BipartitePermutation make_perm(const BipartiteIndexList &target_indices) const
Permutation factory function.
Definition: expr_engine.h:188
ContEngine(const MultExpr< L, R > &expr)
Constructor.
Definition: cont_engine.h:163
void init_inner_tile_op(const IndexList &inner_target_indices)
Definition: cont_engine.h:458
Multiplication expression engine.
Definition: cont_engine.h:50
EngineTrait< Derived >::dist_eval_type dist_eval_type
The distributed evaluator type.
Definition: binary_engine.h:63
EngineTrait< Derived >::shape_type shape_type
Shape type.
Definition: binary_engine.h:69
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
Definition: gemm_helper.h:126
TensorProduct product_type() const
Definition: cont_engine.h:137
bool empty(const Tile< Arg > &arg)
Check that arg is empty (no data)
Definition: tile.h:646
EngineTrait< Derived >::policy policy
The result policy type.
Definition: cont_engine.h:75
EngineTrait< Derived >::right_type right_type
The right-hand expression type.
Definition: cont_engine.h:62
std::shared_ptr< Pmap > make_row_phase_pmap(const size_type cols) const
Construct row phased a cyclic process.
Definition: proc_grid.h:541
const BipartiteIndexList & indices() const
Index list accessor.
Definition: expr_engine.h:224
PermutationType right_inner_permtype_
Right-hand permutation type.
Definition: binary_engine.h:96
#define TA_USER_ERROR_MESSAGE(m)
Definition: error.h:93
trange_type make_trange(const Permutation &perm={}) const
Tiled range factory function.
Definition: cont_engine.h:349
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:82
Determine the object type used in the evaluation of tensor expressions.
Definition: type_traits.h:580
Multiplication expression.
Definition: mult_expr.h:88
shape_type make_shape(const Permutation &perm) const
Permuting shape factory function.
Definition: cont_engine.h:410
BinaryEngine< Derived > BinaryEngine_
Binary base class type.
Definition: cont_engine.h:54
PermutationType right_outer_permtype_
Right-hand permutation type.
Definition: binary_engine.h:92
typename value_type::value_type tile_element_type
Definition: cont_engine.h:110
Permutation of a bipartite set.
Definition: permutation.h:610
void print(ExprOStream &os, const BipartiteIndexList &target_indices) const
Expression print.
Definition: expr_engine.h:256
auto inner_size(const IndexList &p)
Definition: index_list.h:881
void init_distribution(World *world, std::shared_ptr< pmap_interface > pmap)
Initialize result tensor distribution.
Definition: cont_engine.h:304
left_type left_
The left-hand argument.
Definition: binary_engine.h:86
size_type K_
Inner dimension size.
Definition: cont_engine.h:121
PermutationType right_outer_permtype_
Right-hand permutation type.
Definition: binary_engine.h:92
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
Definition: expr_engine.h:80