TiledArray  0.7.0
mult_engine.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * mult_engine.h
22  * Mar 31, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_MULT_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_MULT_ENGINE_H__INCLUDED
28 
32 
33 
34 namespace TiledArray {
35  namespace expressions {
36 
37  // Forward declarations
38  template <typename, typename> class MultExpr;
39  template <typename, typename, typename> class ScalMultExpr;
40  template <typename, typename, typename> class MultEngine;
41  template <typename, typename, typename, typename> class ScalMultEngine;
42 
43  template <typename Left, typename Right, typename Result>
44  struct EngineTrait<MultEngine<Left, Right, Result> > {
45  static_assert(std::is_same<typename EngineTrait<Left>::policy,
46  typename EngineTrait<Right>::policy>::value,
47  "The left- and right-hand expressions must use the same policy class");
48 
49  // Argument typedefs
50  typedef Left left_type;
51  typedef Right right_type;
52 
53  // Operational typedefs
54  typedef TiledArray::detail::Mult<Result,
61  typedef typename op_type::result_type
63  typedef typename eval_trait<value_type>::type
67  typedef typename Left::policy policy;
70 
71  // Meta data typedefs
72  typedef typename policy::size_type size_type;
73  typedef typename policy::trange_type trange_type;
74  typedef typename policy::shape_type shape_type;
75  typedef typename policy::pmap_interface
77 
78  static constexpr bool consumable = is_consumable_tile<eval_type>::value;
79  static constexpr unsigned int leaves =
81  };
82 
83  template <typename Left, typename Right, typename Scalar, typename Result>
84  struct EngineTrait<ScalMultEngine<Left, Right, Scalar, Result> > {
85  static_assert(std::is_same<typename EngineTrait<Left>::policy,
86  typename EngineTrait<Right>::policy>::value,
87  "The left- and right-hand expressions must use the same policy class");
88 
89  // Argument typedefs
90  typedef Left left_type;
91  typedef Right right_type;
92 
93  // Operational typedefs
94  typedef Scalar scalar_type;
95  typedef TiledArray::detail::ScalMult<Result,
102  typedef typename op_type::result_type
104  typedef typename eval_trait<value_type>::type
106  typedef typename Left::policy policy;
109 
110  // Meta data typedefs
111  typedef typename policy::size_type size_type;
112  typedef typename policy::trange_type trange_type;
113  typedef typename policy::shape_type shape_type;
114  typedef typename policy::pmap_interface
116 
117  static constexpr bool consumable = is_consumable_tile<eval_type>::value;
118  static constexpr unsigned int leaves =
120  };
121 
122 
124 
133  template <typename Left, typename Right, typename Result>
134  class MultEngine : public ContEngine<MultEngine<Left, Right, Result> > {
135  public:
136  // Class hierarchy typedefs
144 
145  // Argument typedefs
146  typedef typename EngineTrait<MultEngine_>::left_type
148  typedef typename EngineTrait<MultEngine_>::right_type
150 
151  // Operational typedefs
152  typedef typename EngineTrait<MultEngine_>::value_type
156  typedef typename EngineTrait<MultEngine_>::op_type
158  typedef typename EngineTrait<MultEngine_>::policy
162 
163  // Meta data typedefs
164  typedef typename EngineTrait<MultEngine_>::size_type
168  typedef typename EngineTrait<MultEngine_>::shape_type
172 
173  private:
174 
175  bool contract_;
176 
178  public:
179 
181 
185  template <typename L, typename R>
186  MultEngine(const MultExpr<L, R>& expr) :
187  ContEngine_(expr), contract_(false)
188  { }
189 
190 
192 
198  void perm_vars(const VariableList& target_vars) {
199  if(contract_)
200  ContEngine_::perm_vars(target_vars);
201  else {
202  BinaryEngine_::perm_vars(target_vars);
203  }
204  }
205 
207 
210  void perm_vars() {
211  if(contract_)
213  else {
215  }
216  }
217 
219 
221  void init_vars(const VariableList& target_vars) {
222  BinaryEngine_::left_.init_vars();
223  BinaryEngine_::right_.init_vars();
224 
225  if(BinaryEngine_::left_.vars().is_permutation(BinaryEngine_::right_.vars())) {
226  BinaryEngine_::perm_vars(target_vars);
227  } else {
228  contract_ = true;
230  ContEngine_::perm_vars(target_vars);
231  }
232  }
233 
235  void init_vars() {
236  BinaryEngine_::left_.init_vars();
237  BinaryEngine_::right_.init_vars();
238 
239  if(BinaryEngine_::left_.vars().is_permutation(BinaryEngine_::right_.vars())) {
240  if(left_type::leaves <= right_type::leaves)
242  else
244  } else {
245  contract_ = true;
247  }
248  }
249 
251 
255  void init_struct(const VariableList& target_vars) {
256  if(contract_)
257  ContEngine_::init_struct(target_vars);
258  else
259  BinaryEngine_::init_struct(target_vars);
260  }
261 
263 
268  void init_distribution(World* world, std::shared_ptr<pmap_interface> pmap) {
269  if(contract_)
271  else
273  }
274 
276 
279  if(contract_)
280  return ContEngine_::make_trange();
281  else
283  }
284 
286 
290  if(contract_)
292  else
294  }
295 
297 
300  return BinaryEngine_::left_.shape().mult(BinaryEngine_::right_.shape());
301  }
302 
304 
308  return BinaryEngine_::left_.shape().mult(BinaryEngine_::right_.shape(), perm);
309  }
310 
312 
314  static op_type make_tile_op() { return op_type(op_base_type()); }
315 
317 
321 
323 
326  if(contract_)
328  else
330  }
331 
333 
335  const char* make_tag() const { return "[*] "; }
336 
338 
341  void print(ExprOStream os, const VariableList& target_vars) const {
342  if(contract_)
343  return ContEngine_::print(os, target_vars);
344  else
345  return BinaryEngine_::print(os, target_vars);
346  }
347 
348  }; // class MultEngine
349 
350 
352 
358  template <typename Left, typename Right, typename Scalar, typename Result>
359  class ScalMultEngine :
360  public ContEngine<ScalMultEngine<Left, Right, Scalar, Result> >
361  {
362  public:
363  // Class hierarchy typedefs
364  typedef ScalMultEngine<Left, Right, Scalar, Result>
372 
373  // Argument typedefs
378 
379  // Operational typedefs
388  typedef typename EngineTrait<ScalMultEngine_>::policy
392 
393  // Meta data typedefs
402 
403  private:
404 
405  bool contract_;
406 
408  public:
409 
411 
416  template <typename L, typename R, typename S>
417  ScalMultEngine(const ScalMultExpr<L, R, S>& expr) : ContEngine_(expr), contract_(false) { }
418 
420 
426  void perm_vars(const VariableList& target_vars) {
427  if(contract_)
428  ContEngine_::perm_vars(target_vars);
429  else {
430  BinaryEngine_::perm_vars(target_vars);
431  }
432  }
433 
435 
438  void perm_vars() {
439  if(contract_)
441  else {
443  }
444  }
445 
447 
449  void init_vars(const VariableList& target_vars) {
450  BinaryEngine_::left_.init_vars();
451  BinaryEngine_::right_.init_vars();
452 
453  if(BinaryEngine_::left_.vars().is_permutation(BinaryEngine_::right_.vars())) {
454  BinaryEngine_::perm_vars(target_vars);
455  } else {
456  contract_ = true;
458  ContEngine_::perm_vars(target_vars);
459  }
460  }
461 
463  void init_vars() {
464  BinaryEngine_::left_.init_vars();
465  BinaryEngine_::right_.init_vars();
466 
467  if(BinaryEngine_::left_.vars().is_permutation(BinaryEngine_::right_.vars())) {
468  if(left_type::leaves <= right_type::leaves)
470  else
472  } else {
473  contract_ = true;
475  }
476  }
477 
479 
483  void init_struct(const VariableList& target_vars) {
484  if(contract_)
485  ContEngine_::init_struct(target_vars);
486  else
487  BinaryEngine_::init_struct(target_vars);
488  }
489 
491 
496  void init_distribution(World* world, std::shared_ptr<pmap_interface> pmap) {
497  if(contract_)
499  else
501  }
502 
504 
507  if(contract_)
509  else
511  }
512 
514 
517  if(contract_)
518  return ContEngine_::make_trange();
519  else
521  }
522 
524 
528  if(contract_)
530  else
532  }
533 
535 
538  return BinaryEngine_::left_.shape().mult(BinaryEngine_::right_.shape(),
540  }
541 
543 
547  return BinaryEngine_::left_.shape().mult(BinaryEngine_::right_.shape(),
549  }
550 
552 
555 
557 
562  }
563 
564 
565 
567 
569  std::string make_tag() const {
570  std::stringstream ss;
571  ss << "[*] [" << ContEngine_::factor_ << "] ";
572  return ss.str();
573  }
574 
576 
579  void print(ExprOStream os, const VariableList& target_vars) const {
580  if(contract_)
581  return ContEngine_::print(os, target_vars);
582  else
583  return BinaryEngine_::print(os, target_vars);
584  }
585 
586  }; // class ScalMultEngine
587 
588  } // namespace expressions
589 } // namespace TiledArray
590 
591 #endif // TILEDARRAY_EXPRESSIONS_MULT_ENGINE_H__INCLUDED
BinaryEngine< MultEngine_ > BinaryEngine_
Binary base class type.
Definition: mult_engine.h:141
Multiplication expression.
Definition: cont_engine.h:39
dist_eval_type make_dist_eval() const
Construct the distributed evaluator for this expression.
right_type right_
The right-hand argument.
Definition: binary_engine.h:78
void init_vars()
Initialize the variable list of this expression.
Definition: mult_engine.h:463
Op::result_type result_type
The result tile type.
ContEngine< ScalMultEngine_ > ContEngine_
Contraction engine base class.
Definition: mult_engine.h:367
TiledArray::detail::BinaryWrapper< op_base_type > op_type
The tile operation type.
Definition: mult_engine.h:60
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
Definition: mult_engine.h:255
void print(ExprOStream os, const VariableList &target_vars) const
Expression print.
Definition: cont_engine.h:558
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
eval_trait< value_type >::type eval_type
Evaluation tile type.
Definition: mult_engine.h:64
EngineTrait< ScalMultEngine_ >::right_type right_type
The right-hand expression type.
Definition: mult_engine.h:377
const Permutation & perm() const
Permutation accessor.
Definition: expr_engine.h:204
trange_type make_trange(const Permutation &perm) const
Permuting tiled range factory function.
Definition: mult_engine.h:527
EngineTrait< MultEngine_ >::shape_type shape_type
Shape type.
Definition: mult_engine.h:169
scalar_type factor_
Contraction scaling factor.
Definition: cont_engine.h:106
BinaryEngine< MultEngine_ > ExprEngine_
Expression engine base class type.
Definition: mult_engine.h:143
TiledArray::detail::DistEval< value_type, policy > dist_eval_type
The distributed evaluator type.
Definition: mult_engine.h:108
EngineTrait< ScalMultEngine_ >::policy policy
The result policy type.
Definition: mult_engine.h:389
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
Definition: cont_engine.h:360
Type trait for extracting the numeric type of tensors and arrays.
Definition: type_traits.h:479
shape_type make_shape(const Permutation &perm) const
Permuting shape factory function.
Definition: mult_engine.h:307
static op_type make_tile_op()
Non-permuting tile operation factory function.
Definition: mult_engine.h:314
void perm_vars(const VariableList &target_vars)
Set the variable list for this expression.
Definition: mult_engine.h:426
EngineTrait< ScalMultEngine_ >::left_type left_type
The left-hand expression type.
Definition: mult_engine.h:375
EngineTrait< MultEngine_ >::value_type value_type
The result tile type.
Definition: mult_engine.h:153
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
op_type make_tile_op() const
Non-permuting tile operation factory function.
Definition: mult_engine.h:554
ContEngine< MultEngine_ > ContEngine_
Contraction engine base class.
Definition: mult_engine.h:139
EngineTrait< MultEngine_ >::size_type size_type
Size type.
Definition: mult_engine.h:165
EngineTrait< ScalMultEngine_ >::op_base_type op_base_type
The tile operation type.
Definition: mult_engine.h:385
Multiplication expression.
Definition: cont_engine.h:38
const std::shared_ptr< pmap_interface > & pmap() const
Process map accessor.
Definition: expr_engine.h:219
VariableList vars_
The variable list of this expression.
Definition: expr_engine.h:64
void init_vars()
Initialize the variable list of this expression.
Definition: mult_engine.h:235
EngineTrait< MultEngine_ >::dist_eval_type dist_eval_type
The distributed evaluator type.
Definition: mult_engine.h:161
EngineTrait< MultEngine_ >::op_type op_type
The tile operation type.
Definition: mult_engine.h:157
const VariableList & vars() const
Variable list accessor.
Definition: expr_engine.h:199
TiledArray::detail::Mult< Result, typename EngineTrait< Left >::eval_type, typename EngineTrait< Right >::eval_type, EngineTrait< Left >::consumable, EngineTrait< Right >::consumable > op_base_type
The base tile operation type.
Definition: mult_engine.h:58
void print(ExprOStream os, const VariableList &target_vars) const
Expression print.
Definition: mult_engine.h:579
EngineTrait< MultEngine_ >::op_base_type op_base_type
The tile operation type.
Definition: mult_engine.h:155
void perm_vars(const VariableList &target_vars)
Set the variable list for this expression.
Definition: cont_engine.h:167
EngineTrait< MultEngine_ >::pmap_interface pmap_interface
Process map interface type.
Definition: mult_engine.h:171
void init_vars(const VariableList &target_vars)
Initialize the variable list of this expression.
Definition: mult_engine.h:449
EngineTrait< ScalMultEngine_ >::value_type value_type
The result tile type.
Definition: mult_engine.h:381
EngineTrait< ScalMultEngine_ >::op_type op_type
The tile operation type.
Definition: mult_engine.h:387
shape_type make_shape(const Permutation &perm) const
Permuting shape factory function.
Definition: mult_engine.h:546
EngineTrait< ScalMultEngine_ >::pmap_interface pmap_interface
Process map interface type.
Definition: mult_engine.h:401
MultEngine< Left, Right, Result > MultEngine_
This class type.
Definition: mult_engine.h:137
MultEngine(const MultExpr< L, R > &expr)
Constructor.
Definition: mult_engine.h:186
trange_type make_trange(const Permutation &perm) const
Permuting tiled range factory function.
Definition: mult_engine.h:289
Tensor expression object.
Definition: dist_eval.h:241
void init_distribution(World *world, std::shared_ptr< pmap_interface > pmap)
Initialize result tensor distribution.
Definition: mult_engine.h:268
Binary tile operation wrapper.
Variable list manages a list variable strings.
void init_distribution(World *world, std::shared_ptr< pmap_interface > pmap)
Initialize result tensor distribution.
Definition: mult_engine.h:496
BinaryEngine< ScalMultEngine_ > BinaryEngine_
Binary base class type.
Definition: mult_engine.h:369
void perm_vars()
Set the variable list for this expression.
Definition: mult_engine.h:438
Tile scale-multiplication operation.
Definition: mult.h:233
EngineTrait< ScalMultEngine_ >::size_type size_type
Size type.
Definition: mult_engine.h:395
shape_type make_shape() const
Non-permuting shape factory function.
Definition: mult_engine.h:537
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
Definition: mult_engine.h:483
const char * make_tag() const
Expression identification tag.
Definition: mult_engine.h:335
std::string make_tag() const
Expression identification tag.
Definition: mult_engine.h:569
void init_vars(const VariableList &target_vars)
Initialize the variable list of this expression.
Definition: mult_engine.h:221
TiledArray::detail::DistEval< value_type, policy > dist_eval_type
The distributed evaluator type.
Definition: mult_engine.h:69
ScalMultEngine< Left, Right, Scalar, Result > ScalMultEngine_
This class type.
Definition: mult_engine.h:365
trange_type make_trange() const
Non-permuting tiled range factory function.
Definition: mult_engine.h:278
void perm_vars(const VariableList &target_vars)
Set the variable list for this expression.
Definition: binary_engine.h:94
Multiplication expression engine.
Definition: cont_engine.h:45
BinaryEngine< ScalMultEngine_ > ExprEngine_
Expression engine base class type.
Definition: mult_engine.h:371
dist_eval_type make_dist_eval() const
Construct the distributed evaluator for this expression.
Definition: mult_engine.h:325
Consumable tile type trait.
Definition: type_traits.h:406
Tile multiplication operation.
Definition: mult.h:51
static op_type make_tile_op(const Permutation &perm)
Permuting tile operation factory function.
Definition: mult_engine.h:320
EngineTrait< MultEngine_ >::trange_type trange_type
Tiled range type.
Definition: mult_engine.h:167
EngineTrait< ScalMultEngine_ >::dist_eval_type dist_eval_type
The distributed evaluator type.
Definition: mult_engine.h:391
ScalMultEngine(const ScalMultExpr< L, R, S > &expr)
Constructor.
Definition: mult_engine.h:417
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:119
trange_type make_trange() const
Non-permuting tiled range factory function.
void init_vars()
Initialize the variable list of this expression.
Definition: cont_engine.h:228
void init_distribution(World *world, std::shared_ptr< pmap_interface > pmap)
Initialize result tensor distribution.
Definition: cont_engine.h:400
shape_type make_shape() const
Non-permuting shape factory function.
Definition: mult_engine.h:299
trange_type make_trange() const
Non-permuting tiled range factory function.
Definition: mult_engine.h:516
op_type make_tile_op(const Permutation &perm) const
Permuting tile operation factory function.
Definition: mult_engine.h:560
TiledArray::detail::numeric_type< value_type >::type scalar_type
Tile scalar type.
Definition: mult_engine.h:66
EngineTrait< MultEngine_ >::right_type right_type
The right-hand expression type.
Definition: mult_engine.h:149
EngineTrait< MultEngine_ >::left_type left_type
The left-hand expression type.
Definition: mult_engine.h:147
EngineTrait< ScalMultEngine_ >::shape_type shape_type
Shape type.
Definition: mult_engine.h:399
void perm_vars()
Set the variable list for this expression.
Definition: mult_engine.h:210
EngineTrait< ScalMultEngine_ >::scalar_type scalar_type
Tile scalar type.
Definition: mult_engine.h:383
Expression output stream.
Definition: expr_trace.h:39
Multiplication expression engine.
Definition: mult_engine.h:40
dist_eval_type make_dist_eval() const
Definition: cont_engine.h:527
void perm_vars(const VariableList &target_vars)
Set the variable list for this expression.
Definition: mult_engine.h:198
EngineTrait< MultEngine_ >::policy policy
The result policy type.
Definition: mult_engine.h:159
void print(ExprOStream os, const VariableList &target_vars) const
Expression print.
Definition: mult_engine.h:341
EngineTrait< ScalMultEngine_ >::trange_type trange_type
Tiled range type.
Definition: mult_engine.h:397
TiledArray::detail::ScalMult< Result, typename EngineTrait< Left >::eval_type, typename EngineTrait< Right >::eval_type, scalar_type, EngineTrait< Left >::consumable, EngineTrait< Right >::consumable > op_base_type
The base tile operation type.
Definition: mult_engine.h:99
left_type left_
The left-hand argument.
Definition: binary_engine.h:77
Scaled multiplication expression engine.
Definition: mult_engine.h:41
TiledArray::detail::BinaryWrapper< op_base_type > op_type
The tile operation type.
Definition: mult_engine.h:101
void print(ExprOStream os, const VariableList &target_vars) const
Expression print.
dist_eval_type make_dist_eval() const
Construct the distributed evaluator for this expression.
Definition: mult_engine.h:506
policy::pmap_interface pmap_interface
Process map interface type.
Definition: mult_engine.h:76