20 #ifndef TILEDARRAY_DIST_EVAL_CONTRACTION_EVAL_H__INCLUDED
21 #define TILEDARRAY_DIST_EVAL_CONTRACTION_EVAL_H__INCLUDED
25 #include <TiledArray/config.h>
54 template <
typename Left,
typename Right,
typename Op,
typename Policy>
57 public std::enable_shared_from_this<Summa<Left, Right, Op, Policy>> {
89 madness::Group row_group_;
90 madness::Group col_group_;
116 typedef std::pair<ordinal_type, right_future>
118 typedef std::pair<ordinal_type, left_future>
121 static constexpr
const bool trace_tasks =
122 #ifdef TILEDARRAY_ENABLE_TASK_DEBUG_TRACE
131 using std::enable_shared_from_this<Summa_>::shared_from_this;
138 const char* max_memory = getenv(
"TA_SUMMA_MAX_MEMORY");
141 std::stringstream ss(max_memory);
147 if (unit ==
"KB" || unit ==
"kB") {
149 }
else if (unit ==
"KiB" || unit ==
"kiB") {
151 }
else if (unit ==
"MB") {
153 }
else if (unit ==
"MiB") {
155 }
else if (unit ==
"GB") {
156 memory *= 1000000000.0;
157 }
else if (unit ==
"GiB") {
158 memory *= 1073741824.0;
164 memory =
std::max(memory, 104857600.0);
172 const char* max_depth = getenv(
"TA_SUMMA_MAX_DEPTH");
173 if (max_depth)
return std::stoul(max_depth);
202 template <
typename Shape,
typename ProcMap>
203 madness::Group make_group(
const Shape&
shape,
204 const std::vector<bool>& process_mask,
209 const ProcMap& proc_map)
const {
211 std::vector<ProcessID> proc_list(max_group_size, -1);
216 proc_list[p] = proc_map(p);
220 for (p = 0ul; (index < end) && (count < max_group_size);
221 index += stride, p = (p + 1u) % max_group_size) {
222 if ((proc_list[p] != -1) || (
shape.is_zero(index)) || !process_mask.at(p))
225 proc_list[p] = proc_map(p);
231 if (proc_list[p] == -1)
continue;
232 proc_list[x++] = proc_list[p];
236 proc_list.resize(count);
238 return madness::Group(
247 madness::Group make_row_group(
const ordinal_type k)
const {
253 auto result_row_mask_k = make_row_mask(k);
256 if (result_row_mask_k[proc_grid_.
rank_col()])
257 return make_group(right_.shape(), result_row_mask_k, right_begin_k,
258 right_end_k, right_stride_, proc_grid_.
proc_cols(), k,
260 return proc_grid_.map_col(col);
263 return madness::Group();
270 madness::Group make_col_group(
const ordinal_type k)
const {
273 auto result_col_mask_k = make_col_mask(k);
276 if (result_col_mask_k[proc_grid_.
rank_row()])
278 left_.shape(), result_col_mask_k, k, left_end_, left_stride_,
280 [&](
const ordinal_type row) { return proc_grid_.map_row(row); });
282 return madness::Group();
291 std::vector<bool> make_row_mask(
const ordinal_type k)
const {
296 const auto nproc_cols = proc_grid_.
proc_cols();
297 const auto my_proc_row = proc_grid_.
rank_row();
303 if (result_shape.is_dense())
return std::vector<bool>(nproc_cols,
true);
306 std::vector<bool> mask(nproc_cols,
false);
309 const auto nj = proc_grid_.
cols();
315 std::tie(i_start, i_fence, i_stride) = result_row_range(my_proc_row);
316 const auto ik_stride = i_stride * nk;
317 for (
ordinal_type i = i_start, ik = i_start * nk + k; i < i_fence;
318 i += i_stride, ik += ik_stride) {
320 if (!left_.shape().is_zero(ik)) {
322 const auto k_proc_col = k % nproc_cols;
323 mask[k_proc_col] =
true;
325 for (
ordinal_type proc_col = 0; proc_col != nproc_cols; ++proc_col) {
327 if (proc_col != k_proc_col) {
330 std::tie(j_start, j_fence, j_stride) = result_col_range(proc_col);
331 const auto ij_stride = j_stride;
332 for (
ordinal_type j = j_start, ij = i * nj + j_start; j < j_fence;
333 j += j_stride, ij += ij_stride) {
336 if (!result_shape.is_zero(
338 mask[proc_col] =
true;
356 std::vector<bool> make_col_mask(
const ordinal_type k)
const {
363 const auto nproc_rows = proc_grid_.
proc_rows();
364 const auto my_proc_col = proc_grid_.
rank_col();
370 if (result_shape.is_dense())
return std::vector<bool>(nproc_rows,
true);
373 std::vector<bool> mask(nproc_rows,
false);
376 const auto nj = proc_grid_.
cols();
380 std::tie(j_start, j_fence, j_stride) = result_col_range(my_proc_col);
381 const auto kj_stride = j_stride;
382 for (
ordinal_type j = j_start, kj = k * nj + j_start; j < j_fence;
383 j += j_stride, kj += kj_stride) {
385 if (!right_.shape().is_zero(kj)) {
387 auto k_proc_row = k % nproc_rows;
388 mask[k_proc_row] =
true;
390 for (
ordinal_type proc_row = 0; proc_row != nproc_rows; ++proc_row) {
392 if (proc_row != k_proc_row) {
395 std::tie(i_start, i_fence, i_stride) = result_row_range(proc_row);
396 const auto ij_stride = i_stride * nj;
397 for (
ordinal_type i = i_start, ij = i_start * nj + j; i < i_fence;
398 i += i_stride, ij += ij_stride) {
401 if (!result_shape.is_zero(
403 mask[proc_row] =
true;
421 inline std::tuple<ordinal_type, ordinal_type, ordinal_type> result_row_range(
426 return std::make_tuple(start, fence, stride);
435 std::tuple<ordinal_type, ordinal_type, ordinal_type> result_col_range(
440 return std::make_tuple(start, fence, stride);
450 template <
typename Tile>
451 static auto convert_tile(
const Tile& tile) {
463 template <
typename Arg>
464 static typename std::enable_if<!is_lazy_tile<typename Arg::value_type>::value,
465 Future<typename Arg::eval_type>>::type
466 get_tile(Arg& arg,
const typename Arg::ordinal_type index) {
467 return arg.get(index);
478 template <
typename Arg>
479 static typename std::enable_if<
480 is_lazy_tile<typename Arg::value_type>::value
481 #ifdef TILEDARRAY_HAS_CUDA
482 && !detail::is_cuda_tile_v<typename Arg::value_type>
485 Future<typename Arg::eval_type>>::type
486 get_tile(Arg& arg,
const typename Arg::ordinal_type index) {
487 auto convert_tile_fn =
488 &Summa_::template convert_tile<typename Arg::value_type>;
489 return arg.world().taskq.add(convert_tile_fn, arg.get(index),
490 madness::TaskAttributes::hipri());
493 #ifdef TILEDARRAY_HAS_CUDA
502 template <
typename Arg>
503 static typename std::enable_if<
504 is_lazy_tile<typename Arg::value_type>::value &&
505 detail::is_cuda_tile_v<typename Arg::value_type>,
506 Future<typename Arg::eval_type>>::type
507 get_tile(Arg& arg,
const typename Arg::ordinal_type index) {
508 auto convert_tile_fn =
509 &Summa_::template convert_tile<typename Arg::value_type>;
510 return madness::add_cuda_task(arg.world(), convert_tile_fn, arg.get(index),
511 madness::TaskAttributes::hipri());
524 template <
typename Arg,
typename Datum>
526 const ordinal_type stride, std::vector<Datum>& vec)
const {
530 if (arg.is_local(index)) {
532 if (arg.shape().is_zero(index))
continue;
533 vec.emplace_back(i, get_tile(arg, index));
537 if (arg.shape().is_zero(index))
continue;
538 vec.emplace_back(i, Future<typename Arg::eval_type>());
549 void get_col(
const ordinal_type k, std::vector<col_datum>& col)
const {
551 get_vector(left_, left_start_local_ + k, left_end_, left_stride_local_,
559 void get_row(
const ordinal_type k, std::vector<row_datum>& row)
const {
567 get_vector(right_, begin, end, right_stride_local_, row);
578 template <
typename Datum>
580 const madness::Group& group,
const ProcessID group_root,
581 const ordinal_type key_offset, std::vector<Datum>& vec)
const {
586 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
587 std::stringstream ss;
589 <<
" root=" << group.world_rank(group_root) <<
" groupid=("
590 << group.id().first <<
"," << group.id().second
591 <<
") keyoffset=" << key_offset <<
" group={ ";
592 for (ProcessID group_proc = 0; group_proc < group.size(); ++group_proc)
593 ss << group.world_rank(group_proc) <<
" ";
595 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
598 for (
typename std::vector<Datum>::iterator it = vec.begin();
599 it != vec.end(); ++it) {
606 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
608 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
613 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
615 printf(ss.str().c_str());
616 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_BCAST
622 const madness::Group& row_group)
const {
623 ProcessID group_root = k % proc_grid_.
proc_cols();
624 if (!right_.shape().is_dense() &&
625 row_group.size() <
static_cast<ProcessID
>(proc_grid_.
proc_cols())) {
626 const ProcessID world_root =
628 group_root = row_group.rank(world_root);
634 const madness::Group& col_group)
const {
635 ProcessID group_root = k % proc_grid_.
proc_rows();
636 if (!left_.shape().is_dense() &&
637 col_group.size() <
static_cast<ProcessID
>(proc_grid_.
proc_rows())) {
638 const ProcessID world_root =
640 group_root = col_group.rank(world_root);
649 void bcast_col(
const ordinal_type k, std::vector<col_datum>& col,
650 const madness::Group& row_group)
const {
652 if (!row_group.empty()) {
654 ProcessID group_root = get_row_group_root(k, row_group);
655 bcast(left_start_local_ + k, left_stride_local_, row_group, group_root,
664 void bcast_row(
const ordinal_type k, std::vector<row_datum>& row,
665 const madness::Group& col_group)
const {
667 if (!col_group.empty()) {
669 ProcessID group_root = get_col_group_root(k, col_group);
672 bcast(k * proc_grid_.
cols() + proc_grid_.
rank_col(), right_stride_local_,
673 col_group, group_root, left_.size(), row);
680 k += (Pcols - ((k + Pcols - proc_grid_.
rank_col()) % Pcols)) % Pcols;
682 for (; k <
end; k += Pcols) {
687 bool have_group =
false;
688 madness::Group row_group;
689 ProcessID group_root;
693 for (; index < left_end_; index += left_stride_local_) {
694 if (left_.shape().is_zero(index))
continue;
699 row_group = make_row_group(k);
701 do_broadcast = !row_group.empty() && row_group.size() > 1;
702 if (do_broadcast) group_root = get_row_group_root(k, row_group);
708 auto tile = get_tile(left_, index);
712 left_.discard(index);
721 k += (Prows - ((k + Prows - proc_grid_.
rank_row()) % Prows)) % Prows;
723 for (; k <
end; k += Prows) {
730 bool have_group =
false;
731 madness::Group col_group;
732 ProcessID group_root;
736 for (; index < row_end; index += right_stride_local_) {
737 if (right_.shape().is_zero(index))
continue;
742 col_group = make_col_group(k);
744 do_broadcast = !col_group.empty() && col_group.size() > 1;
745 if (do_broadcast) group_root = get_col_group_root(k, col_group);
751 index + left_.size());
752 auto tile = get_tile(right_, index);
756 right_.discard(index);
776 for (; k < k_; ++k) {
780 for (; i <
end; i += right_stride_local_)
781 if (!right_.shape().is_zero(i))
return k;
800 for (
ordinal_type i = left_start_local_ + k; i < left_end_;
801 i += left_stride_local_)
802 if (!left_.shape().is_zero(i))
return k;
824 while (k_col != k_row) {
826 k_col = iterate_col(k_row);
828 k_row = iterate_row(k_col);
835 &Summa_::bcast_col_range_task, k, k_row,
836 madness::TaskAttributes::hipri());
840 &Summa_::bcast_row_range_task, k, k_col,
841 madness::TaskAttributes::hipri());
859 return (left_.shape().is_dense() && right_.shape().is_dense()
861 : iterate_sparse(k));
874 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
875 std::stringstream ss;
877 << col_did.first <<
", " << col_did.second <<
") { ";
878 for (ProcessID gproc = 0ul; gproc < col_group_.size(); ++gproc)
879 ss << col_group_.world_rank(gproc) <<
" ";
880 ss <<
"}\n row_group_=(" << row_did.first <<
", " << row_did.second
882 for (ProcessID gproc = 0ul; gproc < row_group_.size(); ++gproc)
883 ss << row_group_.world_rank(gproc) <<
" ";
885 printf(ss.str().c_str());
886 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
889 std::allocator<ReducePairTask<op_type>> alloc;
890 reduce_tasks_ = alloc.allocate(proc_grid_.
local_size());
896 ReducePairTask<op_type>* MADNESS_RESTRICT
const reduce_task =
905 template <
typename Shape>
907 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
908 std::stringstream ss;
910 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
913 std::allocator<ReducePairTask<op_type>> alloc;
914 reduce_tasks_ = alloc.allocate(proc_grid_.
local_size());
928 ReducePairTask<op_type>* MADNESS_RESTRICT reduce_task = reduce_tasks_;
931 for (; row_start <
end; row_start += col_stride, row_end += col_stride) {
933 index += row_stride, ++reduce_task) {
938 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
940 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
946 new (reduce_task) ReducePairTask<op_type>();
951 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
953 printf(ss.str().c_str());
954 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
960 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
962 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
966 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
968 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
976 void finalize(
const DenseShape&) {
988 for (ReducePairTask<op_type>* reduce_task = reduce_tasks_; row_start <
end;
989 row_start += col_stride, row_end += col_stride) {
991 index += row_stride, ++reduce_task) {
994 reduce_task->submit());
997 reduce_task->~ReducePairTask<
op_type>();
1002 std::allocator<ReducePairTask<op_type>>().deallocate(
1007 template <
typename Shape>
1008 void finalize(
const Shape&
shape) {
1009 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1010 std::stringstream ss;
1012 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1017 row_start += proc_grid_.
rank_col();
1025 for (ReducePairTask<op_type>* reduce_task = reduce_tasks_; row_start <
end;
1026 row_start += col_stride, row_end += col_stride) {
1028 index += row_stride, ++reduce_task) {
1034 if (!
shape.is_zero(perm_index)) {
1035 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1037 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1044 reduce_task->~ReducePairTask<
op_type>();
1048 std::allocator<ReducePairTask<op_type>>().deallocate(
1051 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1053 printf(ss.str().c_str());
1054 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1058 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1060 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1064 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1066 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
1072 class FinalizeTask :
public madness::TaskInterface {
1074 std::shared_ptr<Summa_> owner_;
1077 FinalizeTask(
const std::shared_ptr<Summa_>&
owner,
const int ndep)
1078 :
madness::TaskInterface(ndep,
madness::TaskAttributes::hipri()),
1081 virtual ~FinalizeTask() {}
1083 virtual void run(
const madness::TaskThreadEnv&) { owner_->finalize(); }
1098 const std::vector<col_datum>& col,
1099 const std::vector<row_datum>& row,
1100 madness::TaskInterface*
const task) {
1110 reduce_task_offset + row[j].first;
1113 if (task) task->inc();
1114 const left_future left = col[i].second;
1115 const right_future right = row[j].second;
1116 reduce_tasks_[reduce_task_index].
add(left, right, task);
1129 template <
typename Shape>
1131 const std::vector<col_datum>& col,
1132 const std::vector<row_datum>& row,
1133 madness::TaskInterface*
const task) {
1143 reduce_task_offset + row[j].first;
1146 if (!reduce_tasks_[reduce_task_index])
continue;
1151 task->inc_debug(
"destroy(*ReduceObject)");
1155 const left_future left = col[i].second;
1156 const right_future right = row[j].second;
1157 reduce_tasks_[reduce_task_index].
add(left, right, task);
1162 #define TILEDARRAY_DISABLE_TILE_CONTRACTION_FILTER
1163 #ifndef TILEDARRAY_DISABLE_TILE_CONTRACTION_FILTER
1176 template <
typename T>
1177 typename std::enable_if<std::is_floating_point<T>::value>::type contract(
1179 const std::vector<col_datum>& col,
const std::vector<row_datum>& row,
1180 madness::TaskInterface*
const task) {
1182 std::vector<typename SparseShape<T>::value_type> row_shape_values;
1183 row_shape_values.reserve(row.size());
1187 row_shape_values.push_back(
1188 right_.shape()[row_start + (row[j].first * right_stride_local_)]);
1200 left_.shape()[col_start + (col[i].first * left_stride_local_)];
1204 if ((col_shape_value * row_shape_values[j]) < threshold_k)
continue;
1206 const ordinal_type reduce_task_index = offset + row[j].first;
1209 if (!reduce_tasks_[reduce_task_index])
continue;
1211 if (task) task->inc();
1212 reduce_tasks_[reduce_task_index].
add(col[i].second, row[j].second,
1217 #endif // TILEDARRAY_DISABLE_TILE_CONTRACTION_FILTER
1219 void contract(
const ordinal_type k,
const std::vector<col_datum>& col,
1220 const std::vector<row_datum>& row,
1221 madness::TaskInterface*
const task) {
1231 class StepTask :
public madness::TaskInterface {
1234 std::shared_ptr<Summa_> owner_;
1236 std::vector<col_datum> col_{};
1237 std::vector<row_datum> row_{};
1238 FinalizeTask* finalize_task_;
1239 StepTask* next_step_task_ =
nullptr;
1240 StepTask* tail_step_task_ =
1244 owner_->get_col(k, col_);
1246 this->notify_debug(
"StepTask::spawn_col");
1252 owner_->get_row(k, row_);
1254 this->notify_debug(
"StepTask::spawn_row");
1260 StepTask(
const std::shared_ptr<Summa_>&
owner,
int finalize_ndep)
1262 #ifdef TILEDARRAY_ENABLE_TASK_DEBUG_TRACE
1263 madness::TaskInterface(0ul,
"StepTask 1st ctor",
1264 madness::TaskAttributes::hipri()),
1270 finalize_task_(new FinalizeTask(
owner, finalize_ndep)) {
1272 owner_->world().taskq.add(finalize_task_);
1279 StepTask(StepTask*
const parent,
const int ndep)
1281 #ifdef TILEDARRAY_ENABLE_TASK_DEBUG_TRACE
1282 madness::TaskInterface(ndep,
"StepTask nth ctor",
1283 madness::TaskAttributes::hipri()),
1287 owner_(parent->owner_),
1288 world_(parent->world_),
1289 finalize_task_(parent->finalize_task_) {
1291 parent->next_step_task_ =
this;
1294 virtual ~StepTask() {}
1299 madness::DependencyInterface::inc_debug(
"StepTask::spawn_col");
1301 madness::DependencyInterface::inc();
1302 world_.taskq.add(
this, &StepTask::get_col, k,
1303 madness::TaskAttributes::hipri());
1307 madness::DependencyInterface::inc_debug(
"StepTask::spawn_row");
1309 madness::DependencyInterface::inc();
1310 world_.taskq.add(
this, &StepTask::get_row, k,
1311 madness::TaskAttributes::hipri());
1314 template <
typename Derived>
1315 void make_next_step_tasks(Derived* task,
ordinal_type depth) {
1318 if (depth > owner_->k_) depth = owner_->k_;
1321 for (; depth > 0ul; --depth) {
1324 Derived*
const next =
new Derived(task, depth == 1 ? 1 : 0);
1329 tail_step_task_ = task;
1332 template <
typename Derived,
typename GroupType>
1333 void run(
const ordinal_type k,
const GroupType& row_group,
1334 const GroupType& col_group) {
1335 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_STEP
1336 printf(
"step: start rank=%i k=%lu\n", owner_->world().rank(), k);
1337 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_STEP
1339 if (k < owner_->k_) {
1342 next_step_task_->tail_step_task_ =
new Derived(
1343 static_cast<Derived*
>(tail_step_task_),
1349 world_.taskq.add(next_step_task_);
1350 next_step_task_ =
nullptr;
1353 world_.taskq.add(owner_, &Summa_::bcast_col, k, col_, row_group,
1354 madness::TaskAttributes::hipri());
1355 world_.taskq.add(owner_, &Summa_::bcast_row, k, row_, col_group,
1356 madness::TaskAttributes::hipri());
1359 owner_->contract(k, col_, row_, tail_step_task_);
1364 tail_step_task_->notify_debug(
"StepTask nth ctor");
1366 tail_step_task_->notify();
1367 finalize_task_->notify();
1369 }
else if (finalize_task_) {
1372 finalize_task_->notify();
1375 StepTask* step_task = next_step_task_;
1377 StepTask*
const next_step_task = step_task->next_step_task_;
1378 step_task->next_step_task_ =
nullptr;
1379 step_task->finalize_task_ =
nullptr;
1380 world_.taskq.add(step_task);
1381 step_task = next_step_task;
1385 tail_step_task_->notify_debug(
"StepTask nth ctor");
1387 tail_step_task_->notify();
1390 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_STEP
1391 printf(
"step: finish rank=%i k=%lu\n", owner_->world().rank(), k);
1392 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_STEP
1397 class DenseStepTask :
public StepTask {
1400 using StepTask::owner_;
1403 DenseStepTask(
const std::shared_ptr<Summa_>&
owner,
1406 StepTask::make_next_step_tasks(
this, depth);
1407 StepTask::spawn_get_row_col_tasks(k_);
1410 DenseStepTask(DenseStepTask*
const parent,
const int ndep)
1411 : StepTask(parent, ndep), k_(parent->k_ + 1ul) {
1413 if (k_ < owner_->k_) StepTask::spawn_get_row_col_tasks(k_);
1416 virtual ~DenseStepTask() {}
1418 virtual void run(
const madness::TaskThreadEnv&) {
1419 StepTask::template run<DenseStepTask>(k_, owner_->row_group_,
1420 owner_->col_group_);
1424 class SparseStepTask :
public StepTask {
1426 Future<ordinal_type> k_{};
1427 Future<madness::Group> row_group_{};
1428 Future<madness::Group> col_group_{};
1429 using StepTask::finalize_task_;
1430 using StepTask::next_step_task_;
1431 using StepTask::owner_;
1432 using StepTask::world_;
1438 k = owner_->iterate_sparse(k + offset);
1441 if (k < owner_->k_) {
1446 StepTask::spawn_get_row_col_tasks(k);
1449 row_group_ = world_.taskq.add(owner_, &Summa_::make_row_group, k,
1450 madness::TaskAttributes::hipri());
1451 col_group_ = world_.taskq.add(owner_, &Summa_::make_col_group, k,
1452 madness::TaskAttributes::hipri());
1457 finalize_task_->inc();
1461 madness::DependencyInterface::notify_debug(
"SparseStepTask ctor");
1463 madness::DependencyInterface::notify();
1468 : StepTask(
owner, 1ul) {
1469 StepTask::make_next_step_tasks(
this, depth);
1473 madness::DependencyInterface::inc_debug(
"SparseStepTask ctor");
1475 madness::DependencyInterface::inc();
1476 world_.taskq.add(
this, &SparseStepTask::iterate_task, 0ul, 0ul,
1477 madness::TaskAttributes::hipri());
1480 SparseStepTask(SparseStepTask*
const parent,
const int ndep)
1481 : StepTask(parent, ndep) {
1482 if (parent->k_.probe() && (parent->k_.get() >= owner_->k_)) {
1484 k_.set(parent->k_.get());
1490 madness::DependencyInterface::inc_debug(
"SparseStepTask ctor");
1492 madness::DependencyInterface::inc();
1493 world_.taskq.add(
this, &SparseStepTask::iterate_task, parent->k_, 1ul,
1494 madness::TaskAttributes::hipri());
1498 virtual ~SparseStepTask() {}
1500 virtual void run(
const madness::TaskThreadEnv&) {
1501 StepTask::template run<SparseStepTask>(k_, row_group_, col_group_);
1522 template <
typename Perm,
typename = std::enable_if_t<
1523 TiledArray::detail::is_permutation_v<Perm>>>
1526 const std::shared_ptr<pmap_interface>&
pmap,
const Perm& perm,
1535 proc_grid_(proc_grid),
1536 reduce_tasks_(NULL),
1537 left_start_local_(proc_grid_.rank_row() * k),
1538 left_end_(left.
size()),
1540 left_stride_local_(proc_grid.proc_rows() * k),
1542 right_stride_local_(proc_grid.proc_cols()) {}
1565 const ProcessID source = proc_row * proc_grid_.
proc_cols() + proc_col;
1588 const float right_sparsity) {
1591 if (available_memory) {
1593 const std::size_t local_memory_per_iter_left =
1594 (left_.trange().elements_range().volume() /
1595 left_.trange().tiles_range().volume()) *
1597 proc_grid_.
local_rows() * (1.0f - left_sparsity);
1598 const std::size_t local_memory_per_iter_right =
1599 (right_.trange().elements_range().volume() /
1600 right_.trange().tiles_range().volume()) *
1602 proc_grid_.
local_cols() * (1.0f - right_sparsity);
1606 ((local_memory_per_iter_left + local_memory_per_iter_right) /
1610 if (depth > mem_bound_depth) {
1612 switch (mem_bound_depth) {
1615 TA_EXCEPTION(
"Insufficient memory available for SUMMA");
1620 "!! WARNING TiledArray: Memory constraints limit the SUMMA "
1621 "depth depth to 1.\n"
1622 "!! WARNING TiledArray: Performance may be slow.\n");
1624 depth = mem_bound_depth;
1639 virtual int internal_eval() {
1640 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1642 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1648 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1649 printf(
"eval: finished eval children rank=%i\n",
1651 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1655 tile_count = initialize();
1670 if (depth > k_) depth = k_;
1674 depth = mem_bound_depth(depth, 0.0f, 0.0f);
1677 if (max_depth_) depth =
std::min(depth, max_depth_);
1680 new DenseStepTask(shared_from_this(), depth));
1685 const float left_sparsity = left_.shape().sparsity();
1686 const float right_sparsity = right_.shape().sparsity();
1690 const float frac_non_zero = (1.0f -
std::min(left_sparsity, 0.9f)) *
1691 (1.0f -
std::min(right_sparsity, 0.9f));
1695 float(depth) * (1.0f - 1.35638f * std::log2(frac_non_zero)) + 0.5f;
1699 if (depth > k_) depth = k_;
1703 depth = mem_bound_depth(depth, left_sparsity, right_sparsity);
1706 if (max_depth_) depth =
std::min(depth, max_depth_);
1709 new SparseStepTask(shared_from_this(), depth));
1713 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1715 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1721 #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1722 printf(
"eval: finished wait children rank=%i\n",
1724 #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
1733 template <
typename Left,
typename Right,
typename Op,
typename Policy>
1735 Summa<Left, Right, Op, Policy>::max_depth_ =
1736 Summa<Left, Right, Op, Policy>::init_max_depth();
1738 template <
typename Left,
typename Right,
typename Op,
typename Policy>
1740 Summa<Left, Right, Op, Policy>::max_memory_ =
1741 Summa<Left, Right, Op, Policy>::init_max_memory();
1745 #endif // TILEDARRAY_DIST_EVAL_CONTRACTION_EVAL_H__INCLUDED