TiledArray  0.7.0
expr_engine.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * expr_engine.h
22  * Mar 31, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
28 
29 #include <TiledArray/madness.h>
31 
32 namespace TiledArray {
33  namespace expressions {
34 
35  // Forward declarations
36  template <typename> struct EngineParamOverride;
37  template <typename> class Expr;
38  template <typename> struct EngineTrait;
39 
41  template <typename Derived>
42  class ExprEngine : private NO_DEFAULTS {
43  public:
45  typedef Derived derived_type;
46 
47  // Operational typedefs
52 
53  // Meta data typedefs
58 
59  protected:
60  // The member variables of this class are protected because derived
61  // classes will customize initialization.
62 
63  World* world_;
69  std::shared_ptr<pmap_interface> pmap_;
70  std::shared_ptr<EngineParamOverride<Derived> > override_ptr_;
71 
72  public:
73 
75 
77  template <typename D>
78  ExprEngine(const Expr<D> &expr) :
79  world_(NULL), vars_(), permute_tiles_(true), perm_(), trange_(), shape_(),
81  { }
82 
84 
92  void init(World& world, std::shared_ptr<pmap_interface> pmap,
93  const VariableList& target_vars)
94  {
95  if(target_vars.dim()) {
96  derived().init_vars(target_vars);
97  derived().init_struct(target_vars);
98  } else {
99  derived().init_vars();
100  derived().init_struct(vars_);
101  }
102 
103  auto override_world = override_ptr_ != nullptr && override_ptr_->world;
104  auto override_pmap = override_ptr_ != nullptr && override_ptr_->pmap;
105  world_ = override_world ? override_ptr_->world : &world;
106  pmap_ = override_pmap ? override_ptr_->pmap : pmap;
107 
108  // Check for a valid process map.
109  if(pmap_) {
110  // If process map is not valid, use the process map constructed by the
111  // expression engine.
112  if((typename pmap_interface::size_type(world_->size()) != pmap_->procs()) ||
113  (trange_.tiles_range().volume() != pmap_->size()))
114  pmap_.reset();
115  }
116 
117  derived().init_distribution(world_, pmap_);
118  }
119 
121 
130  void init_struct(const VariableList& target_vars) {
131  if(target_vars != vars_) {
132  perm_ = derived().make_perm(target_vars);
133  trange_ = derived().make_trange(perm_);
134  shape_ = derived().make_shape(perm_);
135  } else {
136  trange_ = derived().make_trange();
137  shape_ = derived().make_shape();
138  }
139 
140  if(override_ptr_ && override_ptr_->shape)
141  shape_ = shape_.mask(*override_ptr_->shape);
142  }
143 
145 
152  const std::shared_ptr<pmap_interface>& pmap)
153  {
154  TA_ASSERT(world);
155  TA_ASSERT(pmap);
156  TA_ASSERT(pmap->procs() == typename pmap_interface::size_type(world->size()));
157  TA_ASSERT(pmap->size() == trange_.tiles_range().volume());
158 
159  world_ = world;
160  pmap_ = pmap;
161  }
162 
164 
168  Permutation make_perm(const VariableList& target_vars) const {
169  return target_vars.permutation(vars_);
170  }
171 
173 
178  op_type make_op() const {
179  if(perm_ && permute_tiles_)
180  return derived().make_tile_op(perm_);
181  else
182  return derived().make_tile_op();
183  }
184 
186  derived_type& derived() { return *static_cast<derived_type*>(this); }
187 
189  const derived_type& derived() const { return *static_cast<const derived_type*>(this); }
190 
192 
194  World* world() const { return world_; }
195 
197 
199  const VariableList& vars() const { return vars_; }
200 
202 
204  const Permutation& perm() const { return perm_; }
205 
207 
209  const trange_type& trange() const { return trange_; }
210 
212 
214  const shape_type& shape() const { return shape_; }
215 
217 
219  const std::shared_ptr<pmap_interface>& pmap() const { return pmap_; }
220 
222 
224  void permute_tiles(const bool status) { permute_tiles_ = status; }
225 
227 
230  void print(ExprOStream& os, const VariableList& target_vars) const {
231  if(perm_) {
232  os << "[P " << target_vars << "]" << (permute_tiles_ ? " " : " [no permute tiles] ")
233  << derived().make_tag() << vars_ << "\n";
234  } else {
235  os << derived().make_tag() << vars_ << "\n";
236  }
237  }
238 
240 
242  const char* make_tag() const { return ""; }
243 
244  }; // class ExprEngine
245 
246  } // namespace expressions
247 } // namespace TiledArray
248 
249 #endif // TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
Definition: expr_engine.h:57
EngineTrait< Derived >::value_type value_type
Tensor value type.
Definition: expr_engine.h:48
EngineTrait< Derived >::size_type size_type
Size type.
Definition: expr_engine.h:54
std::shared_ptr< EngineParamOverride< Derived > > override_ptr_
The engine params overriding the default.
Definition: expr_engine.h:70
Derived derived_type
The derived object type.
Definition: expr_engine.h:45
const trange_type & trange() const
Tiled range accessor.
Definition: expr_engine.h:209
bool permute_tiles_
Result tile permutation flag (true == permute tile)
Definition: expr_engine.h:65
const Permutation & perm() const
Permutation accessor.
Definition: expr_engine.h:204
Permutation permutation(const V &other) const
Generate permutation relationship for variable lists.
Permutation make_perm(const VariableList &target_vars) const
Permutation factory function.
Definition: expr_engine.h:168
Scaling expression engine.
Definition: scal_engine.h:38
World * world() const
World accessor.
Definition: expr_engine.h:194
op_type make_op() const
Tile operation factory function.
Definition: expr_engine.h:178
const shape_type & shape() const
Shape accessor.
Definition: expr_engine.h:214
derived_type & derived()
Cast this object to it&#39;s derived type.
Definition: expr_engine.h:186
const std::shared_ptr< pmap_interface > & pmap() const
Process map accessor.
Definition: expr_engine.h:219
VariableList vars_
The variable list of this expression.
Definition: expr_engine.h:64
unsigned int dim() const
Returns the number of strings in the variable list.
const VariableList & vars() const
Variable list accessor.
Definition: expr_engine.h:199
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
Definition: expr_engine.h:130
EngineTrait< Derived >::trange_type trange_type
Tiled range type type.
Definition: expr_engine.h:55
ExprEngine(const Expr< D > &expr)
Default constructor.
Definition: expr_engine.h:78
EngineTrait< Derived >::shape_type shape_type
Tensor shape type.
Definition: expr_engine.h:56
Variable list manages a list variable strings.
#define TA_ASSERT(a)
Definition: error.h:107
const char * make_tag() const
Expression identification tag.
Definition: expr_engine.h:242
EngineTrait< Derived >::dist_eval_type dist_eval_type
This expression&#39;s distributed evaluator type.
Definition: expr_engine.h:51
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:69
Base class for expression evaluation.
Definition: expr.h:81
EngineTrait< Derived >::policy policy
The result policy type.
Definition: expr_engine.h:50
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:63
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:119
ExprEngine< Derived > ExprEngine_
Definition: expr_engine.h:44
Permutation perm_
The permutation that will be applied to the result.
Definition: expr_engine.h:66
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
Definition: expr_engine.h:151
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:68
Expression output stream.
Definition: expr_trace.h:39
void init(World &world, std::shared_ptr< pmap_interface > pmap, const VariableList &target_vars)
Construct and initialize the expression engine.
Definition: expr_engine.h:92
void permute_tiles(const bool status)
Set the permute tiles flag.
Definition: expr_engine.h:224
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:67
void print(ExprOStream &os, const VariableList &target_vars) const
Expression print.
Definition: expr_engine.h:230
EngineTrait< Derived >::op_type op_type
Tile operation type.
Definition: expr_engine.h:49