26 #ifndef TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED
27 #define TILEDARRAY_CONVERSIONS_FOREACH_H__INCLUDED
42 template <
typename,
typename>
44 template <
typename,
typename>
55 template <
bool inplace,
typename Result =
void>
56 struct void_op_helper;
58 template <
typename Result>
59 struct void_op_helper<false, Result> {
60 template <
typename Op,
typename Arg,
typename... Args>
61 Result operator()(
Op&& op, Arg&& arg, Args&&... args) {
63 std::forward<Op>(op)(result, std::forward<Arg>(arg),
64 std::forward<Args>(args)...);
68 template <
typename Result>
69 struct void_op_helper<true, Result> {
70 template <
typename Op,
typename Arg,
typename... Args>
71 decltype(
auto) operator()(
Op&& op, Arg&& arg, Args&&... args) {
72 std::forward<Op>(op)(std::forward<Arg>(arg), std::forward<Args>(args)...);
77 template <
bool inplace,
typename Result =
void>
78 struct nonvoid_op_helper;
80 template <
typename Result>
81 struct nonvoid_op_helper<false, Result> {
82 template <
typename Op,
typename OpResult,
typename Arg,
typename... Args>
83 Result operator()(
Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
85 op_result = std::forward<Op>(op)(result, std::forward<Arg>(arg),
86 std::forward<Args>(args)...);
90 template <
typename Result>
91 struct nonvoid_op_helper<true, Result> {
92 template <
typename Op,
typename OpResult,
typename Arg,
typename... Args>
93 std::decay_t<Arg> operator()(
Op&& op, OpResult& op_result, Arg&& arg,
95 op_result = std::forward<Op>(op)(std::forward<Arg>(arg),
96 std::forward<Args>(args)...);
101 template <
bool inplace,
typename Result>
103 template <
typename Op,
typename OpResult,
typename Arg,
typename... Args>
105 detail::is_invocable_void<Op, Result&, const std::decay_t<Arg>&,
106 const std::decay_t<Args>&...>::value ||
107 detail::is_invocable_void<Op, std::decay_t<Arg>&,
108 const std::decay_t<Args>&...>::value,
110 operator()(
Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
111 void_op_helper<inplace, Result> op_caller;
112 return op_caller(std::forward<Op>(op), std::forward<Arg>(arg),
113 std::forward<Args>(args)...);
115 template <
typename Op,
typename OpResult,
typename Arg,
typename... Args>
117 !(detail::is_invocable_void<Op, Result&, const std::decay_t<Arg>&,
118 const std::decay_t<Args>&...>::value ||
119 detail::is_invocable_void<Op, std::decay_t<Arg>&,
120 const std::decay_t<Args>&...>::value),
122 operator()(
Op&& op, OpResult& op_result, Arg&& arg, Args&&... args) {
123 nonvoid_op_helper<inplace, Result> op_caller;
124 return op_caller(std::forward<Op>(op), op_result, std::forward<Arg>(arg),
125 std::forward<Args>(args)...);
129 template <
typename Tile,
typename Policy>
130 inline bool compare_trange(
const DistArray<Tile, Policy>& array1) {
134 template <
typename Tile1,
typename Tile2,
typename Policy,
typename... Arrays>
135 inline bool compare_trange(
const DistArray<Tile1, Policy>& array1,
136 const DistArray<Tile2, Policy>& array2,
137 const Arrays&... arrays) {
138 return (array1.trange() == array2.trange() &&
139 compare_trange(array1, arrays...));
142 inline bool is_zero_intersection(
143 const std::initializer_list<bool>& is_zero_list) {
144 return std::any_of(is_zero_list.begin(), is_zero_list.end(),
145 [](
const bool val) ->
bool { return val; });
147 inline bool is_zero_union(
const std::initializer_list<bool>& is_zero_list) {
148 return std::all_of(is_zero_list.begin(), is_zero_list.end(),
149 [](
const bool val) ->
bool { return val; });
152 template <
typename I,
typename A>
153 Future<typename A::value_type> get_sparse_tile(
const I& index,
const A& array) {
154 return (!array.is_zero(index)
156 : Future<typename A::value_type>(
typename A::value_type()));
158 template <
typename I,
typename A>
159 Future<typename A::value_type> get_sparse_tile(
const I& index, A& array) {
160 return (!array.is_zero(index)
162 : Future<typename A::value_type>(
typename A::value_type()));
170 template <
bool inplace =
false,
typename Op,
typename ResultTile,
171 typename ArgTile,
typename Policy,
typename... ArgTiles>
176 constexpr
const bool op_returns_void =
178 const ArgTiles&...>::value ||
180 static_assert(!inplace || std::is_same<ResultTile, ArgTile>::value,
181 "if inplace==true, ResultTile and ArgTile must be the same");
182 static_assert(!inplace || op_returns_void,
183 "if inplace==true, Op must be callable with signature "
184 "void(ArgTile&, const ArgTiles&...)");
185 static_assert(inplace || op_returns_void,
186 "if inplace==false, Op must be callable with signature "
187 "void(ResultTile&,const ArgTile&, const ArgTiles&...)");
189 TA_ASSERT(compare_trange(arg, args...) &&
"Tiled ranges of args must match");
194 World& world = arg.
world();
197 result_array_type result(world, arg.trange(), arg.pmap());
212 for (
auto index : *(arg.pmap())) {
218 const ArgTiles&... arg_tiles) {
219 void_op_helper<inplace, typename result_array_type::value_type>
221 return op_caller(std::move(op_shared_handle), arg_tile, arg_tiles...);
223 arg.find(index), args.
find(index)...);
226 result.set(index, tile);
240 template <
bool inplace =
false,
typename Op,
typename ResultTile,
241 typename ArgTile,
typename Policy,
typename... ArgTiles>
247 constexpr
const bool op_returns_void =
249 const ArgTiles&...>::value ||
251 static_assert(!inplace || std::is_same<ResultTile, ArgTile>::value,
252 "if inplace==true, ResultTile and ArgTile must be the same");
255 "if inplace==true, Op must be callable with signature ret(ArgTile&, "
256 "const ArgTiles&...), where ret={void,Policy::shape_type::value_type}");
258 const ArgTiles&...>::value,
259 "if inplace==false, Op must be callable with signature "
260 "ret(ResultTile&,const ArgTile&, const ArgTiles&...), where "
261 "ret={void,Policy::shape_type::value_type}");
263 TA_ASSERT(detail::compare_trange(arg, args...) &&
264 "Tiled ranges of args must match");
269 typedef typename arg_array_type::value_type arg_value_type;
270 typedef typename result_array_type::value_type result_value_type;
271 typedef typename arg_array_type::ordinal_type ordinal_type;
272 typedef typename arg_array_type::shape_type shape_type;
273 typedef std::pair<ordinal_type, Future<result_value_type>> datum_type;
276 std::vector<datum_type> tiles;
277 tiles.reserve(arg.pmap()->size());
282 tile_norms(arg.trange().tiles_range(), 0);
285 madness::AtomicInt counter;
289 const auto task = [op_shared_handle, &counter, &tile_norms](
290 const ordinal_type index,
292 const ArgTiles&... arg_tiles) -> result_value_type {
293 op_helper<inplace, result_value_type> op_caller;
294 auto result_tile = op_caller(std::move(op_shared_handle), tile_norms[index],
295 arg_tile, arg_tiles...);
300 World& world = arg.world();
302 const auto& arg_shape_data = arg.shape().data();
303 switch (shape_reduction) {
306 for (
auto index : *(arg.pmap())) {
307 if (is_zero_intersection({arg.is_zero(index), args.
is_zero(index)...}))
310 world.taskq.add(task, index, arg.find(index), args.
find(index)...);
312 tiles.emplace_back(index, std::move(result_tile));
315 tile_norms[index] = arg_shape_data[index];
318 case ShapeReductionMethod::Union:
320 for (
auto index : *(arg.pmap())) {
321 if (is_zero_union({arg.is_zero(index), args.
is_zero(index)...}))
324 world.taskq.add(task, index, detail::get_sparse_tile(index, arg),
325 detail::get_sparse_tile(index, args)...);
327 tiles.emplace_back(index, std::move(result_tile));
332 "ShapeReductionMethod::Union not supported with "
333 "void-returning Op");
344 [&counter, task_count]() ->
bool {
return counter == task_count; });
347 result_array_type result(
349 shape_type(world, tile_norms, arg.trange(), op_returns_void),
352 for (
typename std::vector<datum_type>::const_iterator it = tiles.begin();
353 it != tiles.end(); ++it) {
354 const auto index = it->first;
355 if (!result.is_zero(index)) result.set(it->first, it->second);
421 template <
typename ResultTile,
typename ArgTile,
typename Policy,
typename Op,
422 typename =
typename std::enable_if<
423 !std::is_same<ResultTile, ArgTile>::value>::type>
427 return detail::foreach<false, Op, ResultTile, ArgTile, Policy>(
428 std::forward<Op>(op), arg);
435 template <
typename Tile,
typename Policy,
typename Op>
438 return detail::foreach<false, Op, Tile, Tile, Policy>(std::forward<Op>(op),
473 template <
typename Tile,
typename Policy,
typename Op,
475 typename std::decay<Op>::type>::value>::type>
480 if (fence) arg.
world().gop.fence();
483 detail::foreach<true, Op, Tile, Tile, Policy>(std::forward<Op>(op), arg);
520 template <
typename ResultTile,
typename ArgTile,
typename Policy,
typename Op,
521 typename =
typename std::enable_if<
522 !std::is_same<ResultTile, ArgTile>::value>::type>
526 return detail::foreach<false, Op, ResultTile, ArgTile, Policy>(
534 template <
typename Tile,
typename Policy,
typename Op>
537 return detail::foreach<false, Op, Tile, Tile, Policy>(
581 template <
typename Tile,
typename Policy,
typename Op,
583 typename std::decay<Op>::type>::value>::type>
588 if (fence) arg.
world().gop.fence();
591 arg = detail::foreach<true, Op, Tile, Tile, Policy>(
597 template <
typename ResultTile,
typename LeftTile,
typename RightTile,
598 typename Policy,
typename Op,
599 typename =
typename std::enable_if<
600 !std::is_same<ResultTile, LeftTile>::value>::type>
605 return detail::foreach<false, Op, ResultTile, LeftTile, Policy, RightTile>(
606 std::forward<Op>(op), left, right);
611 template <
typename LeftTile,
typename RightTile,
typename Policy,
typename Op>
616 return detail::foreach<false, Op, LeftTile, LeftTile, Policy, RightTile>(
617 std::forward<Op>(op), left, right);
621 template <
typename LeftTile,
typename RightTile,
typename Policy,
typename Op>
627 if (fence) left.
world().gop.fence();
629 left = detail::foreach<true, Op, LeftTile, LeftTile, Policy, RightTile>(
630 std::forward<Op>(op), left, right);
635 template <
typename ResultTile,
typename LeftTile,
typename RightTile,
636 typename Policy,
typename Op,
637 typename =
typename std::enable_if<
638 !std::is_same<ResultTile, LeftTile>::value>::type>
645 return detail::foreach<false, Op, ResultTile, LeftTile, Policy, RightTile>(
646 std::forward<Op>(op), shape_reduction, left, right);
651 template <
typename LeftTile,
typename RightTile,
typename Policy,
typename Op>
658 return detail::foreach<false, Op, LeftTile, LeftTile, Policy, RightTile>(
659 std::forward<Op>(op), shape_reduction, left, right);
663 template <
typename LeftTile,
typename RightTile,
typename Policy,
typename Op>
672 if (fence) left.
world().gop.fence();
675 left = detail::foreach<true, Op, LeftTile, LeftTile, Policy, RightTile>(
676 std::forward<Op>(op), shape_reduction, left, right);
683 #endif // TILEDARRAY_CONVERSIONS_TRUNCATE_H__INCLUDED