26 #ifndef TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED 27 #define TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED 39 template <
typename,
typename>
class DistArray;
40 template <
typename,
typename>
class Tensor;
50 template <
bool inplace,
typename Result =
void>
51 struct void_op_helper;
53 template <
typename Result>
54 struct void_op_helper<false, Result> {
55 template <
typename Op,
typename Arg,
typename... Args>
56 Result operator()(Op&&op, Arg&& arg, Args&&... args) {
58 op(result, std::forward<Arg>(arg), std::forward<Args>(args)...);
62 template <
typename Result>
63 struct void_op_helper<true, Result> {
64 template <
typename Op,
typename Arg,
typename... Args>
65 decltype(
auto) operator()(Op&&op, Arg&& arg, Args&&... args) {
66 op(std::forward<Arg>(arg), std::forward<Args>(args)...);
71 template <
bool inplace,
typename Result =
void>
72 struct nonvoid_op_helper;
74 template <
typename Result>
75 struct nonvoid_op_helper<false, Result> {
76 template <
typename Op,
typename OpResult,
77 typename Arg,
typename... Args>
78 Result operator()(Op&&op, OpResult& op_result,
79 Arg&& arg, Args&&... args) {
81 op_result = op(result, std::forward<Arg>(arg), std::forward<Args>(args)...);
85 template <
typename Result>
86 struct nonvoid_op_helper<true, Result> {
87 template <
typename Op,
typename OpResult,
88 typename Arg,
typename... Args>
89 std::decay_t<Arg> operator()(Op&&op, OpResult& op_result,
90 Arg&& arg, Args&&... args) {
91 op_result = op(std::forward<Arg>(arg), std::forward<Args>(args)...);
96 template <
typename Tile,
typename Policy>
97 inline bool compare_trange(
const DistArray<Tile, Policy>& array1) {
101 template <
typename Tile1,
typename Tile2,
typename Policy,
typename... Arrays>
102 inline bool compare_trange(
const DistArray<Tile1, Policy>& array1,
103 const DistArray<Tile2, Policy>& array2,
const Arrays&... arrays) {
104 return(array1.trange() == array2.trange()
105 && compare_trange(array1, arrays...));
108 inline bool is_zero_intersection(
const std::initializer_list<bool>& is_zero_list ) {
109 return std::any_of(is_zero_list.begin(), is_zero_list.end(),
110 [](
const bool val) ->
bool {
return val;});
112 inline bool is_zero_union(
const std::initializer_list<bool>& is_zero_list ) {
113 return std::all_of(is_zero_list.begin(), is_zero_list.end(),
114 [](
const bool val) ->
bool {
return val;});
117 template <
typename I,
typename A>
118 Future<typename A::value_type> get_sparse_tile(
const I& index,
const A& array) {
119 return (!array.is_zero(index)? array.find(index)
120 : Future<typename A::value_type>(
typename A::value_type()));
122 template <
typename I,
typename A>
123 Future<typename A::value_type> get_sparse_tile(
const I& index, A& array) {
124 return (!array.is_zero(index)? array.find(index)
125 : Future<typename A::value_type>(
typename A::value_type()));
133 template <
bool inplace =
false,
typename Op,
134 typename ResultTile,
typename ArgTile,
typename... ArgTiles>
139 TA_USER_ASSERT(compare_trange(arg, args...),
"Tiled ranges of args must match");
144 World& world = arg.
world();
147 result_array_type result(world, arg.trange(), arg.pmap());
151 const ArgTiles&... arg_tiles) {
152 void_op_helper<inplace, typename result_array_type::value_type> op_caller;
153 return op_caller(std::forward<Op>(op), arg_tile, arg_tiles...);
157 for (
auto index: *(arg.pmap())) {
159 Future<typename result_array_type::value_type> tile =
160 world.taskq.add(task, arg.find(index), args.
find(index)...);
163 result.set(index, tile);
172 template <
bool inplace =
false,
typename Op,
173 typename ResultTile,
typename ArgTile,
typename... ArgTiles>
178 TA_USER_ASSERT(detail::compare_trange(arg, args...),
"Tiled ranges of args must match");
183 typedef typename arg_array_type::value_type arg_value_type;
184 typedef typename result_array_type::value_type result_value_type;
185 typedef typename arg_array_type::size_type size_type;
186 typedef typename arg_array_type::shape_type shape_type;
187 typedef std::pair<size_type, Future<result_value_type>> datum_type;
190 std::vector<datum_type> tiles;
191 tiles.reserve(arg.pmap()->size());
196 tile_norms(arg.trange().tiles_range(), 0);
199 madness::AtomicInt counter; counter = 0;
201 auto task = [&op,&counter,&tile_norms](
const size_type index,
203 const ArgTiles&... arg_tiles) -> result_value_type {
204 nonvoid_op_helper<inplace, result_value_type> op_caller;
205 auto result_tile = op_caller(std::forward<Op>(op), tile_norms[index],
206 arg_tile, arg_tiles...);
208 return std::move(result_tile);
211 World& world = arg.world();
213 switch (shape_reduction) {
216 for(
auto index: *(arg.pmap())) {
217 if(is_zero_intersection({arg.is_zero(index), args.
is_zero(index)...}))
219 auto result_tile = world.taskq.add(task, index, arg.find(index),
220 args.
find(index)...);
222 tiles.emplace_back(index, std::move(result_tile));
227 for(
auto index: *(arg.pmap())) {
228 if(is_zero_union({arg.is_zero(index), args.
is_zero(index)...}))
230 auto result_tile = world.taskq.add(task, index, detail::get_sparse_tile(index, arg),
231 detail::get_sparse_tile(index, args)...);
233 tiles.emplace_back(index, std::move(result_tile));
243 world.await([&counter,task_count] () ->
bool {
return counter == task_count; });
246 result_array_type result(world, arg.trange(),
247 shape_type(world, tile_norms, arg.trange()), arg.pmap());
248 for(
typename std::vector<datum_type>::const_iterator it = tiles.begin(); it != tiles.end(); ++it) {
249 const size_type index = it->first;
250 if(! result.is_zero(index))
251 result.set(it->first, it->second);
285 template <
typename ResultTile,
typename ArgTile,
typename Op,
286 typename =
typename std::enable_if<!std::is_same<ResultTile,ArgTile>::value>::type>
287 inline DistArray<ResultTile, DensePolicy>
289 return detail::foreach<false, Op, ResultTile, ArgTile>(std::forward<Op>(op), arg);
296 template <
typename Tile,
typename Op>
297 inline DistArray<Tile, DensePolicy>
299 return detail::foreach<false, Op, Tile, Tile>(std::forward<Op>(op), arg);
332 template <
typename Tile,
typename Op,
333 typename =
typename std::enable_if<! TiledArray::detail::is_array<typename std::decay<Op>::type>::value>::type>
339 arg.
world().gop.fence();
341 arg = detail::foreach<true, Op, Tile, Tile>(std::forward<Op>(op), arg);
379 template <
typename ResultTile,
typename ArgTile,
typename Op,
380 typename =
typename std::enable_if<!std::is_same<ResultTile,ArgTile>::value>::type>
381 inline DistArray<ResultTile, SparsePolicy>
390 template <
typename Tile,
typename Op>
391 inline DistArray<Tile, SparsePolicy>
436 template <
typename Tile,
typename Op,
437 typename =
typename std::enable_if<! TiledArray::detail::is_array<typename std::decay<Op>::type>::value>::type>
444 arg.
world().gop.fence();
452 template <
typename ResultTile,
typename LeftTile,
typename RightTile,
typename Op,
453 typename =
typename std::enable_if<!std::is_same<ResultTile, LeftTile>::value>::type>
454 inline DistArray<ResultTile, DensePolicy>
457 return detail::foreach<false, Op, ResultTile, LeftTile, RightTile>(std::forward<Op>(op),
463 template <
typename LeftTile,
typename RightTile,
typename Op>
464 inline DistArray<LeftTile, DensePolicy>
467 return detail::foreach<false, Op, LeftTile, LeftTile, RightTile>(std::forward<Op>(op),
472 template <
typename LeftTile,
typename RightTile,
typename Op>
479 left.
world().gop.fence();
481 left = detail::foreach<true, Op, LeftTile, LeftTile, RightTile>(std::forward<Op>(op),
487 template <
typename ResultTile,
typename LeftTile,
typename RightTile,
typename Op,
488 typename =
typename std::enable_if<!std::is_same<ResultTile, LeftTile>::value>::type>
489 inline DistArray<ResultTile, SparsePolicy>
493 return detail::foreach<false, Op, ResultTile, LeftTile, RightTile>(std::forward<Op>(op),
494 shape_reduction, left, right);
499 template <
typename LeftTile,
typename RightTile,
typename Op>
500 inline DistArray<LeftTile, SparsePolicy>
504 return detail::foreach<false, Op, LeftTile, LeftTile, RightTile>(std::forward<Op>(op),
505 shape_reduction, left, right);
509 template <
typename LeftTile,
typename RightTile,
typename Op>
519 left.
world().gop.fence();
522 left = detail::foreach<true, Op, LeftTile, LeftTile, RightTile>(std::forward<Op>(op),
523 shape_reduction, left, right);
528 #endif // TILEDARRAY_CONVERSIONS_TRUNCATE_H__INCLUDED Future< value_type > find(const Index &i) const
Find local or remote tile.
void foreach_inplace(DistArray< Tile, DensePolicy > &arg, Op &&op, bool fence=true)
Modify each tile of a dense Array.
An N-dimensional tensor object.
typename std::conditional< B, const T, T >::type const_if_t
prepends const to T if B is true
bool is_zero(const Index &i) const
Check for zero tiles.
World & world() const
World accessor.
#define TA_USER_ASSERT(a, m)