26 #ifndef TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED
27 #define TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED
50 template <
typename SizeType,
typename ExtentType>
52 SizeType* MADNESS_RESTRICT
const fused_weight,
53 const ExtentType* MADNESS_RESTRICT
const size,
55 const unsigned int ndim1 = perm.
size() - 1u;
58 fused_size[3] = size[i--];
59 while ((i >= 0) && (perm[i + 1u] == (perm[i] + 1u)))
60 fused_size[3] *= size[i--];
63 if ((i >= 0) && (perm[i] != ndim1)) {
64 fused_size[2] = size[i--];
65 while ((i >= 0) && (perm[i] != ndim1)) fused_size[2] *= size[i--];
67 fused_weight[2] = fused_size[3];
69 fused_size[1] = size[i--];
70 while ((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
71 fused_size[1] *= size[i--];
73 fused_weight[1] = fused_size[2] * fused_weight[2];
76 fused_weight[2] = 0ul;
78 fused_size[1] = size[i--];
79 while ((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
80 fused_size[1] *= size[i--];
82 fused_weight[1] = fused_size[3];
86 fused_size[0] = size[i--];
87 while (i >= 0) fused_size[0] *= size[i--];
89 fused_weight[0] = fused_size[1] * fused_weight[1];
92 fused_weight[0] = 0ul;
114 template <
typename InputOp,
typename OutputOp,
typename Result,
typename Perm,
115 typename Arg0,
typename... Args,
116 typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
117 inline void permute(InputOp&& input_op, OutputOp&& output_op, Result& result,
118 const Perm& perm,
const Arg0& arg0,
const Args&... args) {
122 const unsigned int ndim = arg0.range().rank();
123 const unsigned int ndim1 = ndim - 1;
124 const auto volume = arg0.range().volume();
127 const auto* MADNESS_RESTRICT
const arg0_extent = arg0.range().extent_data();
129 if (perm[ndim1] == ndim1) {
135 typename Result::ordinal_type block_size = arg0_extent[ndim1];
136 for (
int i =
int(ndim1) - 1; i >= 0; --i) {
137 if (
int(perm[i]) != i)
break;
138 block_size *= arg0_extent[i];
142 auto op = [=](
typename Result::pointer result,
143 typename Arg0::const_reference a0,
144 typename Args::const_reference... as) {
145 output_op(result, input_op(a0, as...));
149 for (
typename Result::ordinal_type index = 0ul; index <
volume;
150 index += block_size) {
151 const typename Result::ordinal_type perm_index = perm_index_op(index);
155 arg0.data() + index, (args.data() + index)...);
174 typename Result::ordinal_type other_fused_size[4];
175 typename Result::ordinal_type other_fused_weight[4];
177 arg0.range().extent_data(), perm);
180 const auto* MADNESS_RESTRICT
const result_extent =
181 result.range().extent_data();
182 typename Result::ordinal_type result_outer_stride = 1ul;
183 for (
unsigned int i = perm[ndim1] + 1u; i < ndim; ++i)
184 result_outer_stride *= result_extent[i];
188 for (
typename Result::ordinal_type i = 0ul; i < other_fused_size[0]; ++i) {
189 typename Result::ordinal_type index = i * other_fused_weight[0];
190 for (
typename Result::ordinal_type j = 0ul; j < other_fused_size[2];
191 ++j, index += other_fused_weight[2]) {
193 typename Result::ordinal_type perm_index = perm_index_op(index);
196 other_fused_size[3], result_outer_stride,
197 result.data() + perm_index, other_fused_weight[1],
198 arg0.data() + index, (args.data() + index)...);
207 #endif // TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED