kernels.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2015 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * kernels.h
22  * Jun 1, 2015
23  *
24  */
25 
26 #ifndef TILEDARRAY_TENSOR_KENERLS_H__INCLUDED
27 #define TILEDARRAY_TENSOR_KENERLS_H__INCLUDED
28 
31 
32 namespace TiledArray {
33 
34 template <typename, typename>
35 class Tensor;
36 
37 namespace detail {
38 
41 template <typename T>
42 struct transform;
43 
44 // -------------------------------------------------------------------------
45 // Tensor kernel operations that generate a new tensor
46 
48 
62 template <typename TR, typename Op, typename T1, typename... Ts,
63  typename std::enable_if<
64  is_tensor<TR, T1, Ts...>::value ||
65  is_tensor_of_tensor<TR, T1, Ts...>::value>::type* = nullptr>
66 inline TR tensor_op(Op&& op, const T1& tensor1, const Ts&... tensors) {
67  if constexpr (std::is_invocable_r_v<TR, Op, const T1&, const Ts&...>) {
68  return std::forward<Op>(op)(tensor1, tensors...);
69  } else {
70  return TiledArray::detail::transform<TR>()(std::forward<Op>(op), tensor1,
71  tensors...);
72  }
73  abort(); // unreachable
74 }
75 
77 
93 template <typename TR, typename Op, typename T1, typename... Ts,
94  typename std::enable_if<
95  (is_tensor<T1, Ts...>::value ||
96  is_tensor_of_tensor<TR, T1, Ts...>::value) &&
97  is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
98 inline TR tensor_op(Op&& op, const Permutation& perm, const T1& tensor1,
99  const Ts&... tensors) {
100  if constexpr (std::is_invocable_r_v<TR, Op, const Permutation&, const T1&,
101  const Ts&...>) {
102  return std::forward<Op>(op)(perm, tensor1, tensors...);
103  } else {
104  return TiledArray::detail::transform<TR>()(std::forward<Op>(op), perm,
105  tensor1, tensors...);
106  }
107 }
108 
116 template <typename T>
117 struct transform {
120  template <typename Op, typename Tensor, typename... Tensors>
121  T operator()(Op&& op, Tensor&& tensor, Tensors&&... tensors) const {
122  TA_ASSERT(!empty(tensor, tensors...));
123  TA_ASSERT(is_range_set_congruent(tensor, tensors...));
124 
125  const auto& range = tensor.range();
126  T result(range);
127  this->operator()(result, std::forward<Op>(op), std::forward<Tensor>(tensor),
128  std::forward<Tensors>(tensors)...);
129  return result;
130  }
131 
134  template <typename Op, typename Tensor, typename... Tensors>
135  void operator()(T& result, Op&& op, Tensor&& tensor,
136  Tensors&&... tensors) const {
137  TA_ASSERT(!empty(result, tensor, tensors...));
138  TA_ASSERT(is_range_set_congruent(result, tensor, tensors...));
139 
140  const auto& range = result.range();
141  for (auto&& i : range)
142  result[std::forward<decltype(i)>(i)] = std::forward<Op>(op)(
143  std::forward<Tensor>(tensor)[std::forward<decltype(i)>(i)],
144  std::forward<Tensors>(tensors)[std::forward<decltype(i)>(i)]...);
145  }
146 
147  template <typename Op, typename Tensor, typename... Tensors>
148  T operator()(Op&& op, const Permutation& perm, Tensor&& tensor,
149  Tensors&&... tensors) const {
150  TA_ASSERT(!empty(tensor, tensors...));
151  TA_ASSERT(is_range_set_congruent(tensor, tensors...));
152  TA_ASSERT(perm);
153  TA_ASSERT(perm.size() == tensor.range().rank());
154 
155  const auto& range = tensor.range();
156  T result(perm ^ range);
157  this->operator()(result, std::forward<Op>(op), perm,
158  std::forward<Tensor>(tensor),
159  std::forward<Tensors>(tensors)...);
160  return result;
161  }
162 
163  template <typename Op, typename Tensor, typename... Tensors>
164  void operator()(T& result, Op&& op, const Permutation& perm, Tensor&& tensor,
165  Tensors&&... tensors) const {
166  TA_ASSERT(!empty(result, tensor, tensors...));
167  TA_ASSERT(is_range_congruent(result, tensor, perm));
168  TA_ASSERT(is_range_set_congruent(tensor, tensors...));
169  TA_ASSERT(perm);
170  TA_ASSERT(perm.size() == tensor.range().rank());
171 
172  const auto& range = tensor.range();
173  for (auto&& i : range)
174  result[perm ^ std::forward<decltype(i)>(i)] = std::forward<Op>(op)(
175  std::forward<Tensor>(tensor)[std::forward<decltype(i)>(i)],
176  std::forward<Tensors>(tensors)[std::forward<decltype(i)>(i)]...);
177  }
178 };
179 
180 // -------------------------------------------------------------------------
181 // Tensor kernel operations with in-place memory operations
182 
184 
193 template <typename Op, typename TR, typename... Ts,
194  typename std::enable_if<
195  is_tensor<TR, Ts...>::value &&
196  is_contiguous_tensor<TR, Ts...>::value>::type* = nullptr>
197 inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
198  TA_ASSERT(!empty(result, tensors...));
199  TA_ASSERT(is_range_set_congruent(result, tensors...));
200 
201  const auto volume = result.range().volume();
202 
203  math::inplace_vector_op(std::forward<Op>(op), volume, result.data(),
204  tensors.data()...);
205 }
206 
208 
217 template <typename Op, typename TR, typename... Ts,
218  typename std::enable_if<
219  is_tensor_of_tensor<TR, Ts...>::value &&
220  is_contiguous_tensor<TR, Ts...>::value>::type* = nullptr>
221 inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
222  TA_ASSERT(!empty(result, tensors...));
223  TA_ASSERT(is_range_set_congruent(result, tensors...));
224 
225  const auto volume = result.range().volume();
226 
227  for (decltype(result.range().volume()) i = 0ul; i < volume; ++i) {
228  inplace_tensor_op(op, result[i], tensors[i]...);
229  }
230 }
231 
233 
257 template <typename InputOp, typename OutputOp, typename TR, typename T1,
258  typename... Ts,
259  typename std::enable_if<
260  is_tensor<TR, T1, Ts...>::value &&
261  is_contiguous_tensor<TR, T1, Ts...>::value>::type* = nullptr>
262 inline void inplace_tensor_op(InputOp&& input_op, OutputOp&& output_op,
263  const Permutation& perm, TR& result,
264  const T1& tensor1, const Ts&... tensors) {
265  TA_ASSERT(!empty(result, tensor1, tensors...));
266  TA_ASSERT(is_range_congruent(result, tensor1, perm));
267  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
268  TA_ASSERT(perm);
269  TA_ASSERT(perm.size() == tensor1.range().rank());
270 
271  permute(std::forward<InputOp>(input_op), std::forward<OutputOp>(output_op),
272  result, perm, tensor1, tensors...);
273 }
274 
276 
309 template <typename InputOp, typename OutputOp, typename TR, typename T1,
310  typename... Ts,
311  typename std::enable_if<
312  is_tensor_of_tensor<TR, T1, Ts...>::value &&
313  is_contiguous_tensor<TR, T1, Ts...>::value>::type* = nullptr>
314 inline void inplace_tensor_op(InputOp&& input_op, OutputOp&& output_op,
315  const Permutation& perm, TR& result,
316  const T1& tensor1, const Ts&... tensors) {
317  TA_ASSERT(!empty(result, tensor1, tensors...));
318  TA_ASSERT(is_range_congruent(result, tensor1, perm));
319  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
320  TA_ASSERT(perm);
321  TA_ASSERT(perm.size() == tensor1.range().rank());
322 
323  auto wrapper_input_op =
324  [&input_op](typename T1::const_reference MADNESS_RESTRICT value1,
325  typename Ts::const_reference MADNESS_RESTRICT... values) ->
326  typename T1::value_type {
327  return tensor_op<TR::value_type>(std::forward<InputOp>(input_op),
328  value1, values...);
329  };
330 
331  auto wrapper_output_op =
332  [&output_op](typename T1::pointer MADNESS_RESTRICT const result_value,
333  const typename TR::value_type value) {
334  inplace_tensor_op(std::forward<OutputOp>(output_op), *result_value,
335  value);
336  };
337 
338  permute(std::move(wrapper_input_op), std::move(wrapper_output_op), result,
339  perm, tensor1, tensors...);
340 }
341 
343 
352 template <typename Op, typename TR, typename... Ts,
353  typename std::enable_if<
354  is_tensor<TR, Ts...>::value &&
355  !(is_contiguous_tensor<TR, Ts...>::value)>::type* = nullptr>
356 inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
357  TA_ASSERT(!empty(result, tensors...));
358  TA_ASSERT(is_range_set_congruent(result, tensors...));
359 
360  const auto stride = inner_size(result, tensors...);
361  const auto volume = result.range().volume();
362 
363  for (decltype(result.range().volume()) i = 0ul; i < volume; i += stride)
364  math::inplace_vector_op(std::forward<Op>(op), stride,
365  result.data() + result.range().ordinal(i),
366  (tensors.data() + tensors.range().ordinal(i))...);
367 }
368 
370 
379 template <typename Op, typename TR, typename... Ts,
380  typename std::enable_if<
381  is_tensor_of_tensor<TR, Ts...>::value &&
382  !(is_contiguous_tensor<TR, Ts...>::value)>::type* = nullptr>
383 inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
384  TA_ASSERT(!empty(result, tensors...));
385  TA_ASSERT(is_range_set_congruent(result, tensors...));
386 
387  const auto stride = inner_size(result, tensors...);
388  const auto volume = result.range().volume();
389 
390  auto inplace_tensor_range =
391  [&op, stride](
392  typename TR::pointer MADNESS_RESTRICT const result_data,
393  typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) {
394  for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
395  inplace_tensor_op(op, result_data[i], tensors_data[i]...);
396  };
397 
398  for (decltype(result.range().volume()) i = 0ul; i < volume; i += stride)
399  inplace_tensor_range(result.data() + result.range().ordinal(i),
400  (tensors.data() + tensors.range().ordinal(i))...);
401 }
402 
403 // -------------------------------------------------------------------------
404 // Tensor initialization functions for argument tensors with contiguous
405 // memory layout
406 
408 
417 template <typename Op, typename TR, typename... Ts,
418  typename std::enable_if<
419  is_tensor<TR, Ts...>::value &&
420  is_contiguous_tensor<TR, Ts...>::value>::type* = nullptr>
421 inline void tensor_init(Op&& op, TR& result, const Ts&... tensors) {
422  TA_ASSERT(!empty(result, tensors...));
423  TA_ASSERT(is_range_set_congruent(result, tensors...));
424 
425  const auto volume = result.range().volume();
426 
427  auto wrapper_op = [&op](typename TR::pointer MADNESS_RESTRICT result,
428  typename Ts::const_reference MADNESS_RESTRICT... ts) {
429  new (result) typename TR::value_type(std::forward<Op>(op)(ts...));
430  };
431 
432  math::vector_ptr_op(std::move(wrapper_op), volume, result.data(),
433  tensors.data()...);
434 }
435 
437 
448 template <
449  typename Op, typename TR, typename... Ts,
450  typename std::enable_if<is_tensor_of_tensor<TR, Ts...>::value &&
451  is_contiguous_tensor<TR>::value>::type* = nullptr>
452 inline void tensor_init(Op&& op, TR& result, const Ts&... tensors) {
453  TA_ASSERT(!empty(result, tensors...));
454  TA_ASSERT(is_range_set_congruent(result, tensors...));
455 
456  const auto volume = result.range().volume();
457 
458  for (decltype(result.range().volume()) i = 0ul; i < volume; ++i) {
459  new (result.data() + i) typename TR::value_type(
460  tensor_op<typename TR::value_type>(op, tensors[i]...));
461  }
462 }
463 
465 
485 template <
486  typename Op, typename TR, typename T1, typename... Ts,
487  typename std::enable_if<is_tensor<TR, T1, Ts...>::value>::type* = nullptr>
488 inline void tensor_init(Op&& op, const Permutation& perm, TR& result,
489  const T1& tensor1, const Ts&... tensors) {
490  TA_ASSERT(!empty(result, tensor1, tensors...));
491  TA_ASSERT(is_range_set_congruent(perm, result, tensor1, tensors...));
492  TA_ASSERT(perm);
493  TA_ASSERT(perm.size() == result.range().rank());
494 
495  auto output_op = [](typename TR::pointer MADNESS_RESTRICT result,
496  typename TR::const_reference MADNESS_RESTRICT temp) {
497  new (result) typename TR::value_type(temp);
498  };
499 
500  permute(std::forward<Op>(op), std::move(output_op), result, perm, tensor1,
501  tensors...);
502 }
503 
505 
519 template <typename Op, typename TR, typename T1, typename... Ts,
520  typename std::enable_if<
521  is_tensor_of_tensor<TR, T1, Ts...>::value>::type* = nullptr>
522 inline void tensor_init(Op&& op, const Permutation& perm, TR& result,
523  const T1& tensor1, const Ts&... tensors) {
524  TA_ASSERT(!empty(result, tensor1, tensors...));
525  TA_ASSERT(is_range_set_congruent(perm, result, tensor1, tensors...));
526  TA_ASSERT(perm);
527  TA_ASSERT(perm.size() == result.range().rank());
528 
529  auto output_op = [](typename TR::pointer MADNESS_RESTRICT result,
530  typename TR::const_reference MADNESS_RESTRICT temp) {
531  new (result) typename TR::value_type(temp);
532  };
533  auto tensor_input_op =
534  [&op](typename T1::const_reference MADNESS_RESTRICT value1,
535  typename Ts::const_reference MADNESS_RESTRICT... values) ->
536  typename TR::value_type {
537  return tensor_op<typename TR::value_type>(std::forward<Op>(op), value1,
538  values...);
539  };
540 
541  permute(std::move(tensor_input_op), output_op, result, perm, tensor1,
542  tensors...);
543 }
544 
546 
562 template <
563  typename Op, typename TR, typename T1, typename... Ts,
564  typename std::enable_if<
565  is_tensor<TR, T1, Ts...>::value && is_contiguous_tensor<TR>::value &&
566  !is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
567 inline void tensor_init(Op&& op, TR& result, const T1& tensor1,
568  const Ts&... tensors) {
569  TA_ASSERT(!empty(result, tensor1, tensors...));
570  TA_ASSERT(is_range_set_congruent(result, tensor1, tensors...));
571 
572  const auto stride = inner_size(tensor1, tensors...);
573  const auto volume = tensor1.range().volume();
574 
575  auto wrapper_op = [&op](typename TR::pointer MADNESS_RESTRICT result_ptr,
576  const typename T1::value_type value1,
577  const typename Ts::value_type... values) {
578  new (result_ptr) typename T1::value_type(op(value1, values...));
579  };
580 
581  for (decltype(tensor1.range().volume()) i = 0ul; i < volume; i += stride)
582  math::vector_ptr_op(wrapper_op, stride, result.data() + i,
583  (tensor1.data() + tensor1.range().ordinal(i)),
584  (tensors.data() + tensors.range().ordinal(i))...);
585 }
586 
588 
604 template <typename Op, typename TR, typename T1, typename... Ts,
605  typename std::enable_if<
606  is_tensor_of_tensor<TR, T1, Ts...>::value &&
607  is_contiguous_tensor<TR>::value &&
608  !is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
609 inline void tensor_init(Op&& op, TR& result, const T1& tensor1,
610  const Ts&... tensors) {
611  TA_ASSERT(!empty(result, tensor1, tensors...));
612  TA_ASSERT(is_range_set_congruent(result, tensor1, tensors...));
613 
614  const auto stride = inner_size(tensor1, tensors...);
615  const auto volume = tensor1.range().volume();
616 
617  auto inplace_tensor_range =
618  [&op, stride](
619  typename TR::pointer MADNESS_RESTRICT const result_data,
620  typename T1::const_pointer MADNESS_RESTRICT const tensor1_data,
621  typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) {
622  for (decltype(result.range().volume()) i = 0ul; i < stride; ++i)
623  new (result_data + i)
624  typename TR::value_type(tensor_op<typename TR::value_type>(
625  op, tensor1_data[i], tensors_data[i]...));
626  };
627 
628  for (decltype(volume) i = 0ul; i < volume; i += stride)
629  inplace_tensor_range(result.data() + i,
630  (tensor1.data() + tensor1.range().ordinal(i)),
631  (tensors.data() + tensors.range().ordinal(i))...);
632 }
633 
634 // -------------------------------------------------------------------------
635 // Reduction kernels for argument tensors
636 
638 
658 template <
659  typename ReduceOp, typename JoinOp, typename Scalar, typename T1,
660  typename... Ts,
661  typename std::enable_if_t<
662  is_tensor<T1, Ts...>::value && is_contiguous_tensor<T1, Ts...>::value &&
663  !is_reduce_op_v<std::decay_t<ReduceOp>, std::decay_t<Scalar>,
664  std::decay_t<T1>, std::decay_t<Ts>...>>* = nullptr>
665 Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, Scalar identity,
666  const T1& tensor1, const Ts&... tensors) {
667  TA_ASSERT(!empty(tensor1, tensors...));
668  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
669 
670  const auto volume = tensor1.range().volume();
671 
673  tensor1.data(), tensors.data()...);
674 
675  return identity;
676 }
677 
679 
693 template <
694  typename ReduceOp, typename JoinOp, typename Scalar, typename T1,
695  typename... Ts,
696  typename std::enable_if_t<
697  is_tensor<T1, Ts...>::value && is_contiguous_tensor<T1, Ts...>::value &&
698  is_reduce_op_v<std::decay_t<ReduceOp>, std::decay_t<Scalar>,
699  std::decay_t<T1>, std::decay_t<Ts>...>>* = nullptr>
700 Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, Scalar identity,
701  const T1& tensor1, const Ts&... tensors) {
702  reduce_op(identity, &tensor1, &tensors...);
703  return identity;
704 }
705 
707 
725 template <typename ReduceOp, typename JoinOp, typename Scalar, typename T1,
726  typename... Ts,
727  typename std::enable_if<
728  is_tensor_of_tensor<T1, Ts...>::value &&
729  is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
730 Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, Scalar identity,
731  const T1& tensor1, const Ts&... tensors) {
732  TA_ASSERT(!empty(tensor1, tensors...));
733  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
734 
735  const auto volume = tensor1.range().volume();
736 
737  auto result = identity;
738  for (decltype(tensor1.range().volume()) i = 0ul; i < volume; ++i) {
739  auto temp =
740  tensor_reduce(reduce_op, join_op, identity, tensor1[i], tensors[i]...);
741  join_op(result, temp);
742  }
743 
744  return result;
745 }
746 
748 
766 template <typename ReduceOp, typename JoinOp, typename Scalar, typename T1,
767  typename... Ts,
768  typename std::enable_if<
769  is_tensor<T1, Ts...>::value &&
770  !is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
771 Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,
772  const Scalar identity, const T1& tensor1,
773  const Ts&... tensors) {
774  TA_ASSERT(!empty(tensor1, tensors...));
775  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
776 
777  const auto stride = inner_size(tensor1, tensors...);
778  const auto volume = tensor1.range().volume();
779 
780  Scalar result = identity;
781  for (decltype(tensor1.range().volume()) i = 0ul; i < volume; i += stride) {
782  Scalar temp = identity;
783  math::reduce_op(reduce_op, join_op, identity, stride, temp,
784  tensor1.data() + tensor1.range().ordinal(i),
785  (tensors.data() + tensors.range().ordinal(i))...);
786  join_op(result, temp);
787  }
788 
789  return result;
790 }
791 
793 
810 template <typename ReduceOp, typename JoinOp, typename Scalar, typename T1,
811  typename... Ts,
812  typename std::enable_if<
813  is_tensor_of_tensor<T1, Ts...>::value &&
814  !is_contiguous_tensor<T1, Ts...>::value>::type* = nullptr>
815 Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,
816  const Scalar identity, const T1& tensor1,
817  const Ts&... tensors) {
818  TA_ASSERT(!empty(tensor1, tensors...));
819  TA_ASSERT(is_range_set_congruent(tensor1, tensors...));
820 
821  const auto stride = inner_size(tensor1, tensors...);
822  const auto volume = tensor1.range().volume();
823 
824  auto tensor_reduce_range =
825  [&reduce_op, &join_op, &identity, stride](
826  Scalar& MADNESS_RESTRICT result,
827  typename T1::const_pointer MADNESS_RESTRICT const tensor1_data,
828  typename Ts::const_pointer MADNESS_RESTRICT const... tensors_data) {
829  for (decltype(result.range().volume()) i = 0ul; i < stride; ++i) {
830  Scalar temp = tensor_reduce(reduce_op, join_op, identity,
831  tensor1_data[i], tensors_data[i]...);
832  join_op(result, temp);
833  }
834  };
835 
836  Scalar result = identity;
837  for (decltype(tensor1.range().volume()) i = 0ul; i < volume; i += stride) {
838  Scalar temp =
839  tensor_reduce_range(result, tensor1.data() + tensor1.range().ordinal(i),
840  (tensors.data() + tensors.range().ordinal(i))...);
841  join_op(result, temp);
842  }
843 
844  return identity;
845 }
846 
847 } // namespace detail
848 } // namespace TiledArray
849 
850 #endif // TILEDARRAY_TENSOR_KENERLS_H__INCLUDED
Scalar tensor_reduce(ReduceOp &&reduce_op, JoinOp &&join_op, Scalar identity, const T1 &tensor1, const Ts &... tensors)
Reduction operation for contiguous tensors.
Definition: kernels.h:665
::blas::Op Op
Definition: blas.h:46
constexpr bool is_range_set_congruent(const Permutation &perm, const T &tensor)
Test that the ranges of a permuted tensor is congruent with itself.
Definition: utility.h:130
index_type size() const
Domain size accessor.
Definition: permutation.h:214
T operator()(Op &&op, const Permutation &perm, Tensor &&tensor, Tensors &&... tensors) const
Definition: kernels.h:148
bool is_range_congruent(const Left &left, const ShiftWrapper< Right > &right)
Check for congruent range objects with a shifted tensor.
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Perm &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
Definition: permute.h:117
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
TR tensor_op(Op &&op, const T1 &tensor1, const Ts &... tensors)
Tensor operations with contiguous data.
Definition: kernels.h:66
void reduce_op(ReduceOp &&reduce_op, JoinOp &&join_op, const Result &identity, const std::size_t n, Result &result, const Args *const ... args)
Definition: vector_op.h:628
void operator()(T &result, Op &&op, const Permutation &perm, Tensor &&tensor, Tensors &&... tensors) const
Definition: kernels.h:164
T operator()(Op &&op, Tensor &&tensor, Tensors &&... tensors) const
Definition: kernels.h:121
void inplace_vector_op(Op &&op, const std::size_t n, Result *const result, const Args *const ... args)
Definition: vector_op.h:391
T1::size_type inner_size(const T1 &tensor1, const T2 &)
Get the inner size of two tensors.
Definition: utility.h:260
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
constexpr bool empty()
Test for empty tensors in an empty list.
Definition: utility.h:320
void vector_ptr_op(Op &&op, const std::size_t n, Result *const result, const Args *const ... args)
Definition: vector_op.h:538
size_t volume(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1622
void inplace_tensor_op(Op &&op, TR &result, const Ts &... tensors)
In-place tensor operations with contiguous data.
Definition: kernels.h:197
void tensor_init(Op &&op, TR &result, const Ts &... tensors)
Initialize tensor with contiguous tensor arguments.
Definition: kernels.h:421
An N-dimensional tensor object.
Definition: tensor.h:50
T identity()
identity for group of objects of type T
void operator()(T &result, Op &&op, Tensor &&tensor, Tensors &&... tensors) const
Definition: kernels.h:135