20 #ifndef TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED
21 #define TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED
53 template <
typename IndexList_,
typename LHSType,
typename RHSType>
55 const IndexList_& lhs_idxs,
56 const IndexList_& rhs_idxs, LHSType&& lhs,
58 using range_type = std::decay_t<decltype(lhs.range())>;
59 using size_type =
typename range_type::size_type;
60 using extent_type = std::pair<size_type, size_type>;
62 std::vector<extent_type> ranges;
63 const auto& lrange = lhs.range();
64 const auto& rrange = rhs.range();
66 for (
const auto& idx : target_idxs) {
67 const auto lmodes = lhs_idxs.positions(idx);
68 const auto rmodes = rhs_idxs.positions(idx);
69 TA_ASSERT(lmodes.size() || rmodes.size());
72 lmodes.size() ? lrange.dim(lmodes[0]) : rrange.dim(rmodes[0]);
73 for (
auto lmode : lmodes)
TA_ASSERT(lrange.dim(lmode) == corr_extent);
74 for (
auto rmode : rmodes)
TA_ASSERT(rrange.dim(rmode) == corr_extent);
75 ranges.emplace_back(std::move(corr_extent));
77 return range_type(ranges);
80 template <
typename IndexList_,
typename LHSType,
typename RHSType>
82 const IndexList_& lhs_idxs,
83 const IndexList_& rhs_idxs, LHSType&& lhs,
85 std::vector<TiledRange1> ranges;
86 const auto& lrange = lhs.trange();
87 const auto& rrange = rhs.trange();
89 for (
const auto& idx : target_idxs) {
90 const auto lmodes = lhs_idxs.positions(idx);
91 const auto rmodes = rhs_idxs.positions(idx);
92 TA_ASSERT(lmodes.size() || rmodes.size());
95 lmodes.size() ? lrange.dim(lmodes[0]) : rrange.dim(rmodes[0]);
96 for (
auto lmode : lmodes)
TA_ASSERT(lrange.dim(lmode) == corr_extent);
97 for (
auto rmode : rmodes)
TA_ASSERT(rrange.dim(rmode) == corr_extent);
98 ranges.emplace_back(std::move(corr_extent));
100 return TiledRange(ranges.begin(), ranges.end());
129 template <
typename IndexList_,
typename IndexType>
130 auto make_index(
const IndexList_& free_vars,
const IndexList_& bound_vars,
131 const IndexList_& tensor_vars, IndexType&& free_idx,
132 IndexType&& bound_idx) {
133 std::decay_t<IndexType> rv(tensor_vars.size());
134 for (std::size_t i = 0; i < tensor_vars.size(); ++i) {
135 const auto& x = tensor_vars[i];
136 const bool is_free = free_vars.count(x);
138 is_free ? free_vars.positions(x) : bound_vars.positions(x);
140 rv[i] = is_free ? free_idx[modes[0]] : bound_idx[modes[0]];
151 std::vector<std::string>(bound_temp.begin(), bound_temp.end()),
152 std::vector<std::string>{});
161 IndexList bound_vars(bound_temp.begin(), bound_temp.end());
168 template <
typename IndexList_,
typename LHSType,
typename RHSType>
170 const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
171 using value_type =
typename std::decay_t<LHSType>::value_type;
178 TA_ASSERT(lhs_vars.is_permutation(rhs_vars));
184 auto lhs_idx = [=](
const auto& bound_idx) {
185 const std::decay_t<decltype(bound_idx)>
empty;
189 auto rhs_idx = [=](
const auto& bound_idx) {
190 const std::decay_t<decltype(bound_idx)>
empty;
198 for (
const auto& bound_idx : bound_range) {
199 const auto& lhs_elem = lhs(lhs_idx(bound_idx));
200 const auto& rhs_elem = rhs(rhs_idx(bound_idx));
201 rv += lhs_elem * rhs_elem;
207 template <
typename IndexList_,
typename LHSType,
typename RHSType>
209 const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
210 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
211 return make_index(free_vars, std::decay_t<IndexList_>{}, rhs_vars, free_idx,
218 std::decay_t<RHSType> rv(orange, 0.0);
219 std::decay_t<decltype(*rhs.range().begin())>
empty;
220 for (
const auto& free_idx : orange) {
221 auto& out_elem = rv(free_idx);
222 const auto& rhs_elem = rhs(rhs_idx(free_idx,
empty));
223 out_elem += lhs * rhs_elem;
229 template <
typename IndexList_,
typename LHSType,
typename RHSType>
231 const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
236 auto lhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
237 return make_index(free_vars, bound_vars, lhs_vars, free_idx, bound_idx);
240 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
241 return make_index(free_vars, bound_vars, rhs_vars, free_idx, bound_idx);
245 std::decay_t<LHSType> rv(orange, 0.0);
247 if (bound_vars.size() == 0) {
248 std::decay_t<decltype(*lhs.range().begin())>
empty;
249 for (
const auto& free_idx : orange) {
250 auto& out_elem = rv(free_idx);
251 const auto& lhs_elem = lhs(lhs_idx(free_idx,
empty));
252 const auto& rhs_elem = rhs(rhs_idx(free_idx,
empty));
253 out_elem += lhs_elem * rhs_elem;
259 for (
const auto& free_idx : orange) {
260 auto& out_elem = rv(free_idx);
261 for (
const auto& bound_idx : brange) {
262 const auto& lhs_elem = lhs(lhs_idx(free_idx, bound_idx));
263 const auto& rhs_elem = rhs(rhs_idx(free_idx, bound_idx));
264 out_elem += lhs_elem * rhs_elem;
272 template <
typename IndexList_,
typename LHSType,
typename RHSType>
274 const IndexList_& lhs_vars,
const IndexList_& rhs_vars,
275 LHSType&& lhs, RHSType&& rhs) {
277 const auto free_ovars =
outer(free_vars);
278 const auto lhs_ovars =
outer(lhs_vars);
279 const auto lhs_ivars =
inner(lhs_vars);
280 const auto rhs_ovars =
outer(rhs_vars);
281 const auto rhs_ivars =
inner(rhs_vars);
292 const auto bound_ovars =
296 auto lhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
297 return make_index(free_ovars, bound_ovars, lhs_ovars, free_idx, bound_idx);
300 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
301 return make_index(free_ovars, bound_ovars, rhs_ovars, free_idx, bound_idx);
309 if (bound_ovars.size() == 0) {
310 std::decay_t<decltype(*lhs.range().begin())>
empty;
311 for (
const auto& free_idx : orange) {
312 const auto& inner_lhs = lhs(lhs_idx(free_idx,
empty));
313 const auto& inner_rhs = rhs(rhs_idx(free_idx,
empty));
315 inner_lhs, inner_rhs);
321 for (
const auto& free_idx : orange) {
322 auto& inner_out = rv(free_idx);
323 for (
const auto& bound_idx : bound_range) {
324 const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
325 const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
327 inner_lhs, inner_rhs);
335 template <
typename IndexList_,
typename LHSType,
typename RHSType>
337 const IndexList_& rhs_vars, LHSType&& lhs,
340 const auto out_ovars =
outer(out_vars);
341 const auto out_ivars =
inner(out_vars);
342 const auto lhs_ovars =
outer(lhs_vars);
343 const auto lhs_ivars =
inner(lhs_vars);
344 const auto rhs_ovars =
outer(rhs_vars);
345 const auto rhs_ivars =
inner(rhs_vars);
356 using tot_type = std::decay_t<RHSType>;
357 typename tot_type::value_type default_tile;
358 tot_type rv(orange, default_tile);
359 const auto bound_vars =
361 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
362 return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
364 for (
const auto& free_idx : orange) {
365 auto& inner_out = rv(free_idx);
366 std::decay_t<decltype(free_idx)>
empty;
367 const auto& inner_rhs = rhs(rhs_idx(free_idx,
empty));
370 if (inner_out != default_tile) {
388 const auto bound_vars =
392 auto lhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
393 return make_index(out_ovars, bound_vars, lhs_ovars, free_idx, bound_idx);
396 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
397 return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
402 using tot_type = std::decay_t<RHSType>;
403 typename tot_type::value_type default_tile;
404 tot_type rv(orange, default_tile);
407 if (bound_vars.size() == 0) {
408 std::decay_t<decltype(*lhs.range().begin())>
empty;
409 for (
const auto& free_idx : orange) {
410 auto& inner_out = rv(free_idx);
411 const auto& inner_lhs = lhs(lhs_idx(free_idx,
empty));
412 const auto& inner_rhs = rhs(rhs_idx(free_idx,
empty));
414 inner_lhs, inner_rhs);
415 if (inner_out != default_tile) {
424 for (
const auto& free_idx : orange) {
425 auto& inner_out = rv(free_idx);
426 for (
const auto& bound_idx : bound_range) {
427 const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
428 const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
430 inner_lhs, inner_rhs);
431 if (inner_out != default_tile) {
443 template <
typename IndexList_,
typename LHSType,
typename RHSType>
445 const IndexList_& lhs_vars,
446 const IndexList_& rhs_vars, LHSType&& lhs,
449 const auto out_ovars =
outer(out_vars);
450 const auto out_ivars =
inner(out_vars);
451 const auto lhs_ovars =
outer(lhs_vars);
452 const auto lhs_ivars =
inner(lhs_vars);
453 const auto rhs_ovars =
outer(rhs_vars);
454 const auto rhs_ivars =
inner(rhs_vars);
465 const auto bound_vars =
469 auto lhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
470 return make_index(out_ovars, bound_vars, lhs_ovars, free_idx, bound_idx);
473 auto rhs_idx = [=](
const auto& free_idx,
const auto& bound_idx) {
474 return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
479 using tot_type = std::decay_t<LHSType>;
480 typename tot_type::value_type default_tile;
481 tot_type rv(orange, default_tile);
484 if (bound_vars.size() == 0) {
485 std::decay_t<decltype(*lhs.range().begin())>
empty;
486 for (
const auto& free_idx : orange) {
487 auto& inner_out = rv(free_idx);
488 const auto& inner_lhs = lhs(lhs_idx(free_idx,
empty));
489 const auto& inner_rhs = rhs(rhs_idx(free_idx,
empty));
491 inner_lhs, inner_rhs);
492 if (inner_out != default_tile) {
501 for (
const auto& free_idx : orange) {
502 auto& inner_out = rv(free_idx);
503 for (
const auto& bound_idx : bound_range) {
504 const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
505 const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
507 inner_lhs, inner_rhs);
508 if (inner_out != default_tile) {
519 template <
bool out_is_tot,
bool lhs_is_tot,
bool rhs_is_tot>
524 template <
typename IndexList_,
typename LTileType,
typename RTileType>
525 auto operator()(
const IndexList_& ovars,
const IndexList_& lvars,
526 const IndexList_& rvars, LTileType&& ltile,
527 RTileType&& rtile)
const {
529 std::forward<LTileType>(ltile),
530 std::forward<RTileType>(rtile));
536 template <
typename IndexList_,
typename LTileType,
typename RTileType>
537 auto operator()(
const IndexList_& ovars,
const IndexList_& lvars,
538 const IndexList_& rvars, LTileType&& ltile,
539 RTileType&& rtile)
const {
541 std::forward<LTileType>(ltile),
542 std::forward<RTileType>(rtile));
548 template <
typename IndexList_,
typename LTileType,
typename RTileType>
549 auto operator()(
const IndexList_& ovars,
const IndexList_& lvars,
550 const IndexList_& rvars, LTileType&& ltile,
551 RTileType&& rtile)
const {
553 std::forward<LTileType>(ltile),
554 std::forward<RTileType>(rtile));
560 template <
typename ResultType,
typename LHSType,
typename RHSType>
567 using out_tile_type =
typename ResultType::value_type;
568 using lhs_tile_type =
typename LHSType::value_type;
569 using rhs_tile_type =
typename RHSType::value_type;
571 constexpr
bool out_is_tot =
572 TiledArray::detail::is_tensor_of_tensor_v<out_tile_type>;
573 constexpr
bool lhs_is_tot =
574 TiledArray::detail::is_tensor_of_tensor_v<lhs_tile_type>;
575 constexpr
bool rhs_is_tot =
576 TiledArray::detail::is_tensor_of_tensor_v<rhs_tile_type>;
578 const auto out_ovars =
outer(ovars);
579 const auto lhs_ovars =
outer(lvars);
580 const auto rhs_ovars =
outer(rvars);
582 const auto bound_vars =
585 const auto& ltensor = lhs.
array();
586 const auto& rtensor = rhs.
array();
595 auto l = [=](
auto& tile,
const Range& r) {
597 orange.tiles_range().idx(orange.element_to_tile(r.lobound()));
598 auto bitr = brange.tiles_range().begin();
599 const auto eitr = brange.tiles_range().end();
601 const bool have_bound = bitr != eitr;
602 decltype(oidx) bidx = have_bound ? *bitr : oidx;
603 auto lidx =
make_index(out_ovars, bound_vars, lhs_ovars, oidx, bidx);
604 auto ridx =
make_index(out_ovars, bound_vars, rhs_ovars, oidx, bidx);
605 if (!ltensor.shape().is_zero(lidx) && !rtensor.shape().is_zero(ridx)) {
606 const auto& ltile = ltensor.find(lidx).get();
607 const auto& rtile = rtensor.find(ridx).get();
609 tile = selector(ovars, lvars, rvars, ltile, rtile);
611 tile += selector(ovars, lvars, rvars, ltile, rtile);
613 if (have_bound) ++bitr;
614 }
while (bitr != brange.tiles_range().end());
615 return !tile.empty() ? tile.norm() : 0.0;
618 auto rv = make_array<ResultType>(ltensor.world(), orange, l);
620 ltensor.world().gop.fence();
625 #endif // TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED