expr_engine.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * expr_engine.h
22  * Mar 31, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
28 
31 
32 namespace TiledArray {
33 namespace expressions {
34 
35 // Forward declarations
36 template <typename>
37 struct EngineParamOverride;
38 template <typename>
39 class Expr;
40 template <typename>
41 struct EngineTrait;
42 
44 template <typename Derived>
45 class ExprEngine : private NO_DEFAULTS {
46  public:
48  typedef Derived derived_type;
49 
50  // Operational typedefs
51  typedef typename EngineTrait<Derived>::value_type
53  typedef
55  typedef
59 
60  // Meta data typedefs
62  typedef typename EngineTrait<Derived>::trange_type
64  typedef typename EngineTrait<Derived>::shape_type
68 
69  protected:
70  // The member variables of this class are protected because derived
71  // classes will customize initialization.
72 
73  World* world_;
76  bool permute_tiles_;
83  std::shared_ptr<pmap_interface>
85  std::shared_ptr<EngineParamOverride<Derived> >
87 
88  public:
90 
92  template <typename D>
93  ExprEngine(const Expr<D>& expr)
94  : world_(NULL),
95  indices_(),
96  permute_tiles_(true),
97  perm_(),
98  trange_(),
99  shape_(),
100  pmap_(),
102 
104 
112  void init(World& world, std::shared_ptr<pmap_interface> pmap,
113  const BipartiteIndexList& target_indices) {
114  if (target_indices.size()) {
115  derived().init_indices(target_indices);
116  derived().init_struct(target_indices);
117  } else {
118  derived().init_indices();
119  derived().init_struct(indices_);
120  }
121 
122  auto override_world = override_ptr_ != nullptr && override_ptr_->world;
123  auto override_pmap = override_ptr_ != nullptr && override_ptr_->pmap;
124  world_ = override_world ? override_ptr_->world : &world;
125  pmap_ = override_pmap ? override_ptr_->pmap : pmap;
126 
127  // Check for a valid process map.
128  if (pmap_) {
129  // If process map is not valid, use the process map constructed by the
130  // expression engine.
131  if ((typename pmap_interface::size_type(world_->size()) !=
132  pmap_->procs()) ||
133  (trange_.tiles_range().volume() != pmap_->size()))
134  pmap_.reset();
135  }
136 
137  derived().init_distribution(world_, pmap_);
138  }
139 
141 
150  void init_struct(const BipartiteIndexList& target_indices) {
151  if (target_indices != indices_) {
152  perm_ = derived().make_perm(target_indices);
153  trange_ = derived().make_trange(outer(perm_));
154  shape_ = derived().make_shape(outer(perm_));
155  } else {
156  trange_ = derived().make_trange();
157  shape_ = derived().make_shape();
158  }
159 
160  if (override_ptr_ && override_ptr_->shape)
161  shape_ = shape_.mask(*override_ptr_->shape);
162  }
163 
165 
172  const std::shared_ptr<pmap_interface>& pmap) {
173  TA_ASSERT(world);
174  TA_ASSERT(pmap);
175  TA_ASSERT(pmap->procs() ==
176  typename pmap_interface::size_type(world->size()));
177  TA_ASSERT(pmap->size() == trange_.tiles_range().volume());
178 
179  world_ = world;
180  pmap_ = pmap;
181  }
182 
184 
189  const BipartiteIndexList& target_indices) const {
190  return target_indices.permutation(indices_);
191  }
192 
194 
199  op_type make_op() const {
200  if (perm_ && permute_tiles_)
201  // permutation can only be applied to the tile, not to its element (if
202  // tile = tensor-of-tensors)
203  return derived().make_tile_op(perm_);
204  else
205  return derived().make_tile_op();
206  }
207 
209  derived_type& derived() { return *static_cast<derived_type*>(this); }
210 
212  const derived_type& derived() const {
213  return *static_cast<const derived_type*>(this);
214  }
215 
217 
219  World* world() const { return world_; }
220 
222 
224  const BipartiteIndexList& indices() const { return indices_; }
225 
227 
229  const BipartitePermutation& perm() const { return perm_; }
230 
232 
234  const trange_type& trange() const { return trange_; }
235 
237 
239  const shape_type& shape() const { return shape_; }
240 
242 
244  const std::shared_ptr<pmap_interface>& pmap() const { return pmap_; }
245 
247 
250  void permute_tiles(const bool status) { permute_tiles_ = status; }
251 
253 
256  void print(ExprOStream& os, const BipartiteIndexList& target_indices) const {
257  if (perm_) {
258  os << "[P " << target_indices << "]"
259  << (permute_tiles_ ? " " : " [no permute tiles] ")
260  << derived().make_tag() << indices_ << "\n";
261  } else {
262  os << derived().make_tag() << indices_ << "\n";
263  }
264  }
265 
267 
269  const char* make_tag() const { return ""; }
270 
271 }; // class ExprEngine
272 
273 } // namespace expressions
274 } // namespace TiledArray
275 
276 #endif // TILEDARRAY_EXPRESSIONS_EXPR_ENGINE_H__INCLUDED
void init_struct(const BipartiteIndexList &target_indices)
Initialize result tensor structure.
Definition: expr_engine.h:150
void init(World &world, std::shared_ptr< pmap_interface > pmap, const BipartiteIndexList &target_indices)
Construct and initialize the expression engine.
Definition: expr_engine.h:112
World * world() const
World accessor.
Definition: expr_engine.h:219
EngineTrait< Derived >::shape_type shape_type
Tensor shape type.
Definition: expr_engine.h:65
derived_type & derived()
Cast this object to its derived type.
Definition: expr_engine.h:209
trange_type trange_
The tiled range of the result tensor.
Definition: expr_engine.h:81
EngineTrait< Derived >::op_type op_type
Tile operation type.
Definition: expr_engine.h:54
auto outer(const IndexList &p)
Definition: index_list.h:879
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
Definition: expr_engine.h:67
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
Definition: expr_engine.h:84
std::shared_ptr< EngineParamOverride< Derived > > override_ptr_
The engine params overriding the default.
Definition: expr_engine.h:86
EngineTrait< Derived >::value_type value_type
Tensor value type.
Definition: expr_engine.h:52
World * world_
The world where this expression will be evaluated.
Definition: expr_engine.h:73
const BipartiteIndexList & indices() const
Index list accessor.
Definition: expr_engine.h:224
Base class for expression evaluation.
Definition: expr.h:97
EngineTrait< Derived >::dist_eval_type dist_eval_type
This expression's distributed evaluator type.
Definition: expr_engine.h:58
const shape_type & shape() const
Shape accessor.
Definition: expr_engine.h:239
EngineTrait< Derived >::trange_type trange_type
Tiled range type type.
Definition: expr_engine.h:63
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
ExprEngine(const Expr< D > &expr)
Default constructor.
Definition: expr_engine.h:93
const BipartitePermutation & perm() const
Permutation accessor.
Definition: expr_engine.h:229
Expression output stream.
Definition: expr_trace.h:41
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
Definition: expr_engine.h:171
EngineTrait< Derived >::policy policy
The result policy type.
Definition: expr_engine.h:56
Derived derived_type
The derived object type.
Definition: expr_engine.h:48
BipartitePermutation make_perm(const BipartiteIndexList &target_indices) const
Permutation factory function.
Definition: expr_engine.h:188
const trange_type & trange() const
Tiled range accessor.
Definition: expr_engine.h:234
EngineTrait< Derived >::size_type size_type
Size type.
Definition: expr_engine.h:61
const char * make_tag() const
Expression identification tag.
Definition: expr_engine.h:269
shape_type shape_
The shape of the result tensor.
Definition: expr_engine.h:82
op_type make_op() const
Tile operation factory function.
Definition: expr_engine.h:199
void permute_tiles(const bool status)
Set the permute tiles flag.
Definition: expr_engine.h:250
Permutation of a bipartite set.
Definition: permutation.h:610
void print(ExprOStream &os, const BipartiteIndexList &target_indices) const
Expression print.
Definition: expr_engine.h:256
BipartitePermutation permutation(const V &other) const
Computes the permutation to go from other to this instance.
Definition: index_list.h:684
const std::shared_ptr< pmap_interface > & pmap() const
Process map accessor.
Definition: expr_engine.h:244
ExprEngine< Derived > ExprEngine_
Definition: expr_engine.h:47
BipartitePermutation perm_
The permutation that will be applied to the outer tensor of tensors.
Definition: expr_engine.h:80