contraction_helpers.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19 
20 #ifndef TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED
21 #define TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED
22 
27 
28 namespace TiledArray::expressions {
29 
53 template <typename IndexList_, typename LHSType, typename RHSType>
54 auto range_from_annotation(const IndexList_& target_idxs,
55  const IndexList_& lhs_idxs,
56  const IndexList_& rhs_idxs, LHSType&& lhs,
57  RHSType&& rhs) {
58  using range_type = std::decay_t<decltype(lhs.range())>;
59  using size_type = typename range_type::size_type;
60  using extent_type = std::pair<size_type, size_type>;
61 
62  std::vector<extent_type> ranges; // Will be the ranges for each extent
63  const auto& lrange = lhs.range();
64  const auto& rrange = rhs.range();
65 
66  for (const auto& idx : target_idxs) {
67  const auto lmodes = lhs_idxs.positions(idx);
68  const auto rmodes = rhs_idxs.positions(idx);
69  TA_ASSERT(lmodes.size() || rmodes.size()); // One of them better have it
70 
71  auto corr_extent =
72  lmodes.size() ? lrange.dim(lmodes[0]) : rrange.dim(rmodes[0]);
73  for (auto lmode : lmodes) TA_ASSERT(lrange.dim(lmode) == corr_extent);
74  for (auto rmode : rmodes) TA_ASSERT(rrange.dim(rmode) == corr_extent);
75  ranges.emplace_back(std::move(corr_extent));
76  }
77  return range_type(ranges);
78 }
79 
80 template <typename IndexList_, typename LHSType, typename RHSType>
81 auto trange_from_annotation(const IndexList_& target_idxs,
82  const IndexList_& lhs_idxs,
83  const IndexList_& rhs_idxs, LHSType&& lhs,
84  RHSType&& rhs) {
85  std::vector<TiledRange1> ranges; // Will be the ranges for each extent
86  const auto& lrange = lhs.trange();
87  const auto& rrange = rhs.trange();
88 
89  for (const auto& idx : target_idxs) {
90  const auto lmodes = lhs_idxs.positions(idx);
91  const auto rmodes = rhs_idxs.positions(idx);
92  TA_ASSERT(lmodes.size() || rmodes.size()); // One of them better have it
93 
94  auto corr_extent =
95  lmodes.size() ? lrange.dim(lmodes[0]) : rrange.dim(rmodes[0]);
96  for (auto lmode : lmodes) TA_ASSERT(lrange.dim(lmode) == corr_extent);
97  for (auto rmode : rmodes) TA_ASSERT(rrange.dim(rmode) == corr_extent);
98  ranges.emplace_back(std::move(corr_extent));
99  }
100  return TiledRange(ranges.begin(), ranges.end());
101 }
102 
129 template <typename IndexList_, typename IndexType>
130 auto make_index(const IndexList_& free_vars, const IndexList_& bound_vars,
131  const IndexList_& tensor_vars, IndexType&& free_idx,
132  IndexType&& bound_idx) {
133  std::decay_t<IndexType> rv(tensor_vars.size());
134  for (std::size_t i = 0; i < tensor_vars.size(); ++i) {
135  const auto& x = tensor_vars[i];
136  const bool is_free = free_vars.count(x);
137  const auto modes =
138  is_free ? free_vars.positions(x) : bound_vars.positions(x);
139  TA_ASSERT(modes.size() == 1); // Annotation should only appear once
140  rv[i] = is_free ? free_idx[modes[0]] : bound_idx[modes[0]];
141  }
142  return rv;
143 }
144 
146 inline auto make_bound_annotation(const BipartiteIndexList& free_vars,
147  const BipartiteIndexList& lhs_vars,
148  const BipartiteIndexList& rhs_vars) {
149  const auto bound_temp = bound_annotations(free_vars, lhs_vars, rhs_vars);
150  BipartiteIndexList bound_vars(
151  std::vector<std::string>(bound_temp.begin(), bound_temp.end()),
152  std::vector<std::string>{});
153  return bound_vars;
154 }
155 
157 inline auto make_bound_annotation(const IndexList& free_vars,
158  const IndexList& lhs_vars,
159  const IndexList& rhs_vars) {
160  const auto bound_temp = bound_annotations(free_vars, lhs_vars, rhs_vars);
161  IndexList bound_vars(bound_temp.begin(), bound_temp.end());
162  return bound_vars;
163 }
164 
165 namespace kernels {
166 
167 // Contract two tensors to a scalar
168 template <typename IndexList_, typename LHSType, typename RHSType>
169 auto s_t_t_contract_(const IndexList_& free_vars, const IndexList_& lhs_vars,
170  const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
171  using value_type = typename std::decay_t<LHSType>::value_type;
172 
173  TA_ASSERT(free_vars.size() == 0);
174  TA_ASSERT(lhs_vars.size() > 0);
175  TA_ASSERT(rhs_vars.size() > 0);
176  TA_ASSERT(inner_size(lhs_vars) == 0);
177  TA_ASSERT(inner_size(rhs_vars) == 0);
178  TA_ASSERT(lhs_vars.is_permutation(rhs_vars));
179 
180  // Get the indices being contracted over
181  const auto bound_vars = make_bound_annotation(free_vars, lhs_vars, rhs_vars);
182 
183  // Lambdas to bind the annotations, making it easier to get coordinate indices
184  auto lhs_idx = [=](const auto& bound_idx) {
185  const std::decay_t<decltype(bound_idx)> empty;
186  return make_index(free_vars, bound_vars, lhs_vars, empty, bound_idx);
187  };
188 
189  auto rhs_idx = [=](const auto& bound_idx) {
190  const std::decay_t<decltype(bound_idx)> empty;
191  return make_index(free_vars, bound_vars, rhs_vars, empty, bound_idx);
192  };
193 
194  auto bound_range =
195  range_from_annotation(bound_vars, lhs_vars, rhs_vars, lhs, rhs);
196 
197  value_type rv = 0;
198  for (const auto& bound_idx : bound_range) {
199  const auto& lhs_elem = lhs(lhs_idx(bound_idx));
200  const auto& rhs_elem = rhs(rhs_idx(bound_idx));
201  rv += lhs_elem * rhs_elem;
202  }
203  return rv;
204 }
205 
206 // Contract two tensors to a tensor
207 template <typename IndexList_, typename LHSType, typename RHSType>
208 auto t_s_t_contract_(const IndexList_& free_vars, const IndexList_& lhs_vars,
209  const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
210  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
211  return make_index(free_vars, std::decay_t<IndexList_>{}, rhs_vars, free_idx,
212  bound_idx);
213  };
214 
215  // We need to avoid passing lhs since it's a double, ranges all come from rhs
216  // so it doesn't matter if we pass it twice
217  auto orange = range_from_annotation(free_vars, lhs_vars, rhs_vars, rhs, rhs);
218  std::decay_t<RHSType> rv(orange, 0.0);
219  std::decay_t<decltype(*rhs.range().begin())> empty;
220  for (const auto& free_idx : orange) {
221  auto& out_elem = rv(free_idx);
222  const auto& rhs_elem = rhs(rhs_idx(free_idx, empty));
223  out_elem += lhs * rhs_elem;
224  }
225  return rv;
226 }
227 
228 // Contract two tensors to a tensor
229 template <typename IndexList_, typename LHSType, typename RHSType>
230 auto t_t_t_contract_(const IndexList_& free_vars, const IndexList_& lhs_vars,
231  const IndexList_& rhs_vars, LHSType&& lhs, RHSType&& rhs) {
232  // Get the indices being contracted over
233  const auto bound_vars = make_bound_annotation(free_vars, lhs_vars, rhs_vars);
234 
235  // Lambdas to bind the annotations, making it easier to get coordinate indices
236  auto lhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
237  return make_index(free_vars, bound_vars, lhs_vars, free_idx, bound_idx);
238  };
239 
240  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
241  return make_index(free_vars, bound_vars, rhs_vars, free_idx, bound_idx);
242  };
243 
244  auto orange = range_from_annotation(free_vars, lhs_vars, rhs_vars, lhs, rhs);
245  std::decay_t<LHSType> rv(orange, 0.0);
246 
247  if (bound_vars.size() == 0) {
248  std::decay_t<decltype(*lhs.range().begin())> empty;
249  for (const auto& free_idx : orange) {
250  auto& out_elem = rv(free_idx);
251  const auto& lhs_elem = lhs(lhs_idx(free_idx, empty));
252  const auto& rhs_elem = rhs(rhs_idx(free_idx, empty));
253  out_elem += lhs_elem * rhs_elem;
254  }
255  } else {
256  auto brange =
257  range_from_annotation(bound_vars, lhs_vars, rhs_vars, lhs, rhs);
258 
259  for (const auto& free_idx : orange) {
260  auto& out_elem = rv(free_idx);
261  for (const auto& bound_idx : brange) {
262  const auto& lhs_elem = lhs(lhs_idx(free_idx, bound_idx));
263  const auto& rhs_elem = rhs(rhs_idx(free_idx, bound_idx));
264  out_elem += lhs_elem * rhs_elem;
265  }
266  }
267  }
268  return rv;
269 }
270 
271 // Contract two ToTs to a ToT
272 template <typename IndexList_, typename LHSType, typename RHSType>
273 auto t_tot_tot_contract_(const IndexList_& free_vars,
274  const IndexList_& lhs_vars, const IndexList_& rhs_vars,
275  LHSType&& lhs, RHSType&& rhs) {
276  // Break the annotations up into their inner and outer parts
277  const auto free_ovars = outer(free_vars);
278  const auto lhs_ovars = outer(lhs_vars);
279  const auto lhs_ivars = inner(lhs_vars);
280  const auto rhs_ovars = outer(rhs_vars);
281  const auto rhs_ivars = inner(rhs_vars);
282 
283  // We assume there's no operation going across the outer and inner tensors
284  // (i.e., the set of outer annotations must be disjoint from the inner)
285  {
286  auto all_outer = all_annotations(free_vars, lhs_ovars, rhs_ovars);
287  auto all_inner = all_annotations(lhs_ivars, rhs_ivars);
288  TA_ASSERT(common_annotations(all_outer, all_inner).size() == 0);
289  }
290 
291  // Get the outer indices being contracted over
292  const auto bound_ovars =
293  make_bound_annotation(free_ovars, lhs_ovars, rhs_ovars);
294 
295  // lambdas to bind annotations, making it easier to get coordinate indices
296  auto lhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
297  return make_index(free_ovars, bound_ovars, lhs_ovars, free_idx, bound_idx);
298  };
299 
300  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
301  return make_index(free_ovars, bound_ovars, rhs_ovars, free_idx, bound_idx);
302  };
303 
304  auto orange =
305  range_from_annotation(free_ovars, lhs_ovars, rhs_ovars, lhs, rhs);
307 
308  // If bound_vars is empty we're doing Hadamard on the outside
309  if (bound_ovars.size() == 0) { // Hadamard on the outside
310  std::decay_t<decltype(*lhs.range().begin())> empty;
311  for (const auto& free_idx : orange) {
312  const auto& inner_lhs = lhs(lhs_idx(free_idx, empty));
313  const auto& inner_rhs = rhs(rhs_idx(free_idx, empty));
314  rv(free_idx) += s_t_t_contract_(IndexList{}, lhs_ivars, rhs_ivars,
315  inner_lhs, inner_rhs);
316  }
317  } else {
318  auto bound_range =
319  range_from_annotation(bound_ovars, lhs_ovars, rhs_ovars, lhs, rhs);
320 
321  for (const auto& free_idx : orange) {
322  auto& inner_out = rv(free_idx);
323  for (const auto& bound_idx : bound_range) {
324  const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
325  const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
326  inner_out += s_t_t_contract_(IndexList{}, lhs_ivars, rhs_ivars,
327  inner_lhs, inner_rhs);
328  }
329  }
330  }
331  return rv;
332 }
333 
334 // Contract a tensor and a ToTs to a ToT
335 template <typename IndexList_, typename LHSType, typename RHSType>
336 auto tot_t_tot_contract_(const IndexList_& out_vars, const IndexList_& lhs_vars,
337  const IndexList_& rhs_vars, LHSType&& lhs,
338  RHSType&& rhs) {
339  // Break the annotations up into their inner and outer parts
340  const auto out_ovars = outer(out_vars);
341  const auto out_ivars = inner(out_vars);
342  const auto lhs_ovars = outer(lhs_vars);
343  const auto lhs_ivars = inner(lhs_vars);
344  const auto rhs_ovars = outer(rhs_vars);
345  const auto rhs_ivars = inner(rhs_vars);
346 
347  // We assume lhs is either being contracted with the outer indices or the
348  // inner indices
349  bool with_outer = common_annotations(lhs_ovars, rhs_ovars).size() != 0;
350  bool with_inner = common_annotations(lhs_ovars, rhs_ivars).size() != 0;
351  TA_ASSERT(!(with_outer && with_inner));
352 
353  if (with_inner) {
354  auto orange =
355  range_from_annotation(out_ovars, lhs_ivars, rhs_ovars, lhs, rhs);
356  using tot_type = std::decay_t<RHSType>;
357  typename tot_type::value_type default_tile;
358  tot_type rv(orange, default_tile);
359  const auto bound_vars =
360  make_bound_annotation(out_ovars, lhs_ivars, rhs_ovars);
361  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
362  return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
363  };
364  for (const auto& free_idx : orange) {
365  auto& inner_out = rv(free_idx);
366  std::decay_t<decltype(free_idx)> empty;
367  const auto& inner_rhs = rhs(rhs_idx(free_idx, empty));
368  const auto elem =
369  t_t_t_contract_(out_ivars, lhs_ovars, rhs_ivars, lhs, inner_rhs);
370  if (inner_out != default_tile) {
371  inner_out += elem;
372  } else {
373  rv(free_idx) = elem;
374  }
375  }
376  return rv;
377  }
378 
379  // We assume there's no operation going across the outer and inner tensors
380  // (i.e., the set of outer annotations must be disjoint from the inner)
381  {
382  auto all_outer = all_annotations(out_ovars, lhs_ovars, rhs_ovars);
383  auto all_inner = all_annotations(out_ivars, lhs_ivars, rhs_ivars);
384  TA_ASSERT(common_annotations(all_outer, all_inner).size() == 0);
385  }
386 
387  // Get the outer indices being contracted over
388  const auto bound_vars =
389  make_bound_annotation(out_ovars, lhs_ovars, rhs_ovars);
390 
391  // lambdas to bind annotations, making it easier to get coordinate indices
392  auto lhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
393  return make_index(out_ovars, bound_vars, lhs_ovars, free_idx, bound_idx);
394  };
395 
396  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
397  return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
398  };
399 
400  auto orange =
401  range_from_annotation(out_ovars, lhs_ovars, rhs_ovars, lhs, rhs);
402  using tot_type = std::decay_t<RHSType>;
403  typename tot_type::value_type default_tile;
404  tot_type rv(orange, default_tile);
405 
406  // If bound_vars is empty we're doing Hadamard on the outside
407  if (bound_vars.size() == 0) { // Hadamard on the outside
408  std::decay_t<decltype(*lhs.range().begin())> empty;
409  for (const auto& free_idx : orange) {
410  auto& inner_out = rv(free_idx);
411  const auto& inner_lhs = lhs(lhs_idx(free_idx, empty));
412  const auto& inner_rhs = rhs(rhs_idx(free_idx, empty));
413  const auto elem = t_s_t_contract_(out_ivars, lhs_ivars, rhs_ivars,
414  inner_lhs, inner_rhs);
415  if (inner_out != default_tile) {
416  inner_out += elem;
417  } else {
418  rv(free_idx) = elem;
419  }
420  }
421  } else {
422  auto bound_range =
423  range_from_annotation(bound_vars, lhs_ovars, rhs_ovars, lhs, rhs);
424  for (const auto& free_idx : orange) {
425  auto& inner_out = rv(free_idx);
426  for (const auto& bound_idx : bound_range) {
427  const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
428  const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
429  const auto elem = t_s_t_contract_(out_ivars, lhs_ivars, rhs_ivars,
430  inner_lhs, inner_rhs);
431  if (inner_out != default_tile) {
432  inner_out += elem;
433  } else {
434  rv(free_idx) = elem;
435  }
436  }
437  }
438  }
439  return rv;
440 }
441 
442 // Contract two ToTs to a ToT
443 template <typename IndexList_, typename LHSType, typename RHSType>
444 auto tot_tot_tot_contract_(const IndexList_& out_vars,
445  const IndexList_& lhs_vars,
446  const IndexList_& rhs_vars, LHSType&& lhs,
447  RHSType&& rhs) {
448  // Break the annotations up into their inner and outer parts
449  const auto out_ovars = outer(out_vars);
450  const auto out_ivars = inner(out_vars);
451  const auto lhs_ovars = outer(lhs_vars);
452  const auto lhs_ivars = inner(lhs_vars);
453  const auto rhs_ovars = outer(rhs_vars);
454  const auto rhs_ivars = inner(rhs_vars);
455 
456  // We assume there's no operation going across the outer and inner tensors
457  // (i.e., the set of outer annotations must be disjoint from the inner)
458  {
459  auto all_outer = all_annotations(out_ovars, lhs_ovars, rhs_ovars);
460  auto all_inner = all_annotations(out_ivars, lhs_ivars, rhs_ivars);
461  TA_ASSERT(common_annotations(all_outer, all_inner).size() == 0);
462  }
463 
464  // Get the outer indices being contracted over
465  const auto bound_vars =
466  make_bound_annotation(out_ovars, lhs_ovars, rhs_ovars);
467 
468  // lambdas to bind annotations, making it easier to get coordinate indices
469  auto lhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
470  return make_index(out_ovars, bound_vars, lhs_ovars, free_idx, bound_idx);
471  };
472 
473  auto rhs_idx = [=](const auto& free_idx, const auto& bound_idx) {
474  return make_index(out_ovars, bound_vars, rhs_ovars, free_idx, bound_idx);
475  };
476 
477  auto orange =
478  range_from_annotation(out_ovars, lhs_ovars, rhs_ovars, lhs, rhs);
479  using tot_type = std::decay_t<LHSType>;
480  typename tot_type::value_type default_tile;
481  tot_type rv(orange, default_tile);
482 
483  // If bound_vars is empty we're doing Hadamard on the outside
484  if (bound_vars.size() == 0) { // Hadamard on the outside
485  std::decay_t<decltype(*lhs.range().begin())> empty;
486  for (const auto& free_idx : orange) {
487  auto& inner_out = rv(free_idx);
488  const auto& inner_lhs = lhs(lhs_idx(free_idx, empty));
489  const auto& inner_rhs = rhs(rhs_idx(free_idx, empty));
490  const auto elem = t_t_t_contract_(out_ivars, lhs_ivars, rhs_ivars,
491  inner_lhs, inner_rhs);
492  if (inner_out != default_tile) {
493  inner_out += elem;
494  } else {
495  rv(free_idx) = elem;
496  }
497  }
498  } else {
499  auto bound_range =
500  range_from_annotation(bound_vars, lhs_ovars, rhs_ovars, lhs, rhs);
501  for (const auto& free_idx : orange) {
502  auto& inner_out = rv(free_idx);
503  for (const auto& bound_idx : bound_range) {
504  const auto& inner_lhs = lhs(lhs_idx(free_idx, bound_idx));
505  const auto& inner_rhs = rhs(rhs_idx(free_idx, bound_idx));
506  const auto elem = t_t_t_contract_(out_ivars, lhs_ivars, rhs_ivars,
507  inner_lhs, inner_rhs);
508  if (inner_out != default_tile) {
509  inner_out += elem;
510  } else {
511  rv(free_idx) = elem;
512  }
513  }
514  }
515  }
516  return rv;
517 }
518 
519 template <bool out_is_tot, bool lhs_is_tot, bool rhs_is_tot>
521 
522 template <>
523 struct KernelSelector<true, true, true> {
524  template <typename IndexList_, typename LTileType, typename RTileType>
525  auto operator()(const IndexList_& ovars, const IndexList_& lvars,
526  const IndexList_& rvars, LTileType&& ltile,
527  RTileType&& rtile) const {
528  return tot_tot_tot_contract_(ovars, lvars, rvars,
529  std::forward<LTileType>(ltile),
530  std::forward<RTileType>(rtile));
531  }
532 };
533 
534 template <>
535 struct KernelSelector<false, true, true> {
536  template <typename IndexList_, typename LTileType, typename RTileType>
537  auto operator()(const IndexList_& ovars, const IndexList_& lvars,
538  const IndexList_& rvars, LTileType&& ltile,
539  RTileType&& rtile) const {
540  return t_tot_tot_contract_(ovars, lvars, rvars,
541  std::forward<LTileType>(ltile),
542  std::forward<RTileType>(rtile));
543  }
544 };
545 
546 template <>
547 struct KernelSelector<true, false, true> {
548  template <typename IndexList_, typename LTileType, typename RTileType>
549  auto operator()(const IndexList_& ovars, const IndexList_& lvars,
550  const IndexList_& rvars, LTileType&& ltile,
551  RTileType&& rtile) const {
552  return tot_t_tot_contract_(ovars, lvars, rvars,
553  std::forward<LTileType>(ltile),
554  std::forward<RTileType>(rtile));
555  }
556 };
557 
558 } // namespace kernels
559 
560 template <typename ResultType, typename LHSType, typename RHSType>
562  const TsrExpr<RHSType, true>& rhs) {
563  const BipartiteIndexList ovars(out.annotation());
564  const BipartiteIndexList lvars(lhs.annotation());
565  const BipartiteIndexList rvars(rhs.annotation());
566 
567  using out_tile_type = typename ResultType::value_type;
568  using lhs_tile_type = typename LHSType::value_type;
569  using rhs_tile_type = typename RHSType::value_type;
570 
571  constexpr bool out_is_tot =
572  TiledArray::detail::is_tensor_of_tensor_v<out_tile_type>;
573  constexpr bool lhs_is_tot =
574  TiledArray::detail::is_tensor_of_tensor_v<lhs_tile_type>;
575  constexpr bool rhs_is_tot =
576  TiledArray::detail::is_tensor_of_tensor_v<rhs_tile_type>;
577 
578  const auto out_ovars = outer(ovars);
579  const auto lhs_ovars = outer(lvars);
580  const auto rhs_ovars = outer(rvars);
581 
582  const auto bound_vars =
583  make_bound_annotation(out_ovars, lhs_ovars, rhs_ovars);
584 
585  const auto& ltensor = lhs.array();
586  const auto& rtensor = rhs.array();
587 
588  const auto orange =
589  trange_from_annotation(out_ovars, lhs_ovars, rhs_ovars, ltensor, rtensor);
590  const auto brange = trange_from_annotation(bound_vars, lhs_ovars, rhs_ovars,
591  ltensor, rtensor);
592 
594 
595  auto l = [=](auto& tile, const Range& r) {
596  const auto oidx =
597  orange.tiles_range().idx(orange.element_to_tile(r.lobound()));
598  auto bitr = brange.tiles_range().begin();
599  const auto eitr = brange.tiles_range().end();
600  do {
601  const bool have_bound = bitr != eitr;
602  decltype(oidx) bidx = have_bound ? *bitr : oidx;
603  auto lidx = make_index(out_ovars, bound_vars, lhs_ovars, oidx, bidx);
604  auto ridx = make_index(out_ovars, bound_vars, rhs_ovars, oidx, bidx);
605  if (!ltensor.shape().is_zero(lidx) && !rtensor.shape().is_zero(ridx)) {
606  const auto& ltile = ltensor.find(lidx).get();
607  const auto& rtile = rtensor.find(ridx).get();
608  if (tile.empty())
609  tile = selector(ovars, lvars, rvars, ltile, rtile);
610  else
611  tile += selector(ovars, lvars, rvars, ltile, rtile);
612  }
613  if (have_bound) ++bitr;
614  } while (bitr != brange.tiles_range().end());
615  return !tile.empty() ? tile.norm() : 0.0;
616  };
617 
618  auto rv = make_array<ResultType>(ltensor.world(), orange, l);
619  out.array() = rv;
620  ltensor.world().gop.fence();
621 }
622 
623 } // namespace TiledArray::expressions
624 
625 #endif // TILEDARRAY_EXPRESSIONS_CONTRACTION_HELPERS_H__INCLUDED
auto tot_tot_tot_contract_(const IndexList_ &out_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
auto make_bound_annotation(const BipartiteIndexList &free_vars, const BipartiteIndexList &lhs_vars, const BipartiteIndexList &rhs_vars)
Wraps process of getting a list with the bound variables.
auto t_tot_tot_contract_(const IndexList_ &free_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
auto t_s_t_contract_(const IndexList_ &free_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
array_type & array() const
Array accessor.
Definition: tsr_expr.h:175
auto operator()(const IndexList_ &ovars, const IndexList_ &lvars, const IndexList_ &rvars, LTileType &&ltile, RTileType &&rtile) const
Expression wrapper for array objects.
Definition: tsr_expr.h:83
auto trange_from_annotation(const IndexList_ &target_idxs, const IndexList_ &lhs_idxs, const IndexList_ &rhs_idxs, LHSType &&lhs, RHSType &&rhs)
auto outer(const IndexList &p)
Definition: index_list.h:879
auto operator()(const IndexList_ &ovars, const IndexList_ &lvars, const IndexList_ &rvars, LTileType &&ltile, RTileType &&rtile) const
auto t_t_t_contract_(const IndexList_ &free_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
auto common_annotations(T &&v, Args &&... args)
Returns the set of annotations found in all of the index lists.
Definition: index_list.h:731
auto tot_t_tot_contract_(const IndexList_ &out_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
auto s_t_t_contract_(const IndexList_ &free_vars, const IndexList_ &lhs_vars, const IndexList_ &rhs_vars, LHSType &&lhs, RHSType &&rhs)
auto inner(const IndexList &p)
Definition: index_list.h:872
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
auto make_index(const IndexList_ &free_vars, const IndexList_ &bound_vars, const IndexList_ &tensor_vars, IndexType &&free_idx, IndexType &&bound_idx)
Range data of a tiled array.
Definition: tiled_range.h:32
auto all_annotations(T &&v, Args &&... args)
Returns a set of each annotation found in at least one of the index lists.
Definition: index_list.h:720
auto bound_annotations(const IndexList_ &out, Args &&... args)
Definition: index_list.h:747
const std::string & annotation() const
Tensor annotation accessor.
Definition: tsr_expr.h:304
bool empty(const Tile< Arg > &arg)
Check that arg is empty (no data)
Definition: tile.h:646
auto range_from_annotation(const IndexList_ &target_idxs, const IndexList_ &lhs_idxs, const IndexList_ &rhs_idxs, LHSType &&lhs, RHSType &&rhs)
An N-dimensional tensor object.
Definition: tensor.h:50
auto inner_size(const IndexList &p)
Definition: index_list.h:881
void einsum(TsrExpr< ResultType, true > out, const TsrExpr< LHSType, true > &lhs, const TsrExpr< RHSType, true > &rhs)
auto operator()(const IndexList_ &ovars, const IndexList_ &lvars, const IndexList_ &rvars, LTileType &&ltile, RTileType &&rtile) const
A (hyperrectangular) interval on , space of integer -indices.
Definition: range.h:46