btas.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2018 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * btas.h
19  * Jul 11, 2017
20  *
21  */
22 
23 #ifndef TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED
24 #define TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED
25 
29 #include <TiledArray/utility.h>
30 #include "TiledArray/config.h"
31 #include "TiledArray/math/blas.h"
33 #include "TiledArray/range.h"
36 
37 #include <btas/features.h>
38 #include <btas/generic/axpy_impl.h>
39 #include <btas/generic/permute.h>
40 #include <btas/tensor.h>
41 
42 #include <madness/world/archive.h>
43 
44 namespace btas {
45 template <>
46 struct range_traits<TiledArray::Range> {
47  const static blas::Layout order = blas::Layout::RowMajor;
50  constexpr static const bool is_general_layout = false;
51 };
52 } // namespace btas
53 
54 namespace TiledArray {
55 namespace detail {
56 // these convert any range into TiledArray::Range
57 
58 inline const TiledArray::Range& make_ta_range(const TiledArray::Range& range) {
59  return range;
60 }
61 
63 
66 template <::blas::Layout Order, typename... Args>
68  const btas::RangeNd<Order, Args...>& range) {
69  TA_ASSERT(Order == ::blas::Layout::RowMajor &&
70  "TiledArray::detail::make_ta_range(btas::RangeNd<Order,...>): "
71  "not supported for col-major Order");
72  return TiledArray::Range(range.lobound(), range.upbound());
73 }
74 
75 } // namespace detail
76 
78 
83 inline bool congruent(const Range& r1, const Range& r2) {
84  return is_congruent(r1, r2);
85 }
86 
87 } // namespace TiledArray
88 
89 namespace btas {
90 
92 
97 template <blas::Layout Order, typename... Args>
98 inline bool is_congruent(const btas::RangeNd<Order, Args...>& r1,
99  const btas::RangeNd<Order, Args...>& r2) {
100  return (r1.rank() == r2.rank()) &&
101  std::equal(r1.extent_data(), r1.extent_data() + r1.rank(),
102  r2.extent_data());
103 }
104 
105 template <typename T, typename Range, typename Storage>
106 decltype(auto) make_ti(const btas::Tensor<T, Range, Storage>& arg) {
107  return TiledArray::detail::TensorInterface<const T, Range,
108  btas::Tensor<T, Range, Storage>>(
109  arg.range(), arg.data());
110 }
111 
112 template <typename T, typename Range, typename Storage>
113 decltype(auto) make_ti(btas::Tensor<T, Range, Storage>& arg) {
114  return TiledArray::detail::TensorInterface<T, Range,
115  btas::Tensor<T, Range, Storage>>(
116  arg.range(), arg.data());
117 }
118 
119 template <typename... Args>
120 inline bool operator==(const TiledArray::Range& range1,
121  const btas::BaseRangeNd<Args...>& range2) {
122  const auto rank = range1.rank();
123  if (rank == range2.rank()) {
124  auto range1_lobound_data = range1.lobound_data();
125  using std::cbegin;
126  const auto lobound_match =
127  std::equal(range1_lobound_data, range1_lobound_data + rank,
128  cbegin(range2.lobound()));
129  if (lobound_match) {
130  auto range1_upbound_data = range1.upbound_data();
131  return std::equal(range1_upbound_data, range1_upbound_data + rank,
132  cbegin(range2.upbound()));
133  }
134  }
135  return false;
136 }
137 
138 template <typename T1, typename S1, typename T2, typename S2>
139 bool operator==(const btas::Tensor<T1, TiledArray::Range, S1>& t1,
140  const btas::Tensor<T2, TiledArray::Range, S2>& t2) {
141  auto t1_view = make_ti(t1);
142  auto t2_view = make_ti(t2);
143  using std::data;
144  return t1_view.size() == t2_view.size() &&
145  std::equal(data(t1_view), data(t1_view) + t1_view.size(),
146  data(t2_view));
147 }
148 
150 template <typename T, typename Range, typename Storage>
151 inline btas::Tensor<T, Range, Storage> clone(
152  const btas::Tensor<T, Range, Storage>& arg) {
153  return arg;
154 }
155 
157 template <typename T, typename Range, typename Storage>
158 inline btas::Tensor<T, Range, Storage> permute(
159  const btas::Tensor<T, Range, Storage>& arg,
160  const TiledArray::Permutation& perm) {
161  btas::Tensor<T, Range, Storage> result;
162  btas::permute(arg, perm.inv().data(), result);
163  return result;
164 }
165 
167 template <typename T, typename Range, typename Storage>
168 inline btas::Tensor<T, Range, Storage> permute(
169  const btas::Tensor<T, Range, Storage>& arg,
170  const TiledArray::BipartitePermutation& perm) {
171  btas::Tensor<T, Range, Storage> result;
172  constexpr bool is_tot =
174  if constexpr (!is_tot) {
175  TA_ASSERT(inner_size(perm) ==
176  0); // this must be a plain permutation if not ToT
177  btas::permute(arg, outer(perm).inv().data(), result);
178  } else {
179  btas::permute(arg, outer(perm).inv().data(), result);
180  if (inner_size(perm) != 0) {
181  auto inner_perm = inner(perm);
183  for (auto& x : result) x = p(x, inner_perm);
184  }
185  }
186  return result;
187 }
188 
189 // Shift operations ----------------------------------------------------------
190 
192 
196 template <typename T, typename Range, typename Storage, typename Index>
197 inline btas::Tensor<T, Range, Storage> shift(
198  const btas::Tensor<T, Range, Storage>& arg, const Index& range_shift) {
199  auto shifted_arg = clone(arg);
200  shift_to(shifted_arg, range_shift);
201  return shifted_arg;
202 }
203 
205 
209 template <typename T, typename Range, typename Storage, typename Index>
210 inline btas::Tensor<T, Range, Storage>& shift_to(
211  btas::Tensor<T, Range, Storage>& arg, const Index& range_shift) {
212  const_cast<Range&>(arg.range()).inplace_shift(range_shift);
213  return arg;
214 }
215 
217 template <typename T, typename Range, typename Storage>
218 inline btas::Tensor<T, Range, Storage> add(
219  const btas::Tensor<T, Range, Storage>& arg1,
220  const btas::Tensor<T, Range, Storage>& arg2) {
221  auto arg1_view = make_ti(arg1);
222  auto arg2_view = make_ti(arg2);
223  return arg1_view.add(arg2_view);
224 }
225 
227 template <typename T, typename Range, typename Storage, typename Scalar,
228  typename std::enable_if<
229  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
230 inline btas::Tensor<T, Range, Storage> add(
231  const btas::Tensor<T, Range, Storage>& arg1,
232  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor) {
233  auto arg1_view = make_ti(arg1);
234  auto arg2_view = make_ti(arg2);
235  return arg1_view.add(arg2_view, factor);
236 }
237 
239 template <
240  typename T, typename Range, typename Storage, typename Perm,
241  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
242 inline btas::Tensor<T, Range, Storage> add(
243  const btas::Tensor<T, Range, Storage>& arg1,
244  const btas::Tensor<T, Range, Storage>& arg2, const Perm& perm) {
245  auto arg1_view = make_ti(arg1);
246  auto arg2_view = make_ti(arg2);
247  return arg1_view.add(arg2_view, perm);
248 }
249 
251 template <typename T, typename Range, typename Storage, typename Scalar,
252  typename Perm,
253  typename std::enable_if<
254  TiledArray::detail::is_numeric_v<Scalar> &&
255  TiledArray::detail::is_permutation_v<Perm>>::type* = nullptr>
256 inline btas::Tensor<T, Range, Storage> add(
257  const btas::Tensor<T, Range, Storage>& arg1,
258  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor,
259  const Perm& perm) {
260  auto arg1_view = make_ti(arg1);
261  auto arg2_view = make_ti(arg2);
262  return arg1_view.add(arg2_view, factor, perm);
263 }
264 
266 template <typename T, typename Range, typename Storage>
267 inline btas::Tensor<T, Range, Storage>& add_to(
268  btas::Tensor<T, Range, Storage>& result,
269  const btas::Tensor<T, Range, Storage>& arg) {
270  auto result_view = make_ti(result);
271  auto arg_view = make_ti(arg);
272  result_view.add_to(arg_view);
273  return result;
274 }
275 
277 template <typename T, typename Range, typename Storage, typename Scalar,
278  typename std::enable_if<
279  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
280 inline btas::Tensor<T, Range, Storage>& add_to(
281  btas::Tensor<T, Range, Storage>& result,
282  const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
283  auto result_view = make_ti(result);
284  auto arg_view = make_ti(arg);
285  result_view.add_to(arg_view, factor);
286  return result;
287 }
288 
290 template <typename T, typename Range, typename Storage>
291 inline btas::Tensor<T, Range, Storage> subt(
292  const btas::Tensor<T, Range, Storage>& arg1,
293  const btas::Tensor<T, Range, Storage>& arg2) {
294  auto arg1_view = make_ti(arg1);
295  auto arg2_view = make_ti(arg2);
296  return arg1_view.subt(arg2_view);
297 }
298 
300 template <typename T, typename Range, typename Storage, typename Scalar,
301  typename std::enable_if<
302  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
303 inline btas::Tensor<T, Range, Storage> subt(
304  const btas::Tensor<T, Range, Storage>& arg1,
305  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor) {
306  auto arg1_view = make_ti(arg1);
307  auto arg2_view = make_ti(arg2);
308  return arg1_view.subt(arg2_view, factor);
309 }
310 
312 template <
313  typename T, typename Range, typename Storage, typename Perm,
314  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
315 inline btas::Tensor<T, Range, Storage> subt(
316  const btas::Tensor<T, Range, Storage>& arg1,
317  const btas::Tensor<T, Range, Storage>& arg2, const Perm& perm) {
318  auto arg1_view = make_ti(arg1);
319  auto arg2_view = make_ti(arg2);
320  return arg1_view.subt(arg2_view, perm);
321 }
322 
324 template <typename T, typename Range, typename Storage, typename Scalar,
325  typename Perm,
326  typename std::enable_if<
327  TiledArray::detail::is_numeric_v<Scalar> &&
328  TiledArray::detail::is_permutation_v<Perm>>::type* = nullptr>
329 inline btas::Tensor<T, Range, Storage> subt(
330  const btas::Tensor<T, Range, Storage>& arg1,
331  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor,
332  const Perm& perm) {
333  auto arg1_view = make_ti(arg1);
334  auto arg2_view = make_ti(arg2);
335  return arg1_view.subt(arg2_view, factor, perm);
336 }
337 
339 template <typename T, typename Range, typename Storage>
340 inline btas::Tensor<T, Range, Storage>& subt_to(
341  btas::Tensor<T, Range, Storage>& result,
342  const btas::Tensor<T, Range, Storage>& arg) {
343  auto result_view = make_ti(result);
344  auto arg_view = make_ti(arg);
345  result_view.subt_to(arg_view);
346  return result;
347 }
348 
349 template <typename T, typename Range, typename Storage, typename Scalar,
350  typename std::enable_if<
351  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
352 inline btas::Tensor<T, Range, Storage>& subt_to(
353  btas::Tensor<T, Range, Storage>& result,
354  const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
355  auto result_view = make_ti(result);
356  auto arg_view = make_ti(arg);
357  result_view.subt_to(arg_view, factor);
358  return result;
359 }
360 
362 template <typename T, typename Range, typename Storage>
363 inline btas::Tensor<T, Range, Storage> mult(
364  const btas::Tensor<T, Range, Storage>& arg1,
365  const btas::Tensor<T, Range, Storage>& arg2) {
366  auto arg1_view = make_ti(arg1);
367  auto arg2_view = make_ti(arg2);
368  return arg1_view.mult(arg2_view);
369 }
370 
372 template <typename T, typename Range, typename Storage, typename Scalar,
373  typename std::enable_if<
374  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
375 inline btas::Tensor<T, Range, Storage> mult(
376  const btas::Tensor<T, Range, Storage>& arg1,
377  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor) {
378  auto arg1_view = make_ti(arg1);
379  auto arg2_view = make_ti(arg2);
380  return arg1_view.mult(arg2_view, factor);
381 }
382 
384 template <
385  typename T, typename Range, typename Storage, typename Perm,
386  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
387 inline btas::Tensor<T, Range, Storage> mult(
388  const btas::Tensor<T, Range, Storage>& arg1,
389  const btas::Tensor<T, Range, Storage>& arg2, const Perm& perm) {
390  auto arg1_view = make_ti(arg1);
391  auto arg2_view = make_ti(arg2);
392  return arg1_view.mult(arg2_view, perm);
393 }
394 
396 template <typename T, typename Range, typename Storage, typename Scalar,
397  typename Perm,
398  typename std::enable_if<
399  TiledArray::detail::is_numeric_v<Scalar> &&
400  TiledArray::detail::is_permutation_v<Perm>>::type* = nullptr>
401 inline btas::Tensor<T, Range, Storage> mult(
402  const btas::Tensor<T, Range, Storage>& arg1,
403  const btas::Tensor<T, Range, Storage>& arg2, const Scalar factor,
404  const Perm& perm) {
405  auto arg1_view = make_ti(arg1);
406  auto arg2_view = make_ti(arg2);
407  return arg1_view.mult(arg2_view, factor, perm);
408 }
409 
411 template <typename T, typename Range, typename Storage>
412 inline btas::Tensor<T, Range, Storage>& mult_to(
413  btas::Tensor<T, Range, Storage>& result,
414  const btas::Tensor<T, Range, Storage>& arg) {
415  auto result_view = make_ti(result);
416  auto arg_view = make_ti(arg);
417  result_view.mult_to(arg_view);
418  return result;
419 }
420 
422 template <typename T, typename Range, typename Storage, typename Scalar,
423  typename std::enable_if<
424  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
425 inline btas::Tensor<T, Range, Storage>& mult_to(
426  btas::Tensor<T, Range, Storage>& result,
427  const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
428  auto result_view = make_ti(result);
429  auto arg_view = make_ti(arg);
430  result_view.mult_to(arg_view, factor);
431  return result;
432 }
433 
434 // Generic element-wise binary operations
435 // ---------------------------------------------
436 
437 template <typename T, typename Range, typename Storage, typename Op>
438 inline auto binary(const btas::Tensor<T, Range, Storage>& arg1,
439  const btas::Tensor<T, Range, Storage>& arg2, Op&& op) {
440  auto arg1_view = make_ti(arg1);
441  auto arg2_view = make_ti(arg2);
442  return arg1_view.binary(arg2_view, std::forward<Op>(op));
443 }
444 
445 template <
446  typename T, typename Range, typename Storage, typename Op, typename Perm,
447  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
448 inline auto binary(const btas::Tensor<T, Range, Storage>& arg1,
449  const btas::Tensor<T, Range, Storage>& arg2, Op&& op,
450  const Perm& perm) {
451  auto arg1_view = make_ti(arg1);
452  auto arg2_view = make_ti(arg2);
453  return arg1_view.binary(arg2_view, std::forward<Op>(op), perm);
454 }
455 
456 template <typename T, typename Range, typename Storage, typename Op>
457 inline btas::Tensor<T, Range, Storage>& inplace_binary(
458  btas::Tensor<T, Range, Storage>& arg1,
459  const btas::Tensor<T, Range, Storage>& arg2, Op&& op) {
460  auto arg1_view = make_ti(arg1);
461  auto arg2_view = make_ti(arg2);
462  return arg1_view.inplace_binary(arg2_view, std::forward<Op>(op));
463 }
464 
465 template <typename T, typename Range, typename Storage, typename Scalar,
466  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
467 inline btas::Tensor<T, Range, Storage>& scale_to(
468  btas::Tensor<T, Range, Storage>& result, const Scalar factor) {
469  auto result_view = make_ti(result);
470  result_view.scale_to(factor);
471  return result;
472 }
473 
474 template <typename T, typename Range, typename Storage, typename Scalar,
475  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
476 inline decltype(auto) scale(const btas::Tensor<T, Range, Storage>& result,
477  const Scalar factor) {
478  auto result_view = make_ti(result);
479  return result_view.scale(factor);
480 }
481 
482 template <
483  typename T, typename Range, typename Storage, typename Scalar,
484  typename Perm,
485  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar> &&
486  TiledArray::detail::is_permutation_v<Perm>>* = nullptr>
487 inline decltype(auto) scale(const btas::Tensor<T, Range, Storage>& result,
488  const Scalar factor, const Perm& perm) {
489  auto result_view = make_ti(result);
490  return result_view.scale(factor, perm);
491 }
492 
493 template <typename T, typename Range, typename Storage>
494 inline btas::Tensor<T, Range, Storage>& neg_to(
495  btas::Tensor<T, Range, Storage>& result) {
496  auto result_view = make_ti(result);
497  result_view.neg_to();
498  return result;
499 }
500 
501 template <typename T, typename Range, typename Storage>
502 inline btas::Tensor<T, Range, Storage> neg(
503  const btas::Tensor<T, Range, Storage>& arg) {
504  auto arg_view = make_ti(arg);
505  return arg_view.neg();
506 }
507 
508 template <
509  typename T, typename Range, typename Storage, typename Perm,
510  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
511 inline btas::Tensor<T, Range, Storage> neg(
512  const btas::Tensor<T, Range, Storage>& arg, const Perm& perm) {
513  auto arg_view = make_ti(arg);
514  return arg_view.neg(perm);
515 }
516 
517 template <typename T, typename Range, typename Storage>
518 inline btas::Tensor<T, Range, Storage> conj(
519  const btas::Tensor<T, Range, Storage>& arg) {
520  auto arg_view = make_ti(arg);
521  return arg_view.conj();
522 }
523 
524 template <
525  typename T, typename Range, typename Storage, typename Perm,
526  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
527 inline btas::Tensor<T, Range, Storage> conj(
528  const btas::Tensor<T, Range, Storage>& arg, const Perm& perm) {
529  auto arg_view = make_ti(arg);
530  return arg_view.conj(perm);
531 }
532 
533 template <typename T, typename Range, typename Storage, typename Scalar,
534  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
535 inline btas::Tensor<T, Range, Storage> conj(
536  const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
537  auto arg_view = make_ti(arg);
538  return arg_view.conj(factor);
539 }
540 
541 template <
542  typename T, typename Range, typename Storage, typename Scalar,
543  typename Perm,
544  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar> &&
545  TiledArray::detail::is_permutation_v<Perm>>* = nullptr>
546 inline btas::Tensor<T, Range, Storage> conj(
547  const btas::Tensor<T, Range, Storage>& arg, const Scalar factor,
548  const Perm& perm) {
549  auto arg_view = make_ti(arg);
550  return arg_view.conj(factor, perm);
551 }
552 
553 template <typename T, typename Range, typename Storage>
554 inline btas::Tensor<T, Range, Storage>& conj_to(
555  btas::Tensor<T, Range, Storage>& arg) {
556  auto arg_view = make_ti(arg);
557  arg_view.conj_to();
558  return arg;
559 }
560 
561 template <typename T, typename Range, typename Storage, typename Scalar,
562  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
563 inline btas::Tensor<T, Range, Storage>& conj_to(
564  btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
565  auto arg_view = make_ti(arg);
566  arg_view.conj_to(factor);
567  return arg;
568 }
569 
570 // Generic element-wise unary operations
571 // ---------------------------------------------
572 
573 template <typename T, typename Range, typename Storage, typename Op>
574 inline auto unary(const btas::Tensor<T, Range, Storage>& arg, Op&& op) {
575  auto arg_view = make_ti(arg);
576  return arg_view.unary(std::forward<Op>(op));
577 }
578 
579 template <
580  typename T, typename Range, typename Storage, typename Op, typename Perm,
581  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
582 inline auto unary(const btas::Tensor<T, Range, Storage>& arg, Op&& op,
583  const Perm& perm) {
584  auto arg_view = make_ti(arg);
585  return arg_view.unary(std::forward<Op>(op), perm);
586 }
587 
588 template <typename T, typename Range, typename Storage, typename Op>
589 inline btas::Tensor<T, Range, Storage>& inplace_unary(
590  const btas::Tensor<T, Range, Storage>& arg, Op&& op) {
591  auto arg_view = make_ti(arg);
592  return arg_view.inplace_unary(std::forward<Op>(op));
593 }
594 
595 template <typename T, typename Range, typename Storage, typename Scalar>
596 inline btas::Tensor<T, Range, Storage> gemm(
597  const btas::Tensor<T, Range, Storage>& left,
598  const btas::Tensor<T, Range, Storage>& right, Scalar factor,
599  const TiledArray::math::GemmHelper& gemm_helper) {
600  // Check that the arguments are not empty and have the correct ranks
601  TA_ASSERT(!left.empty());
602  TA_ASSERT(left.range().rank() == gemm_helper.left_rank());
603  TA_ASSERT(!right.empty());
604  TA_ASSERT(right.range().rank() == gemm_helper.right_rank());
605 
606  // Construct the result Tensor
607  typedef btas::Tensor<T, Range, Storage> Tensor;
608  Tensor result(
609  gemm_helper.make_result_range<Range>(left.range(), right.range()));
610 
611  // Check that the inner dimensions of left and right match
612  TA_ASSERT(
614  gemm_helper.left_right_congruent(std::cbegin(left.range().lobound()),
615  std::cbegin(right.range().lobound())));
616  TA_ASSERT(
618  gemm_helper.left_right_congruent(std::cbegin(left.range().upbound()),
619  std::cbegin(right.range().upbound())));
620  TA_ASSERT(gemm_helper.left_right_congruent(
621  std::cbegin(left.range().extent()), std::cbegin(right.range().extent())));
622 
623  // Compute gemm dimensions
625  integer m = 1, n = 1, k = 1;
626  gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
627 
628  // Get the leading dimension for left and right matrices.
629  const integer lda =
630  (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
631  const integer ldb =
632  (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
633 
634  T factor_t(factor);
635 
636  TiledArray::math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m,
637  n, k, factor_t, left.data(), lda, right.data(),
638  ldb, T(0), result.data(), n);
639 
640  return result;
641 }
642 
643 template <typename T, typename Range, typename Storage, typename Scalar>
644 inline void gemm(btas::Tensor<T, Range, Storage>& result,
645  const btas::Tensor<T, Range, Storage>& left,
646  const btas::Tensor<T, Range, Storage>& right, Scalar factor,
647  const TiledArray::math::GemmHelper& gemm_helper) {
648  // Check that this tensor is not empty and has the correct rank
649  TA_ASSERT(!result.empty());
650  TA_ASSERT(result.range().rank() == gemm_helper.result_rank());
651 
652  // Check that the arguments are not empty and have the correct ranks
653  TA_ASSERT(!left.empty());
654  TA_ASSERT(left.range().rank() == gemm_helper.left_rank());
655  TA_ASSERT(!right.empty());
656  TA_ASSERT(right.range().rank() == gemm_helper.right_rank());
657 
658  // Check that the outer dimensions of left match the the corresponding
659  // dimensions in result
660  TA_ASSERT(
662  gemm_helper.left_result_congruent(std::cbegin(left.range().lobound()),
663  std::cbegin(result.range().lobound())));
664  TA_ASSERT(
666  gemm_helper.left_result_congruent(std::cbegin(left.range().upbound()),
667  std::cbegin(result.range().upbound())));
668  TA_ASSERT(
669  gemm_helper.left_result_congruent(std::cbegin(left.range().extent()),
670  std::cbegin(result.range().extent())));
671 
672  // Check that the outer dimensions of right match the the corresponding
673  // dimensions in result
675  gemm_helper.right_result_congruent(
676  std::cbegin(right.range().lobound()),
677  std::cbegin(result.range().lobound())));
679  gemm_helper.right_result_congruent(
680  std::cbegin(right.range().upbound()),
681  std::cbegin(result.range().upbound())));
682  TA_ASSERT(
683  gemm_helper.right_result_congruent(std::cbegin(right.range().extent()),
684  std::cbegin(result.range().extent())));
685 
686  // Check that the inner dimensions of left and right match
687  TA_ASSERT(
689  gemm_helper.left_right_congruent(std::cbegin(left.range().lobound()),
690  std::cbegin(right.range().lobound())));
691  TA_ASSERT(
693  gemm_helper.left_right_congruent(std::cbegin(left.range().upbound()),
694  std::cbegin(right.range().upbound())));
695  TA_ASSERT(gemm_helper.left_right_congruent(
696  std::cbegin(left.range().extent()), std::cbegin(right.range().extent())));
697 
698  // Compute gemm dimensions
700  integer m, n, k;
701  gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
702 
703  // Get the leading dimension for left and right matrices.
704  const integer lda =
705  (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
706  const integer ldb =
707  (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
708 
709  T factor_t(factor);
710 
711  TiledArray::math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m,
712  n, k, factor_t, left.data(), lda, right.data(),
713  ldb, T(1), result.data(), n);
714 }
715 
716 // sum of the hyperdiagonal elements
717 template <typename T, typename Range, typename Storage>
718 inline T trace(const btas::Tensor<T, Range, Storage>& arg) {
719  assert(false);
720 }
721 // foreach(i) result += arg[i]
722 template <typename T, typename Range, typename Storage>
723 inline T sum(const btas::Tensor<T, Range, Storage>& arg) {
724  return make_ti(arg).sum();
725 }
726 // foreach(i) result *= arg[i]
727 template <typename T, typename Range, typename Storage>
728 inline T product(const btas::Tensor<T, Range, Storage>& arg) {
729  return make_ti(arg).product();
730 }
731 
732 // foreach(i) result += arg[i] * arg[i]
733 template <typename T, typename Range, typename Storage>
734 inline T squared_norm(const btas::Tensor<T, Range, Storage>& arg) {
735  return make_ti(arg).squared_norm();
736 };
737 
738 // foreach(i) result += arg1[i] * arg2[i]
739 template <typename T, typename Range, typename Storage>
740 inline T dot(const btas::Tensor<T, Range, Storage>& arg1,
741  const btas::Tensor<T, Range, Storage>& arg2) {
742  return make_ti(arg1).dot(make_ti(arg2));
743 };
744 
745 template <typename T, typename Range, typename Storage>
746 inline T inner_product(const btas::Tensor<T, Range, Storage>& arg1,
747  const btas::Tensor<T, Range, Storage>& arg2) {
748  return make_ti(arg1).inner_product(make_ti(arg2));
749 };
750 
751 // sqrt(squared_norm(arg))
752 template <typename T, typename Range, typename Storage>
753 inline T norm(const btas::Tensor<T, Range, Storage>& arg) {
754  return make_ti(arg).norm();
755 }
756 // sqrt(squared_norm(arg))
757 template <typename T, typename Range, typename Storage, typename ResultType>
758 inline void norm(const btas::Tensor<T, Range, Storage>& arg,
759  ResultType& result) {
760  result = make_ti(arg).template norm<ResultType>();
761 }
762 // foreach(i) result = max(result, arg[i])
763 template <typename T, typename Range, typename Storage>
764 inline T max(const btas::Tensor<T, Range, Storage>& arg) {
765  return make_ti(arg).max();
766 }
767 // foreach(i) result = min(result, arg[i])
768 template <typename T, typename Range, typename Storage>
769 inline T min(const btas::Tensor<T, Range, Storage>& arg) {
770  return make_ti(arg).min();
771 }
772 // foreach(i) result = max(result, abs(arg[i]))
773 template <typename T, typename Range, typename Storage>
774 inline T abs_max(const btas::Tensor<T, Range, Storage>& arg) {
775  return make_ti(arg).abs_max();
776 }
777 // foreach(i) result = max(result, abs(arg[i]))
778 template <typename T, typename Range, typename Storage>
779 inline T abs_min(const btas::Tensor<T, Range, Storage>& arg) {
780  return make_ti(arg).abs_min();
781 }
782 } // namespace btas
783 
784 namespace TiledArray {
785 
786 namespace detail {
787 
790 template <typename T, typename Range, typename Storage>
791 struct TraceIsDefined<btas::Tensor<T, Range, Storage>, enable_if_numeric_t<T>>
792  : std::true_type {};
793 
794 } // namespace detail
795 
800 template <typename Perm>
801 typename std::enable_if<!TiledArray::detail::is_permutation_v<Perm>,
802  TiledArray::Range>::type
803 permute(const TiledArray::Range& r, const Perm& p) {
804  TiledArray::Permutation pp(p.begin(), p.end());
805  return pp * r;
806 }
807 
808 } // namespace TiledArray
809 
810 namespace TiledArray {
811 namespace detail {
812 
813 template <typename T, typename... Args>
814 struct is_tensor_helper<btas::Tensor<T, Args...>> : public std::true_type {};
815 
816 template <typename T, typename... Args>
818  : public std::true_type {};
819 
820 template <typename T, typename Enabler = void>
821 struct is_btas_tensor : public std::false_type {};
822 
823 template <typename T, typename... Args>
824 struct is_btas_tensor<btas::Tensor<T, Args...>> : public std::true_type {};
825 
826 template <typename T>
828 
830 template <::blas::Layout _Order, typename _Index, typename _Ordinal>
831 struct ordinal_traits<btas::RangeNd<_Order, _Index, _Ordinal>> {
832  static constexpr const auto type = _Order == ::blas::Layout::RowMajor
833  ? OrdinalType::RowMajor
834  : OrdinalType::ColMajor;
835 };
836 
837 } // namespace detail
838 } // namespace TiledArray
839 
840 namespace TiledArray {
842 template <typename T, typename Allocator, typename Range_, typename Storage>
843 struct Cast<TiledArray::Tensor<T, Allocator>,
844  btas::Tensor<T, Range_, Storage>> {
845  auto operator()(const btas::Tensor<T, Range_, Storage>& arg) const {
847  using std::begin;
848  std::copy(btas::cbegin(arg), btas::cend(arg), begin(result));
849  return result;
850  }
851 };
852 } // namespace TiledArray
853 
854 #endif /* TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED */
R make_result_range(const Left &left, const Right &right) const
Construct a result range based on left and right ranges.
Definition: gemm_helper.h:165
const index1_type * lobound_data() const
Range lower bound data accessor.
Definition: range.h:685
decltype(auto) make_ti(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:106
std::enable_if<!TiledArray::detail::is_permutation_v< Perm >, TiledArray::Range >::type permute(const TiledArray::Range &r, const Perm &p)
Definition: btas.h:803
Contraction to *GEMM helper.
Definition: gemm_helper.h:40
T trace(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:718
T product(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:728
::blas::Op Op
Definition: blas.h:46
unsigned int left_rank() const
Left-hand argument rank accessor.
Definition: gemm_helper.h:138
btas::Tensor< T, Range, Storage > & shift_to(btas::Tensor< T, Range, Storage > &arg, const Index &range_shift)
Shift the range of arg in place.
Definition: btas.h:210
T abs_min(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:779
T max(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:764
blas::Op left_op() const
Definition: gemm_helper.h:275
int64_t integer
Definition: blas.h:44
bool right_result_congruent(const Right &right, const Result &result) const
Definition: gemm_helper.h:221
auto operator()(const btas::Tensor< T, Range_, Storage > &arg) const
Definition: btas.h:845
Permutation inv() const
Construct the inverse of this permutation.
Definition: permutation.h:334
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
Definition: btas.h:44
bool left_right_congruent(const Left &left, const Right &right) const
Definition: gemm_helper.h:238
constexpr const bool is_tensor_of_tensor_v
Definition: type_traits.h:155
auto outer(const IndexList &p)
Definition: index_list.h:879
btas::Tensor< T, Range, Storage > shift(const btas::Tensor< T, Range, Storage > &arg, const Index &range_shift)
Shift the range of arg.
Definition: btas.h:197
container::svector< index1_type > index_type
Definition: range.h:52
bool left_result_congruent(const Left &left, const Result &result) const
Definition: gemm_helper.h:205
const index1_type * upbound_data() const
Range upper bound data accessor.
Definition: range.h:710
auto binary(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2, Op &&op)
Definition: btas.h:438
btas::Tensor< T, Range, Storage > permute(const btas::Tensor< T, Range, Storage > &arg, const TiledArray::Permutation &perm)
Computes the result of applying permutation perm to arg.
Definition: btas.h:158
btas::Tensor< T, Range, Storage > & neg_to(btas::Tensor< T, Range, Storage > &result)
Definition: btas.h:494
T dot(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
Definition: btas.h:740
auto rank(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1617
ordinal trait specifies properties of the ordinal
Definition: type_traits.h:330
btas::Tensor< T, Range, Storage > & subt_to(btas::Tensor< T, Range, Storage > &result, const btas::Tensor< T, Range, Storage > &arg)
result[i] -= arg[i]
Definition: btas.h:340
btas::Tensor< T, Range, Storage > clone(const btas::Tensor< T, Range, Storage > &arg)
Computes the result of applying permutation perm to arg.
Definition: btas.h:151
std::enable_if_t< is_numeric_v< T >, U > enable_if_numeric_t
SFINAE type for enabling code when T is a numeric type.
Definition: type_traits.h:649
decltype(auto) scale(const btas::Tensor< T, Range, Storage > &result, const Scalar factor)
Definition: btas.h:476
const auto & data() const
Permutation data accessor.
Definition: permutation.h:388
T min(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:769
auto inner(const IndexList &p)
Definition: index_list.h:872
T squared_norm(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:734
Tile cast operation.
Definition: cast.h:168
btas::Tensor< T, Range, Storage > & add_to(btas::Tensor< T, Range, Storage > &result, const btas::Tensor< T, Range, Storage > &arg)
result[i] += arg[i]
Definition: btas.h:267
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
unsigned int result_rank() const
Result rank accessor.
Definition: gemm_helper.h:133
void ignore_tile_position(bool b)
Definition: utility.h:81
constexpr const bool is_btas_tensor_v
Definition: btas.h:827
void gemm(Op op_a, Op op_b, const integer m, const integer n, const integer k, const S1 alpha, const T1 *a, const integer lda, const T2 *b, const integer ldb, const S2 beta, T3 *c, const integer ldc)
Definition: blas.h:71
const TiledArray::Range & make_ta_range(const TiledArray::Range &range)
Definition: btas.h:58
btas::Tensor< T, Range, Storage > mult(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
result[i] = arg1[i] * arg2[i]
Definition: btas.h:363
T sum(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:723
bool is_congruent(const BlockRange &r1, const BlockRange &r2)
Test that two BlockRange objects are congruent.
Definition: block_range.h:400
bool is_congruent(const btas::RangeNd< Order, Args... > &r1, const btas::RangeNd< Order, Args... > &r2)
Test if the two ranges are congruent.
Definition: btas.h:98
btas::Tensor< T, Range, Storage > neg(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:502
btas::Tensor< T, Range, Storage > & conj_to(btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:554
Permute a tile.
Definition: permute.h:134
void compute_matrix_sizes(blas::integer &m, blas::integer &n, blas::integer &k, const Left &left, const Right &right) const
Compute the matrix dimension that can be used in a *GEMM call.
Definition: gemm_helper.h:254
blas::Op right_op() const
Definition: gemm_helper.h:276
TiledArray::Range::index_type index_type
Definition: btas.h:48
T abs_max(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:774
auto unary(const btas::Tensor< T, Range, Storage > &arg, Op &&op)
Definition: btas.h:574
TiledArray::Range::ordinal_type ordinal_type
Definition: btas.h:49
btas::Tensor< T, Range, Storage > add(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
result[i] = arg1[i] + arg2[i]
Definition: btas.h:218
Tensor interface for external data.
unsigned int right_rank() const
Right-hand argument rank accessor.
Definition: gemm_helper.h:143
bool congruent(const Range &r1, const Range &r2)
Test if the two ranges are congruent.
Definition: btas.h:83
btas::Tensor< T, Range, Storage > subt(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
result[i] = arg1[i] - arg2[i]
Definition: btas.h:291
btas::Tensor< T, Range, Storage > gemm(const btas::Tensor< T, Range, Storage > &left, const btas::Tensor< T, Range, Storage > &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper)
Definition: btas.h:596
T inner_product(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
Definition: btas.h:746
btas::Tensor< T, Range, Storage > & inplace_unary(const btas::Tensor< T, Range, Storage > &arg, Op &&op)
Definition: btas.h:589
btas::Tensor< T, Range, Storage > & inplace_binary(btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2, Op &&op)
Definition: btas.h:457
btas::Tensor< T, Range, Storage > conj(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:518
btas::Tensor< T, Range, Storage > & mult_to(btas::Tensor< T, Range, Storage > &result, const btas::Tensor< T, Range, Storage > &arg)
result[i] *= arg[i]
Definition: btas.h:412
An N-dimensional tensor object.
Definition: tensor.h:50
Permutation of a bipartite set.
Definition: permutation.h:610
std::size_t ordinal_type
Ordinal type, to conform to TWG spec.
Definition: range.h:59
T norm(const btas::Tensor< T, Range, Storage > &arg)
Definition: btas.h:753
auto inner_size(const IndexList &p)
Definition: index_list.h:881
bool operator==(const TiledArray::Range &range1, const btas::BaseRangeNd< Args... > &range2)
Definition: btas.h:120
unsigned int rank() const
Rank accessor.
Definition: range.h:669
btas::Tensor< T, Range, Storage > & scale_to(btas::Tensor< T, Range, Storage > &result, const Scalar factor)
Definition: btas.h:467
A (hyperrectangular) interval on , space of integer -indices.
Definition: range.h:46