23 #ifndef TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED
24 #define TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED
30 #include "TiledArray/config.h"
37 #include <btas/features.h>
38 #include <btas/generic/axpy_impl.h>
39 #include <btas/generic/permute.h>
40 #include <btas/tensor.h>
42 #include <madness/world/archive.h>
47 const static blas::Layout order = blas::Layout::RowMajor;
50 constexpr
static const bool is_general_layout =
false;
66 template <::blas::Layout Order,
typename... Args>
68 const btas::RangeNd<Order, Args...>& range) {
69 TA_ASSERT(Order == ::blas::Layout::RowMajor &&
70 "TiledArray::detail::make_ta_range(btas::RangeNd<Order,...>): "
71 "not supported for col-major Order");
97 template <blas::Layout Order,
typename... Args>
99 const btas::RangeNd<Order, Args...>& r2) {
100 return (r1.rank() == r2.rank()) &&
101 std::equal(r1.extent_data(), r1.extent_data() + r1.rank(),
105 template <
typename T,
typename Range,
typename Storage>
106 decltype(
auto)
make_ti(const
btas::Tensor<T, Range, Storage>& arg) {
108 btas::Tensor<T, Range, Storage>>(
109 arg.range(), arg.data());
112 template <
typename T,
typename Range,
typename Storage>
115 btas::Tensor<T, Range, Storage>>(
116 arg.range(), arg.data());
119 template <
typename... Args>
121 const btas::BaseRangeNd<Args...>& range2) {
123 if (
rank == range2.rank()) {
126 const auto lobound_match =
127 std::equal(range1_lobound_data, range1_lobound_data +
rank,
128 cbegin(range2.lobound()));
131 return std::equal(range1_upbound_data, range1_upbound_data +
rank,
132 cbegin(range2.upbound()));
138 template <
typename T1,
typename S1,
typename T2,
typename S2>
139 bool operator==(
const btas::Tensor<T1, TiledArray::Range, S1>& t1,
140 const btas::Tensor<T2, TiledArray::Range, S2>& t2) {
144 return t1_view.size() == t2_view.size() &&
145 std::equal(data(t1_view), data(t1_view) + t1_view.size(),
150 template <
typename T,
typename Range,
typename Storage>
151 inline btas::Tensor<T, Range, Storage>
clone(
152 const btas::Tensor<T, Range, Storage>& arg) {
157 template <
typename T,
typename Range,
typename Storage>
158 inline btas::Tensor<T, Range, Storage>
permute(
159 const btas::Tensor<T, Range, Storage>& arg,
161 btas::Tensor<T, Range, Storage> result;
167 template <
typename T,
typename Range,
typename Storage>
168 inline btas::Tensor<T, Range, Storage>
permute(
169 const btas::Tensor<T, Range, Storage>& arg,
171 btas::Tensor<T, Range, Storage> result;
172 constexpr
bool is_tot =
174 if constexpr (!is_tot) {
181 auto inner_perm =
inner(perm);
183 for (
auto& x : result) x = p(x, inner_perm);
196 template <
typename T,
typename Range,
typename Storage,
typename Index>
197 inline btas::Tensor<T, Range, Storage>
shift(
198 const btas::Tensor<T, Range, Storage>& arg,
const Index& range_shift) {
199 auto shifted_arg =
clone(arg);
209 template <
typename T,
typename Range,
typename Storage,
typename Index>
211 btas::Tensor<T, Range, Storage>& arg,
const Index& range_shift) {
212 const_cast<Range&
>(arg.range()).inplace_shift(range_shift);
217 template <
typename T,
typename Range,
typename Storage>
218 inline btas::Tensor<T, Range, Storage>
add(
219 const btas::Tensor<T, Range, Storage>& arg1,
220 const btas::Tensor<T, Range, Storage>& arg2) {
221 auto arg1_view =
make_ti(arg1);
222 auto arg2_view =
make_ti(arg2);
223 return arg1_view.add(arg2_view);
227 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
228 typename std::enable_if<
229 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
230 inline btas::Tensor<T, Range, Storage>
add(
231 const btas::Tensor<T, Range, Storage>& arg1,
232 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor) {
233 auto arg1_view =
make_ti(arg1);
234 auto arg2_view =
make_ti(arg2);
235 return arg1_view.add(arg2_view, factor);
240 typename T,
typename Range,
typename Storage,
typename Perm,
241 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
242 inline btas::Tensor<T, Range, Storage>
add(
243 const btas::Tensor<T, Range, Storage>& arg1,
244 const btas::Tensor<T, Range, Storage>& arg2,
const Perm& perm) {
245 auto arg1_view =
make_ti(arg1);
246 auto arg2_view =
make_ti(arg2);
247 return arg1_view.add(arg2_view, perm);
251 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
253 typename std::enable_if<
254 TiledArray::detail::is_numeric_v<Scalar> &&
255 TiledArray::detail::is_permutation_v<Perm>>::type* =
nullptr>
256 inline btas::Tensor<T, Range, Storage>
add(
257 const btas::Tensor<T, Range, Storage>& arg1,
258 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor,
260 auto arg1_view =
make_ti(arg1);
261 auto arg2_view =
make_ti(arg2);
262 return arg1_view.add(arg2_view, factor, perm);
266 template <
typename T,
typename Range,
typename Storage>
267 inline btas::Tensor<T, Range, Storage>&
add_to(
268 btas::Tensor<T, Range, Storage>& result,
269 const btas::Tensor<T, Range, Storage>& arg) {
270 auto result_view =
make_ti(result);
272 result_view.add_to(arg_view);
277 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
278 typename std::enable_if<
279 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
280 inline btas::Tensor<T, Range, Storage>&
add_to(
281 btas::Tensor<T, Range, Storage>& result,
282 const btas::Tensor<T, Range, Storage>& arg,
const Scalar factor) {
283 auto result_view =
make_ti(result);
285 result_view.add_to(arg_view, factor);
290 template <
typename T,
typename Range,
typename Storage>
291 inline btas::Tensor<T, Range, Storage>
subt(
292 const btas::Tensor<T, Range, Storage>& arg1,
293 const btas::Tensor<T, Range, Storage>& arg2) {
294 auto arg1_view =
make_ti(arg1);
295 auto arg2_view =
make_ti(arg2);
296 return arg1_view.subt(arg2_view);
300 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
301 typename std::enable_if<
302 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
303 inline btas::Tensor<T, Range, Storage>
subt(
304 const btas::Tensor<T, Range, Storage>& arg1,
305 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor) {
306 auto arg1_view =
make_ti(arg1);
307 auto arg2_view =
make_ti(arg2);
308 return arg1_view.subt(arg2_view, factor);
313 typename T,
typename Range,
typename Storage,
typename Perm,
314 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
315 inline btas::Tensor<T, Range, Storage>
subt(
316 const btas::Tensor<T, Range, Storage>& arg1,
317 const btas::Tensor<T, Range, Storage>& arg2,
const Perm& perm) {
318 auto arg1_view =
make_ti(arg1);
319 auto arg2_view =
make_ti(arg2);
320 return arg1_view.subt(arg2_view, perm);
324 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
326 typename std::enable_if<
327 TiledArray::detail::is_numeric_v<Scalar> &&
328 TiledArray::detail::is_permutation_v<Perm>>::type* =
nullptr>
329 inline btas::Tensor<T, Range, Storage>
subt(
330 const btas::Tensor<T, Range, Storage>& arg1,
331 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor,
333 auto arg1_view =
make_ti(arg1);
334 auto arg2_view =
make_ti(arg2);
335 return arg1_view.subt(arg2_view, factor, perm);
339 template <
typename T,
typename Range,
typename Storage>
340 inline btas::Tensor<T, Range, Storage>&
subt_to(
341 btas::Tensor<T, Range, Storage>& result,
342 const btas::Tensor<T, Range, Storage>& arg) {
343 auto result_view =
make_ti(result);
345 result_view.subt_to(arg_view);
349 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
350 typename std::enable_if<
351 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
352 inline btas::Tensor<T, Range, Storage>&
subt_to(
353 btas::Tensor<T, Range, Storage>& result,
354 const btas::Tensor<T, Range, Storage>& arg,
const Scalar factor) {
355 auto result_view =
make_ti(result);
357 result_view.subt_to(arg_view, factor);
362 template <
typename T,
typename Range,
typename Storage>
363 inline btas::Tensor<T, Range, Storage>
mult(
364 const btas::Tensor<T, Range, Storage>& arg1,
365 const btas::Tensor<T, Range, Storage>& arg2) {
366 auto arg1_view =
make_ti(arg1);
367 auto arg2_view =
make_ti(arg2);
368 return arg1_view.mult(arg2_view);
372 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
373 typename std::enable_if<
374 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
375 inline btas::Tensor<T, Range, Storage>
mult(
376 const btas::Tensor<T, Range, Storage>& arg1,
377 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor) {
378 auto arg1_view =
make_ti(arg1);
379 auto arg2_view =
make_ti(arg2);
380 return arg1_view.mult(arg2_view, factor);
385 typename T,
typename Range,
typename Storage,
typename Perm,
386 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
387 inline btas::Tensor<T, Range, Storage>
mult(
388 const btas::Tensor<T, Range, Storage>& arg1,
389 const btas::Tensor<T, Range, Storage>& arg2,
const Perm& perm) {
390 auto arg1_view =
make_ti(arg1);
391 auto arg2_view =
make_ti(arg2);
392 return arg1_view.mult(arg2_view, perm);
396 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
398 typename std::enable_if<
399 TiledArray::detail::is_numeric_v<Scalar> &&
400 TiledArray::detail::is_permutation_v<Perm>>::type* =
nullptr>
401 inline btas::Tensor<T, Range, Storage>
mult(
402 const btas::Tensor<T, Range, Storage>& arg1,
403 const btas::Tensor<T, Range, Storage>& arg2,
const Scalar factor,
405 auto arg1_view =
make_ti(arg1);
406 auto arg2_view =
make_ti(arg2);
407 return arg1_view.mult(arg2_view, factor, perm);
411 template <
typename T,
typename Range,
typename Storage>
412 inline btas::Tensor<T, Range, Storage>&
mult_to(
413 btas::Tensor<T, Range, Storage>& result,
414 const btas::Tensor<T, Range, Storage>& arg) {
415 auto result_view =
make_ti(result);
417 result_view.mult_to(arg_view);
422 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
423 typename std::enable_if<
424 TiledArray::detail::is_numeric_v<Scalar>>::type* =
nullptr>
425 inline btas::Tensor<T, Range, Storage>&
mult_to(
426 btas::Tensor<T, Range, Storage>& result,
427 const btas::Tensor<T, Range, Storage>& arg,
const Scalar factor) {
428 auto result_view =
make_ti(result);
430 result_view.mult_to(arg_view, factor);
437 template <
typename T,
typename Range,
typename Storage,
typename Op>
438 inline auto binary(
const btas::Tensor<T, Range, Storage>& arg1,
439 const btas::Tensor<T, Range, Storage>& arg2,
Op&& op) {
440 auto arg1_view =
make_ti(arg1);
441 auto arg2_view =
make_ti(arg2);
442 return arg1_view.binary(arg2_view, std::forward<Op>(op));
446 typename T,
typename Range,
typename Storage,
typename Op,
typename Perm,
447 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
448 inline auto binary(
const btas::Tensor<T, Range, Storage>& arg1,
449 const btas::Tensor<T, Range, Storage>& arg2,
Op&& op,
451 auto arg1_view =
make_ti(arg1);
452 auto arg2_view =
make_ti(arg2);
453 return arg1_view.binary(arg2_view, std::forward<Op>(op), perm);
456 template <
typename T,
typename Range,
typename Storage,
typename Op>
458 btas::Tensor<T, Range, Storage>& arg1,
459 const btas::Tensor<T, Range, Storage>& arg2,
Op&& op) {
460 auto arg1_view =
make_ti(arg1);
461 auto arg2_view =
make_ti(arg2);
462 return arg1_view.inplace_binary(arg2_view, std::forward<Op>(op));
465 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
466 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* =
nullptr>
468 btas::Tensor<T, Range, Storage>& result,
const Scalar factor) {
469 auto result_view =
make_ti(result);
470 result_view.scale_to(factor);
474 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
475 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* =
nullptr>
476 inline decltype(
auto)
scale(const
btas::Tensor<T, Range, Storage>& result,
477 const Scalar factor) {
478 auto result_view =
make_ti(result);
479 return result_view.scale(factor);
483 typename T,
typename Range,
typename Storage,
typename Scalar,
485 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar> &&
486 TiledArray::detail::is_permutation_v<Perm>>* =
nullptr>
487 inline decltype(
auto)
scale(const
btas::Tensor<T, Range, Storage>& result,
488 const Scalar factor, const Perm& perm) {
489 auto result_view =
make_ti(result);
490 return result_view.scale(factor, perm);
493 template <
typename T,
typename Range,
typename Storage>
494 inline btas::Tensor<T, Range, Storage>&
neg_to(
495 btas::Tensor<T, Range, Storage>& result) {
496 auto result_view =
make_ti(result);
497 result_view.neg_to();
501 template <
typename T,
typename Range,
typename Storage>
502 inline btas::Tensor<T, Range, Storage>
neg(
503 const btas::Tensor<T, Range, Storage>& arg) {
505 return arg_view.neg();
509 typename T,
typename Range,
typename Storage,
typename Perm,
510 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
511 inline btas::Tensor<T, Range, Storage>
neg(
512 const btas::Tensor<T, Range, Storage>& arg,
const Perm& perm) {
514 return arg_view.neg(perm);
517 template <
typename T,
typename Range,
typename Storage>
518 inline btas::Tensor<T, Range, Storage>
conj(
519 const btas::Tensor<T, Range, Storage>& arg) {
521 return arg_view.conj();
525 typename T,
typename Range,
typename Storage,
typename Perm,
526 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
527 inline btas::Tensor<T, Range, Storage>
conj(
528 const btas::Tensor<T, Range, Storage>& arg,
const Perm& perm) {
530 return arg_view.conj(perm);
533 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
534 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* =
nullptr>
535 inline btas::Tensor<T, Range, Storage>
conj(
536 const btas::Tensor<T, Range, Storage>& arg,
const Scalar factor) {
538 return arg_view.conj(factor);
542 typename T,
typename Range,
typename Storage,
typename Scalar,
544 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar> &&
545 TiledArray::detail::is_permutation_v<Perm>>* =
nullptr>
546 inline btas::Tensor<T, Range, Storage>
conj(
547 const btas::Tensor<T, Range, Storage>& arg,
const Scalar factor,
550 return arg_view.conj(factor, perm);
553 template <
typename T,
typename Range,
typename Storage>
554 inline btas::Tensor<T, Range, Storage>&
conj_to(
555 btas::Tensor<T, Range, Storage>& arg) {
561 template <
typename T,
typename Range,
typename Storage,
typename Scalar,
562 std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* =
nullptr>
563 inline btas::Tensor<T, Range, Storage>&
conj_to(
564 btas::Tensor<T, Range, Storage>& arg,
const Scalar factor) {
566 arg_view.conj_to(factor);
573 template <
typename T,
typename Range,
typename Storage,
typename Op>
574 inline auto unary(
const btas::Tensor<T, Range, Storage>& arg,
Op&& op) {
576 return arg_view.unary(std::forward<Op>(op));
580 typename T,
typename Range,
typename Storage,
typename Op,
typename Perm,
581 typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
582 inline auto unary(
const btas::Tensor<T, Range, Storage>& arg,
Op&& op,
585 return arg_view.unary(std::forward<Op>(op), perm);
588 template <
typename T,
typename Range,
typename Storage,
typename Op>
590 const btas::Tensor<T, Range, Storage>& arg,
Op&& op) {
592 return arg_view.inplace_unary(std::forward<Op>(op));
595 template <
typename T,
typename Range,
typename Storage,
typename Scalar>
596 inline btas::Tensor<T, Range, Storage>
gemm(
597 const btas::Tensor<T, Range, Storage>& left,
598 const btas::Tensor<T, Range, Storage>& right, Scalar factor,
607 typedef btas::Tensor<T, Range, Storage> Tensor;
615 std::cbegin(right.range().lobound())));
619 std::cbegin(right.range().upbound())));
621 std::cbegin(left.range().extent()), std::cbegin(right.range().extent())));
630 (gemm_helper.
left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
632 (gemm_helper.
right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
637 n, k, factor_t, left.data(), lda, right.data(),
638 ldb, T(0), result.data(), n);
643 template <
typename T,
typename Range,
typename Storage,
typename Scalar>
644 inline void gemm(btas::Tensor<T, Range, Storage>& result,
645 const btas::Tensor<T, Range, Storage>& left,
646 const btas::Tensor<T, Range, Storage>& right, Scalar factor,
663 std::cbegin(result.range().lobound())));
667 std::cbegin(result.range().upbound())));
670 std::cbegin(result.range().extent())));
676 std::cbegin(right.range().lobound()),
677 std::cbegin(result.range().lobound())));
680 std::cbegin(right.range().upbound()),
681 std::cbegin(result.range().upbound())));
684 std::cbegin(result.range().extent())));
690 std::cbegin(right.range().lobound())));
694 std::cbegin(right.range().upbound())));
696 std::cbegin(left.range().extent()), std::cbegin(right.range().extent())));
705 (gemm_helper.
left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
707 (gemm_helper.
right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
712 n, k, factor_t, left.data(), lda, right.data(),
713 ldb, T(1), result.data(), n);
717 template <
typename T,
typename Range,
typename Storage>
718 inline T
trace(
const btas::Tensor<T, Range, Storage>& arg) {
722 template <
typename T,
typename Range,
typename Storage>
723 inline T
sum(
const btas::Tensor<T, Range, Storage>& arg) {
727 template <
typename T,
typename Range,
typename Storage>
728 inline T
product(
const btas::Tensor<T, Range, Storage>& arg) {
733 template <
typename T,
typename Range,
typename Storage>
735 return make_ti(arg).squared_norm();
739 template <
typename T,
typename Range,
typename Storage>
740 inline T
dot(
const btas::Tensor<T, Range, Storage>& arg1,
741 const btas::Tensor<T, Range, Storage>& arg2) {
745 template <
typename T,
typename Range,
typename Storage>
747 const btas::Tensor<T, Range, Storage>& arg2) {
752 template <
typename T,
typename Range,
typename Storage>
753 inline T
norm(
const btas::Tensor<T, Range, Storage>& arg) {
757 template <
typename T,
typename Range,
typename Storage,
typename ResultType>
758 inline void norm(
const btas::Tensor<T, Range, Storage>& arg,
759 ResultType& result) {
760 result =
make_ti(arg).template norm<ResultType>();
763 template <
typename T,
typename Range,
typename Storage>
764 inline T
max(
const btas::Tensor<T, Range, Storage>& arg) {
768 template <
typename T,
typename Range,
typename Storage>
769 inline T
min(
const btas::Tensor<T, Range, Storage>& arg) {
773 template <
typename T,
typename Range,
typename Storage>
774 inline T
abs_max(
const btas::Tensor<T, Range, Storage>& arg) {
778 template <
typename T,
typename Range,
typename Storage>
779 inline T
abs_min(
const btas::Tensor<T, Range, Storage>& arg) {
790 template <
typename T,
typename Range,
typename Storage>
800 template <
typename Perm>
801 typename std::enable_if<!TiledArray::detail::is_permutation_v<Perm>,
813 template <
typename T,
typename... Args>
816 template <
typename T,
typename... Args>
818 :
public std::true_type {};
820 template <
typename T,
typename Enabler =
void>
823 template <
typename T,
typename... Args>
826 template <
typename T>
830 template <::blas::Layout _Order,
typename _Index,
typename _Ordinal>
832 static constexpr
const auto type = _Order == ::blas::Layout::RowMajor
833 ? OrdinalType::RowMajor
834 : OrdinalType::ColMajor;
842 template <
typename T,
typename Allocator,
typename Range_,
typename Storage>
844 btas::Tensor<T, Range_, Storage>> {
845 auto operator()(
const btas::Tensor<T, Range_, Storage>& arg)
const {
848 std::copy(btas::cbegin(arg), btas::cend(arg), begin(result));