foreach.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2015 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * foreach.h
22  * Apr 15, 2015
23  *
24  */
25 
26 #ifndef TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED
27 #define TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED
28 
29 #include <TiledArray/shape.h>
30 #include <TiledArray/type_traits.h>
32 
34 namespace Eigen {
35 template <typename>
37 } // namespace Eigen
38 
39 namespace TiledArray {
40 
42 template <typename, typename>
43 class DistArray;
44 template <typename, typename>
45 class Tensor;
46 class DensePolicy;
47 class SparsePolicy;
48 
49 enum class ShapeReductionMethod { Union, Intersect };
50 
51 namespace detail {
52 
53 namespace {
54 
55 template <bool inplace, typename Result = void>
56 struct void_op_helper;
57 
58 template <typename Result>
59 struct void_op_helper<false, Result> {
60  template <typename Op, typename Arg, typename... Args>
61  Result operator()(Op&& op, Arg&& arg, Args&&... args) {
62  Result result;
63  std::forward<Op>(op)(result, std::forward<Arg>(arg),
64  std::forward<Args>(args)...);
65  return result;
66  }
67 };
68 template <typename Result>
69 struct void_op_helper<true, Result> {
70  template <typename Op, typename Arg, typename... Args>
71  decltype(auto) operator()(Op&& op, Arg&& arg, Args&&... args) {
72  std::forward<Op>(op)(std::forward<Arg>(arg), std::forward<Args>(args)...);
73  return arg;
74  }
75 };
76 
77 template <bool inplace, typename Result = void>
78 struct nonvoid_op_helper;
79 
80 template <typename Result>
81 struct nonvoid_op_helper<false, Result> {
82  template <typename Op, typename OpResult, typename Arg, typename... Args>
83  Result operator()(Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
84  Result result;
85  op_result = std::forward<Op>(op)(result, std::forward<Arg>(arg),
86  std::forward<Args>(args)...);
87  return result;
88  }
89 };
90 template <typename Result>
91 struct nonvoid_op_helper<true, Result> {
92  template <typename Op, typename OpResult, typename Arg, typename... Args>
93  std::decay_t<Arg> operator()(Op&& op, OpResult& op_result, Arg&& arg,
94  Args&&... args) {
95  op_result = std::forward<Op>(op)(std::forward<Arg>(arg),
96  std::forward<Args>(args)...);
97  return arg;
98  }
99 };
100 
101 template <bool inplace, typename Result>
102 struct op_helper {
103  template <typename Op, typename OpResult, typename Arg, typename... Args>
104  std::enable_if_t<
105  detail::is_invocable_void<Op, Result&, const std::decay_t<Arg>&,
106  const std::decay_t<Args>&...>::value ||
107  detail::is_invocable_void<Op, std::decay_t<Arg>&,
108  const std::decay_t<Args>&...>::value,
109  Result>
110  operator()(Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
111  void_op_helper<inplace, Result> op_caller;
112  return op_caller(std::forward<Op>(op), std::forward<Arg>(arg),
113  std::forward<Args>(args)...);
114  }
115  template <typename Op, typename OpResult, typename Arg, typename... Args>
116  std::enable_if_t<
117  !(detail::is_invocable_void<Op, Result&, const std::decay_t<Arg>&,
118  const std::decay_t<Args>&...>::value ||
119  detail::is_invocable_void<Op, std::decay_t<Arg>&,
120  const std::decay_t<Args>&...>::value),
121  Result>
122  operator()(Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
123  nonvoid_op_helper<inplace, Result> op_caller;
124  return op_caller(std::forward<Op>(op), op_result, std::forward<Arg>(arg),
125  std::forward<Args>(args)...);
126  }
127 };
128 
129 template <typename Tile, typename Policy>
130 inline bool compare_trange(const DistArray<Tile, Policy>& array1) {
131  return true;
132 }
133 
134 template <typename Tile1, typename Tile2, typename Policy, typename... Arrays>
135 inline bool compare_trange(const DistArray<Tile1, Policy>& array1,
136  const DistArray<Tile2, Policy>& array2,
137  const Arrays&... arrays) {
138  return (array1.trange() == array2.trange() &&
139  compare_trange(array1, arrays...));
140 }
141 
142 inline bool is_zero_intersection(
143  const std::initializer_list<bool>& is_zero_list) {
144  return std::any_of(is_zero_list.begin(), is_zero_list.end(),
145  [](const bool val) -> bool { return val; });
146 }
147 inline bool is_zero_union(const std::initializer_list<bool>& is_zero_list) {
148  return std::all_of(is_zero_list.begin(), is_zero_list.end(),
149  [](const bool val) -> bool { return val; });
150 }
151 
152 template <typename I, typename A>
153 Future<typename A::value_type> get_sparse_tile(const I& index, const A& array) {
154  return (!array.is_zero(index)
155  ? array.find(index)
156  : Future<typename A::value_type>(typename A::value_type()));
157 }
158 template <typename I, typename A>
159 Future<typename A::value_type> get_sparse_tile(const I& index, A& array) {
160  return (!array.is_zero(index)
161  ? array.find(index)
162  : Future<typename A::value_type>(typename A::value_type()));
163 }
164 
165 } // namespace
166 
168 
170 template <bool inplace = false, typename Op, typename ResultTile,
171  typename ArgTile, typename Policy, typename... ArgTiles>
172 inline std::
173  enable_if_t<is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
174  Op&& op, const_if_t<not inplace, DistArray<ArgTile, Policy>> & arg,
175  const DistArray<ArgTiles, Policy>&... args) {
176  constexpr const bool op_returns_void =
177  detail::is_invocable_void<Op, ResultTile&, const ArgTile&,
178  const ArgTiles&...>::value ||
179  detail::is_invocable_void<Op, ArgTile&, const ArgTiles&...>::value;
180  static_assert(!inplace || std::is_same<ResultTile, ArgTile>::value,
181  "if inplace==true, ResultTile and ArgTile must be the same");
182  static_assert(!inplace || op_returns_void,
183  "if inplace==true, Op must be callable with signature "
184  "void(ArgTile&, const ArgTiles&...)");
185  static_assert(inplace || op_returns_void,
186  "if inplace==false, Op must be callable with signature "
187  "void(ResultTile&,const ArgTile&, const ArgTiles&...)");
188 
189  TA_ASSERT(compare_trange(arg, args...) && "Tiled ranges of args must match");
190 
191  typedef DistArray<ArgTile, Policy> arg_array_type;
192  typedef DistArray<ResultTile, Policy> result_array_type;
193 
194  World& world = arg.world();
195 
196  // Make an empty result array
197  result_array_type result(world, arg.trange(), arg.pmap());
198 
199  // lifetime management of op depends on whether it is a lvalue ref (i.e. has
200  // an external owner) or an rvalue ref it also depends on whether we want to
201  // fire_same op for each tile (fire_op) or fire clones of op (fire_op_clone)
202  // - if op is an lvalue ref
203  // - if fire_op: pass op to tasks
204  // - if fire_op_clone: pass Op_(op) to tasks
205  // - if op is an rvalue ref
206  // - if fire_op: pass make_shared_function(op) to tasks
207  // - if fire_op_clone: pass copy of std::make_function(op) to tasks
208  // currently only fire_op is implemented
209  auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
210 
211  // Iterate over local tiles of arg
212  for (auto index : *(arg.pmap())) {
213  // Spawn a task to evaluate the tile
214  Future<typename result_array_type::value_type> tile = world.taskq.add(
215  [op_shared_handle](
217  arg_tile,
218  const ArgTiles&... arg_tiles) {
219  void_op_helper<inplace, typename result_array_type::value_type>
220  op_caller;
221  return op_caller(std::move(op_shared_handle), arg_tile, arg_tiles...);
222  },
223  arg.find(index), args.find(index)...);
224 
225  // Store result tile
226  result.set(index, tile);
227  }
228 
229  return result;
230 }
231 
233 
240 template <bool inplace = false, typename Op, typename ResultTile,
241  typename ArgTile, typename Policy, typename... ArgTiles>
242 inline std::
243  enable_if_t<!is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
244  Op&& op, const ShapeReductionMethod shape_reduction,
245  const_if_t<not inplace, DistArray<ArgTile, Policy>> & arg,
246  const DistArray<ArgTiles, Policy>&... args) {
247  constexpr const bool op_returns_void =
248  detail::is_invocable_void<Op, ResultTile&, const ArgTile&,
249  const ArgTiles&...>::value ||
250  detail::is_invocable_void<Op, ArgTile&, const ArgTiles&...>::value;
251  static_assert(!inplace || std::is_same<ResultTile, ArgTile>::value,
252  "if inplace==true, ResultTile and ArgTile must be the same");
253  static_assert(
255  "if inplace==true, Op must be callable with signature ret(ArgTile&, "
256  "const ArgTiles&...), where ret={void,Policy::shape_type::value_type}");
257  static_assert(inplace || detail::is_invocable<Op, ResultTile&, const ArgTile&,
258  const ArgTiles&...>::value,
259  "if inplace==false, Op must be callable with signature "
260  "ret(ResultTile&,const ArgTile&, const ArgTiles&...), where "
261  "ret={void,Policy::shape_type::value_type}");
262 
263  TA_ASSERT(detail::compare_trange(arg, args...) &&
264  "Tiled ranges of args must match");
265 
266  typedef DistArray<ArgTile, Policy> arg_array_type;
267  typedef DistArray<ResultTile, Policy> result_array_type;
268 
269  typedef typename arg_array_type::value_type arg_value_type;
270  typedef typename result_array_type::value_type result_value_type;
271  typedef typename arg_array_type::ordinal_type ordinal_type;
272  typedef typename arg_array_type::shape_type shape_type;
273  typedef std::pair<ordinal_type, Future<result_value_type>> datum_type;
274 
275  // Create a vector to hold local tiles
276  std::vector<datum_type> tiles;
277  tiles.reserve(arg.pmap()->size());
278 
279  // Construct a tensor to hold updated tile norms for the result shape.
280  TiledArray::Tensor<typename shape_type::value_type,
282  tile_norms(arg.trange().tiles_range(), 0);
283 
284  // Construct the task function used to construct the result tiles.
285  madness::AtomicInt counter;
286  counter = 0;
287  int task_count = 0;
288  auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
289  const auto task = [op_shared_handle, &counter, &tile_norms](
290  const ordinal_type index,
292  const ArgTiles&... arg_tiles) -> result_value_type {
293  op_helper<inplace, result_value_type> op_caller;
294  auto result_tile = op_caller(std::move(op_shared_handle), tile_norms[index],
295  arg_tile, arg_tiles...);
296  ++counter;
297  return result_tile;
298  };
299 
300  World& world = arg.world();
301 
302  const auto& arg_shape_data = arg.shape().data();
303  switch (shape_reduction) {
305  // Get local tile index iterator
306  for (auto index : *(arg.pmap())) {
307  if (is_zero_intersection({arg.is_zero(index), args.is_zero(index)...}))
308  continue;
309  auto result_tile =
310  world.taskq.add(task, index, arg.find(index), args.find(index)...);
311  ++task_count;
312  tiles.emplace_back(index, std::move(result_tile));
313  if (op_returns_void) // if Op does not evaluate norms, use the (scaled)
314  // norms of the first arg
315  tile_norms[index] = arg_shape_data[index];
316  }
317  break;
318  case ShapeReductionMethod::Union:
319  // Get local tile index iterator
320  for (auto index : *(arg.pmap())) {
321  if (is_zero_union({arg.is_zero(index), args.is_zero(index)...}))
322  continue;
323  auto result_tile =
324  world.taskq.add(task, index, detail::get_sparse_tile(index, arg),
325  detail::get_sparse_tile(index, args)...);
326  ++task_count;
327  tiles.emplace_back(index, std::move(result_tile));
328  if (op_returns_void) // if Op does not evaluate norms, use the (scaled)
329  // norms of the first arg need max reduction here,
330  // hence c++17, until then just assert false
331  TA_ASSERT(false &&
332  "ShapeReductionMethod::Union not supported with "
333  "void-returning Op");
334  }
335  break;
336  default:
337  TA_ASSERT(false);
338  break;
339  }
340 
341  // Wait for tile norm data to be collected.
342  if (task_count > 0)
343  world.await(
344  [&counter, task_count]() -> bool { return counter == task_count; });
345 
346  // Construct the new array
347  result_array_type result(
348  world, arg.trange(),
349  shape_type(world, tile_norms, arg.trange(), op_returns_void),
350  arg.pmap()); // if Op returns void tile_norms contains scaled norms, so
351  // do not scale again
352  for (typename std::vector<datum_type>::const_iterator it = tiles.begin();
353  it != tiles.end(); ++it) {
354  const auto index = it->first;
355  if (!result.is_zero(index)) result.set(it->first, it->second);
356  }
357 
358  return result;
359 }
360 
361 } // namespace detail
362 
391 
393 
395 
421 template <typename ResultTile, typename ArgTile, typename Policy, typename Op,
422  typename = typename std::enable_if<
423  !std::is_same<ResultTile, ArgTile>::value>::type>
424 inline std::
425  enable_if_t<is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
426  const DistArray<ArgTile, Policy>& arg, Op && op) {
427  return detail::foreach<false, Op, ResultTile, ArgTile, Policy>(
428  std::forward<Op>(op), arg);
429 }
430 
432 
435 template <typename Tile, typename Policy, typename Op>
436 inline std::enable_if_t<is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
437  const DistArray<Tile, Policy>& arg, Op && op) {
438  return detail::foreach<false, Op, Tile, Tile, Policy>(std::forward<Op>(op),
439  arg);
440 }
441 
443 
473 template <typename Tile, typename Policy, typename Op,
474  typename = typename std::enable_if<!TiledArray::detail::is_array<
475  typename std::decay<Op>::type>::value>::type>
476 inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
477  DistArray<Tile, Policy>& arg, Op&& op, bool fence = true) {
478  // The tile data is being modified in place, which means we may need to
479  // fence to ensure no other threads are using the data.
480  if (fence) arg.world().gop.fence();
481 
482  arg =
483  detail::foreach<true, Op, Tile, Tile, Policy>(std::forward<Op>(op), arg);
484 }
485 
487 
520 template <typename ResultTile, typename ArgTile, typename Policy, typename Op,
521  typename = typename std::enable_if<
522  !std::is_same<ResultTile, ArgTile>::value>::type>
523 inline std::
524  enable_if_t<!is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
525  const DistArray<ArgTile, Policy> arg, Op && op) {
526  return detail::foreach<false, Op, ResultTile, ArgTile, Policy>(
527  std::forward<Op>(op), ShapeReductionMethod::Intersect, arg);
528 }
529 
531 
534 template <typename Tile, typename Policy, typename Op>
535 inline std::enable_if_t<!is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
536  const DistArray<Tile, Policy>& arg, Op && op) {
537  return detail::foreach<false, Op, Tile, Tile, Policy>(
538  std::forward<Op>(op), ShapeReductionMethod::Intersect, arg);
539 }
540 
542 
581 template <typename Tile, typename Policy, typename Op,
582  typename = typename std::enable_if<!TiledArray::detail::is_array<
583  typename std::decay<Op>::type>::value>::type>
584 inline std::enable_if_t<!is_dense_v<Policy>, void> foreach_inplace(
585  DistArray<Tile, Policy>& arg, Op&& op, bool fence = true) {
586  // The tile data is being modified in place, which means we may need to
587  // fence to ensure no other threads are using the data.
588  if (fence) arg.world().gop.fence();
589 
590  // Set the arg with the new array
591  arg = detail::foreach<true, Op, Tile, Tile, Policy>(
592  std::forward<Op>(op), ShapeReductionMethod::Intersect, arg);
593 }
594 
597 template <typename ResultTile, typename LeftTile, typename RightTile,
598  typename Policy, typename Op,
599  typename = typename std::enable_if<
600  !std::is_same<ResultTile, LeftTile>::value>::type>
601 inline std::
602  enable_if_t<is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
603  const DistArray<LeftTile, Policy>& left,
604  const DistArray<RightTile, Policy>& right, Op && op) {
605  return detail::foreach<false, Op, ResultTile, LeftTile, Policy, RightTile>(
606  std::forward<Op>(op), left, right);
607 }
608 
611 template <typename LeftTile, typename RightTile, typename Policy, typename Op>
612 inline std::
613  enable_if_t<is_dense_v<Policy>, DistArray<LeftTile, Policy>> foreach (
614  const DistArray<LeftTile, Policy>& left,
615  const DistArray<RightTile, Policy>& right, Op && op) {
616  return detail::foreach<false, Op, LeftTile, LeftTile, Policy, RightTile>(
617  std::forward<Op>(op), left, right);
618 }
619 
621 template <typename LeftTile, typename RightTile, typename Policy, typename Op>
622 inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
624  const DistArray<RightTile, Policy>& right, Op&& op, bool fence = true) {
625  // The tile data is being modified in place, which means we may need to
626  // fence to ensure no other threads are using the data.
627  if (fence) left.world().gop.fence();
628 
629  left = detail::foreach<true, Op, LeftTile, LeftTile, Policy, RightTile>(
630  std::forward<Op>(op), left, right);
631 }
632 
635 template <typename ResultTile, typename LeftTile, typename RightTile,
636  typename Policy, typename Op,
637  typename = typename std::enable_if<
638  !std::is_same<ResultTile, LeftTile>::value>::type>
639 inline std::
640  enable_if_t<!is_dense_v<Policy>, DistArray<ResultTile, Policy>> foreach (
641  const DistArray<LeftTile, Policy>& left,
642  const DistArray<RightTile, Policy>& right, Op && op,
643  const ShapeReductionMethod shape_reduction =
645  return detail::foreach<false, Op, ResultTile, LeftTile, Policy, RightTile>(
646  std::forward<Op>(op), shape_reduction, left, right);
647 }
648 
651 template <typename LeftTile, typename RightTile, typename Policy, typename Op>
652 inline std::
653  enable_if_t<!is_dense_v<Policy>, DistArray<LeftTile, Policy>> foreach (
654  const DistArray<LeftTile, Policy>& left,
655  const DistArray<RightTile, Policy>& right, Op && op,
656  const ShapeReductionMethod shape_reduction =
658  return detail::foreach<false, Op, LeftTile, LeftTile, Policy, RightTile>(
659  std::forward<Op>(op), shape_reduction, left, right);
660 }
661 
663 template <typename LeftTile, typename RightTile, typename Policy, typename Op>
664 inline std::enable_if_t<!is_dense_v<Policy>, void> foreach_inplace(
666  const DistArray<RightTile, Policy>& right, Op&& op,
667  const ShapeReductionMethod shape_reduction =
669  bool fence = true) {
670  // The tile data is being modified in place, which means we may need to
671  // fence to ensure no other threads are using the data.
672  if (fence) left.world().gop.fence();
673 
674  // Set the arg with the new array
675  left = detail::foreach<true, Op, LeftTile, LeftTile, Policy, RightTile>(
676  std::forward<Op>(op), shape_reduction, left, right);
677 }
678 
680 
681 } // namespace TiledArray
682 
683 #endif // TILEDARRAY_CONVERSIONS_TRUNCATE_H__INCLUDED
::blas::Op Op
Definition: blas.h:46
Forward declarations.
Definition: foreach.h:34
ShapeReductionMethod
Definition: foreach.h:49
auto make_op_shared_handle(Op &&op)
Definition: function.h:129
std::enable_if_t< is_dense_v< Policy >, void > foreach_inplace(DistArray< Tile, Policy > &arg, Op &&op, bool fence=true)
Modify each tile of a dense Array.
Definition: foreach.h:476
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
Future< value_type > find(const Index &i) const
Find local or remote tile by index.
Definition: dist_array.h:524
Forward declarations.
Definition: dist_array.h:57
World & world() const
World accessor.
Definition: dist_array.h:1007
typename std::conditional< B, const T, T >::type const_if_t
prepends const to T if B is true
Definition: type_traits.h:966
bool is_zero(const Index &i) const
Check for zero tiles.
Definition: dist_array.h:1137
An N-dimensional tensor object.
Definition: tensor.h:50