26 #ifndef TILEDARRAY_TENSOR_KENERLS_H__INCLUDED
27 #define TILEDARRAY_TENSOR_KENERLS_H__INCLUDED
34 template <
typename,
typename>
62 template <
typename TR,
typename Op,
typename T1,
typename... Ts,
63 typename std::enable_if<
64 is_tensor<TR, T1, Ts...>::value ||
65 is_tensor_of_tensor<TR, T1, Ts...>::value>::type* =
nullptr>
66 inline TR
tensor_op(
Op&& op,
const T1& tensor1,
const Ts&... tensors) {
67 if constexpr (std::is_invocable_r_v<TR, Op, const T1&, const Ts&...>) {
68 return std::forward<Op>(op)(tensor1, tensors...);
93 template <
typename TR,
typename Op,
typename T1,
typename... Ts,
94 typename std::enable_if<
95 (is_tensor<T1, Ts...>::value ||
96 is_tensor_of_tensor<TR, T1, Ts...>::value) &&
97 is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
99 const Ts&... tensors) {
100 if constexpr (std::is_invocable_r_v<TR,
Op,
const Permutation&,
const T1&,
102 return std::forward<Op>(op)(perm, tensor1, tensors...);
105 tensor1, tensors...);
116 template <
typename T>
120 template <
typename Op,
typename Tensor,
typename... Tensors>
125 const auto& range = tensor.range();
127 this->operator()(result, std::forward<Op>(op), std::forward<Tensor>(tensor),
128 std::forward<Tensors>(tensors)...);
134 template <
typename Op,
typename Tensor,
typename... Tensors>
136 Tensors&&... tensors)
const {
140 const auto& range = result.range();
141 for (
auto&& i : range)
142 result[std::forward<decltype(i)>(i)] = std::forward<Op>(op)(
143 std::forward<Tensor>(tensor)[std::forward<decltype(i)>(i)],
144 std::forward<Tensors>(tensors)[std::forward<decltype(i)>(i)]...);
147 template <
typename Op,
typename Tensor,
typename... Tensors>
149 Tensors&&... tensors)
const {
155 const auto& range = tensor.range();
156 T result(perm ^ range);
157 this->operator()(result, std::forward<Op>(op), perm,
158 std::forward<Tensor>(tensor),
159 std::forward<Tensors>(tensors)...);
163 template <
typename Op,
typename Tensor,
typename... Tensors>
165 Tensors&&... tensors)
const {
172 const auto& range = tensor.range();
173 for (
auto&& i : range)
174 result[perm ^ std::forward<decltype(i)>(i)] = std::forward<Op>(op)(
175 std::forward<Tensor>(tensor)[std::forward<decltype(i)>(i)],
176 std::forward<Tensors>(tensors)[std::forward<decltype(i)>(i)]...);
193 template <
typename Op,
typename TR,
typename... Ts,
194 typename std::enable_if<
195 is_tensor<TR, Ts...>::value &&
196 is_contiguous_tensor<TR, Ts...>::value>::type* =
nullptr>
201 const auto volume = result.range().volume();
217 template <
typename Op,
typename TR,
typename... Ts,
218 typename std::enable_if<
219 is_tensor_of_tensor<TR, Ts...>::value &&
220 is_contiguous_tensor<TR, Ts...>::value>::type* =
nullptr>
225 const auto volume = result.range().volume();
227 for (decltype(result.range().volume()) i = 0ul; i <
volume; ++i) {
257 template <
typename InputOp,
typename OutputOp,
typename TR,
typename T1,
259 typename std::enable_if<
260 is_tensor<TR, T1, Ts...>::value &&
261 is_contiguous_tensor<TR, T1, Ts...>::value>::type* =
nullptr>
264 const T1& tensor1,
const Ts&... tensors) {
271 permute(std::forward<InputOp>(input_op), std::forward<OutputOp>(output_op),
272 result, perm, tensor1, tensors...);
309 template <
typename InputOp,
typename OutputOp,
typename TR,
typename T1,
311 typename std::enable_if<
312 is_tensor_of_tensor<TR, T1, Ts...>::value &&
313 is_contiguous_tensor<TR, T1, Ts...>::value>::type* =
nullptr>
316 const T1& tensor1,
const Ts&... tensors) {
323 auto wrapper_input_op =
324 [&input_op](
typename T1::const_reference MADNESS_RESTRICT value1,
325 typename Ts::const_reference MADNESS_RESTRICT... values) ->
326 typename T1::value_type {
327 return tensor_op<TR::value_type>(std::forward<InputOp>(input_op),
331 auto wrapper_output_op =
332 [&output_op](
typename T1::pointer MADNESS_RESTRICT
const result_value,
333 const typename TR::value_type value) {
338 permute(std::move(wrapper_input_op), std::move(wrapper_output_op), result,
339 perm, tensor1, tensors...);
352 template <
typename Op,
typename TR,
typename... Ts,
353 typename std::enable_if<
354 is_tensor<TR, Ts...>::value &&
355 !(is_contiguous_tensor<TR, Ts...>::value)>::type* =
nullptr>
360 const auto stride =
inner_size(result, tensors...);
361 const auto volume = result.range().volume();
363 for (decltype(result.range().volume()) i = 0ul; i <
volume; i += stride)
365 result.data() + result.range().ordinal(i),
366 (tensors.data() + tensors.range().ordinal(i))...);
379 template <
typename Op,
typename TR,
typename... Ts,
380 typename std::enable_if<
381 is_tensor_of_tensor<TR, Ts...>::value &&
382 !(is_contiguous_tensor<TR, Ts...>::value)>::type* =
nullptr>
387 const auto stride =
inner_size(result, tensors...);
388 const auto volume = result.range().volume();
390 auto inplace_tensor_range =
392 typename TR::pointer MADNESS_RESTRICT
const result_data,
393 typename Ts::const_pointer MADNESS_RESTRICT
const... tensors_data) {
394 for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
398 for (decltype(result.range().volume()) i = 0ul; i <
volume; i += stride)
399 inplace_tensor_range(result.data() + result.range().ordinal(i),
400 (tensors.data() + tensors.range().ordinal(i))...);
417 template <
typename Op,
typename TR,
typename... Ts,
418 typename std::enable_if<
419 is_tensor<TR, Ts...>::value &&
420 is_contiguous_tensor<TR, Ts...>::value>::type* =
nullptr>
425 const auto volume = result.range().volume();
427 auto wrapper_op = [&op](
typename TR::pointer MADNESS_RESTRICT result,
428 typename Ts::const_reference MADNESS_RESTRICT... ts) {
429 new (result)
typename TR::value_type(std::forward<Op>(op)(ts...));
449 typename Op,
typename TR,
typename... Ts,
450 typename std::enable_if<is_tensor_of_tensor<TR, Ts...>::value &&
451 is_contiguous_tensor<TR>::value>::type* =
nullptr>
452 inline void tensor_init(
Op&& op, TR& result,
const Ts&... tensors) {
456 const auto volume = result.range().volume();
458 for (decltype(result.range().volume()) i = 0ul; i <
volume; ++i) {
459 new (result.data() + i)
typename TR::value_type(
460 tensor_op<typename TR::value_type>(op, tensors[i]...));
486 typename Op,
typename TR,
typename T1,
typename... Ts,
487 typename std::enable_if<is_tensor<TR, T1, Ts...>::value>::type* =
nullptr>
489 const T1& tensor1,
const Ts&... tensors) {
495 auto output_op = [](
typename TR::pointer MADNESS_RESTRICT result,
496 typename TR::const_reference MADNESS_RESTRICT temp) {
497 new (result)
typename TR::value_type(temp);
500 permute(std::forward<Op>(op), std::move(output_op), result, perm, tensor1,
519 template <
typename Op,
typename TR,
typename T1,
typename... Ts,
520 typename std::enable_if<
521 is_tensor_of_tensor<TR, T1, Ts...>::value>::type* =
nullptr>
523 const T1& tensor1,
const Ts&... tensors) {
529 auto output_op = [](
typename TR::pointer MADNESS_RESTRICT result,
530 typename TR::const_reference MADNESS_RESTRICT temp) {
531 new (result)
typename TR::value_type(temp);
533 auto tensor_input_op =
534 [&op](
typename T1::const_reference MADNESS_RESTRICT value1,
535 typename Ts::const_reference MADNESS_RESTRICT... values) ->
536 typename TR::value_type {
537 return tensor_op<typename TR::value_type>(std::forward<Op>(op), value1,
541 permute(std::move(tensor_input_op), output_op, result, perm, tensor1,
563 typename Op,
typename TR,
typename T1,
typename... Ts,
564 typename std::enable_if<
565 is_tensor<TR, T1, Ts...>::value && is_contiguous_tensor<TR>::value &&
566 !is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
568 const Ts&... tensors) {
572 const auto stride =
inner_size(tensor1, tensors...);
573 const auto volume = tensor1.range().volume();
575 auto wrapper_op = [&op](
typename TR::pointer MADNESS_RESTRICT result_ptr,
576 const typename T1::value_type value1,
577 const typename Ts::value_type... values) {
578 new (result_ptr)
typename T1::value_type(op(value1, values...));
581 for (decltype(tensor1.range().volume()) i = 0ul; i <
volume; i += stride)
583 (tensor1.data() + tensor1.range().ordinal(i)),
584 (tensors.data() + tensors.range().ordinal(i))...);
604 template <
typename Op,
typename TR,
typename T1,
typename... Ts,
605 typename std::enable_if<
606 is_tensor_of_tensor<TR, T1, Ts...>::value &&
607 is_contiguous_tensor<TR>::value &&
608 !is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
609 inline void tensor_init(
Op&& op, TR& result,
const T1& tensor1,
610 const Ts&... tensors) {
614 const auto stride =
inner_size(tensor1, tensors...);
615 const auto volume = tensor1.range().volume();
617 auto inplace_tensor_range =
619 typename TR::pointer MADNESS_RESTRICT
const result_data,
620 typename T1::const_pointer MADNESS_RESTRICT
const tensor1_data,
621 typename Ts::const_pointer MADNESS_RESTRICT
const... tensors_data) {
622 for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
623 new (result_data + i)
624 typename TR::value_type(tensor_op<typename TR::value_type>(
625 op, tensor1_data[i], tensors_data[i]...));
629 inplace_tensor_range(result.data() + i,
630 (tensor1.data() + tensor1.range().ordinal(i)),
631 (tensors.data() + tensors.range().ordinal(i))...);
659 typename ReduceOp,
typename JoinOp,
typename Scalar,
typename T1,
661 typename std::enable_if_t<
662 is_tensor<T1, Ts...>::value && is_contiguous_tensor<T1, Ts...>::value &&
663 !is_reduce_op_v<std::decay_t<ReduceOp>, std::decay_t<Scalar>,
664 std::decay_t<T1>, std::decay_t<Ts>...>>* =
nullptr>
666 const T1& tensor1,
const Ts&... tensors) {
670 const auto volume = tensor1.range().volume();
673 tensor1.data(), tensors.data()...);
694 typename ReduceOp,
typename JoinOp,
typename Scalar,
typename T1,
696 typename std::enable_if_t<
697 is_tensor<T1, Ts...>::value && is_contiguous_tensor<T1, Ts...>::value &&
698 is_reduce_op_v<std::decay_t<ReduceOp>, std::decay_t<Scalar>,
699 std::decay_t<T1>, std::decay_t<Ts>...>>* =
nullptr>
701 const T1& tensor1,
const Ts&... tensors) {
725 template <
typename ReduceOp,
typename JoinOp,
typename Scalar,
typename T1,
727 typename std::enable_if<
728 is_tensor_of_tensor<T1, Ts...>::value &&
729 is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
731 const T1& tensor1,
const Ts&... tensors) {
735 const auto volume = tensor1.range().volume();
738 for (decltype(tensor1.range().volume()) i = 0ul; i <
volume; ++i) {
741 join_op(result, temp);
766 template <
typename ReduceOp,
typename JoinOp,
typename Scalar,
typename T1,
768 typename std::enable_if<
769 is_tensor<T1, Ts...>::value &&
770 !is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
772 const Scalar
identity,
const T1& tensor1,
773 const Ts&... tensors) {
777 const auto stride =
inner_size(tensor1, tensors...);
778 const auto volume = tensor1.range().volume();
781 for (decltype(tensor1.range().volume()) i = 0ul; i <
volume; i += stride) {
784 tensor1.data() + tensor1.range().ordinal(i),
785 (tensors.data() + tensors.range().ordinal(i))...);
786 join_op(result, temp);
810 template <
typename ReduceOp,
typename JoinOp,
typename Scalar,
typename T1,
812 typename std::enable_if<
813 is_tensor_of_tensor<T1, Ts...>::value &&
814 !is_contiguous_tensor<T1, Ts...>::value>::type* =
nullptr>
816 const Scalar
identity,
const T1& tensor1,
817 const Ts&... tensors) {
821 const auto stride =
inner_size(tensor1, tensors...);
822 const auto volume = tensor1.range().volume();
824 auto tensor_reduce_range =
826 Scalar& MADNESS_RESTRICT result,
827 typename T1::const_pointer MADNESS_RESTRICT
const tensor1_data,
828 typename Ts::const_pointer MADNESS_RESTRICT
const... tensors_data) {
829 for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) {
831 tensor1_data[i], tensors_data[i]...);
832 join_op(result, temp);
837 for (decltype(tensor1.range().volume()) i = 0ul; i <
volume; i += stride) {
839 tensor_reduce_range(result, tensor1.data() + tensor1.range().ordinal(i),
840 (tensors.data() + tensors.range().ordinal(i))...);
841 join_op(result, temp);
850 #endif // TILEDARRAY_TENSOR_KENERLS_H__INCLUDED