26 #ifndef TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED 27 #define TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED 50 template <
typename SizeType>
52 SizeType * MADNESS_RESTRICT
const fused_weight,
55 const unsigned int ndim1 = perm.
dim() - 1u;
58 fused_size[3] =
size[i--];
59 while((i >= 0) && (perm[i + 1u] == (perm[i] + 1u)))
60 fused_size[3] *=
size[i--];
64 if((i >= 0) && (perm[i] != ndim1)) {
65 fused_size[2] =
size[i--];
66 while((i >= 0) && (perm[i] != ndim1))
67 fused_size[2] *=
size[i--];
69 fused_weight[2] = fused_size[3];
71 fused_size[1] =
size[i--];
72 while((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
73 fused_size[1] *=
size[i--];
75 fused_weight[1] = fused_size[2] * fused_weight[2];
78 fused_weight[2] = 0ul;
80 fused_size[1] =
size[i--];
81 while((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
82 fused_size[1] *=
size[i--];
84 fused_weight[1] = fused_size[3];
88 fused_size[0] =
size[i--];
90 fused_size[0] *=
size[i--];
92 fused_weight[0] = fused_size[1] * fused_weight[1];
95 fused_weight[0] = 0ul;
120 template <
typename InputOp,
typename OutputOp,
typename Result,
121 typename Arg0,
typename... Args>
122 inline void permute(InputOp&& input_op, OutputOp&& output_op, Result& result,
123 const Permutation& perm,
const Arg0& arg0,
const Args&... args)
128 const unsigned int ndim = arg0.range().rank();
129 const unsigned int ndim1 = ndim - 1;
130 const typename Result::size_type volume = arg0.range().volume();
133 const auto* MADNESS_RESTRICT
const arg0_extent = arg0.range().extent_data();
135 if(perm[ndim1] == ndim1) {
141 typename Result::size_type block_size = arg0_extent[ndim1];
142 for(
int i =
int(ndim1) - 1 ; i >= 0; --i) {
143 if(
int(perm[i]) != i)
145 block_size *= arg0_extent[i];
149 auto op = [=] (
typename Result::pointer result,
150 typename Arg0::const_reference a0,
typename Args::const_reference... as)
151 { output_op(result, input_op(a0, as...)); };
154 for(
typename Result::size_type index = 0ul; index < volume; index += block_size) {
155 const typename Result::size_type perm_index = perm_index_op(index);
159 arg0.data() + index, (args.data() + index)...);
178 typename Result::size_type other_fused_size[4];
179 typename Result::size_type other_fused_weight[4];
181 arg0.range().extent_data(), perm);
184 const auto* MADNESS_RESTRICT
const result_extent = result.range().extent_data();
185 typename Result::size_type result_outer_stride = 1ul;
186 for(
unsigned int i = perm[ndim1] + 1u; i < ndim; ++i)
187 result_outer_stride *= result_extent[i];
191 for(
typename Result::size_type i = 0ul; i < other_fused_size[0]; ++i) {
192 typename Result::size_type index = i * other_fused_weight[0];
193 for(
typename Result::size_type j = 0ul; j < other_fused_size[2]; ++j, index += other_fused_weight[2]) {
195 typename Result::size_type perm_index = perm_index_op(index);
198 other_fused_size[1], other_fused_size[3],
199 result_outer_stride, result.data() + perm_index,
200 other_fused_weight[1], arg0.data() + index, (args.data() + index)...);
210 #endif // TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED A functor that permutes ordinal indices.
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Permutation &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
index_type dim() const
Domain size accessor.
constexpr std::size_t size(T(&)[N])
Array size accessor.
void fuse_dimensions(SizeType *MADNESS_RESTRICT const fused_size, SizeType *MADNESS_RESTRICT const fused_weight, const SizeType *MADNESS_RESTRICT const size, const Permutation &perm)
Compute the fused dimensions for permutation.
Permutation of a sequence of objects indexed by base-0 indices.
void transpose(InputOp &&input_op, OutputOp &&output_op, const std::size_t m, const std::size_t n, const std::size_t result_stride, Result *result, const std::size_t arg_stride, const Args *const ... args)
Matrix transpose and initialization.
void vector_ptr_op(Op &&op, const std::size_t n, Result *const result, const Args *const ... args)