permute.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2015 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * permute.h
22  * May 31, 2015
23  *
24  */
25 
26 #ifndef TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED
27 #define TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED
28 
30 #include <TiledArray/perm_index.h>
32 
33 namespace TiledArray {
34 namespace detail {
35 
37 
50 template <typename SizeType, typename ExtentType>
51 inline void fuse_dimensions(SizeType* MADNESS_RESTRICT const fused_size,
52  SizeType* MADNESS_RESTRICT const fused_weight,
53  const ExtentType* MADNESS_RESTRICT const size,
54  const Permutation& perm) {
55  const unsigned int ndim1 = perm.size() - 1u;
56 
57  int i = ndim1;
58  fused_size[3] = size[i--];
59  while ((i >= 0) && (perm[i + 1u] == (perm[i] + 1u)))
60  fused_size[3] *= size[i--];
61  fused_weight[3] = 1u;
62 
63  if ((i >= 0) && (perm[i] != ndim1)) {
64  fused_size[2] = size[i--];
65  while ((i >= 0) && (perm[i] != ndim1)) fused_size[2] *= size[i--];
66 
67  fused_weight[2] = fused_size[3];
68 
69  fused_size[1] = size[i--];
70  while ((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
71  fused_size[1] *= size[i--];
72 
73  fused_weight[1] = fused_size[2] * fused_weight[2];
74  } else {
75  fused_size[2] = 1ul;
76  fused_weight[2] = 0ul;
77 
78  fused_size[1] = size[i--];
79  while ((i >= 0) && (perm[i + 1] == (perm[i] + 1u)))
80  fused_size[1] *= size[i--];
81 
82  fused_weight[1] = fused_size[3];
83  }
84 
85  if (i >= 0) {
86  fused_size[0] = size[i--];
87  while (i >= 0) fused_size[0] *= size[i--];
88 
89  fused_weight[0] = fused_size[1] * fused_weight[1];
90  } else {
91  fused_size[0] = 1ul;
92  fused_weight[0] = 0ul;
93  }
94 }
95 
97 
114 template <typename InputOp, typename OutputOp, typename Result, typename Perm,
115  typename Arg0, typename... Args,
116  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
117 inline void permute(InputOp&& input_op, OutputOp&& output_op, Result& result,
118  const Perm& perm, const Arg0& arg0, const Args&... args) {
119  detail::PermIndex perm_index_op(arg0.range(), outer(perm));
120 
121  // Cache constants
122  const unsigned int ndim = arg0.range().rank();
123  const unsigned int ndim1 = ndim - 1;
124  const auto volume = arg0.range().volume();
125 
126  // Get pointer to arg extent
127  const auto* MADNESS_RESTRICT const arg0_extent = arg0.range().extent_data();
128 
129  if (perm[ndim1] == ndim1) {
130  // This is the simple case where the last dimension is not permuted.
131  // Therefore, it can be shuffled in chunks.
132 
133  // Determine which dimensions can be permuted with the least significant
134  // dimension.
135  typename Result::ordinal_type block_size = arg0_extent[ndim1];
136  for (int i = int(ndim1) - 1; i >= 0; --i) {
137  if (int(perm[i]) != i) break;
138  block_size *= arg0_extent[i];
139  }
140 
141  // Combine the input and output operations
142  auto op = [=](typename Result::pointer result,
143  typename Arg0::const_reference a0,
144  typename Args::const_reference... as) {
145  output_op(result, input_op(a0, as...));
146  };
147 
148  // Permute the data
149  for (typename Result::ordinal_type index = 0ul; index < volume;
150  index += block_size) {
151  const typename Result::ordinal_type perm_index = perm_index_op(index);
152 
153  // Copy the block
154  math::vector_ptr_op(op, block_size, result.data() + perm_index,
155  arg0.data() + index, (args.data() + index)...);
156  }
157 
158  } else {
159  // This is the more complicated case. Here we permute in terms of matrix
160  // transposes. The data layout of the input and output matrices are
161  // chosen such that they both contain stride one dimensions.
162 
163  // Here we partition the n dimensional index space, I, of the permute
164  // tensor with up to four parts
165  // {I_1, ..., I_i, I_i+1, ..., I_j, I_j+1, ..., I_k, I_k+1, ..., I_n}
166  // where the subrange {I_k+1, ..., I_n} is the (fused) inner dimension
167  // in the input tensor, and the subrange {I_i+1, ..., I_j} is the
168  // (fused) inner dimension in the output tensor that has been mapped to
169  // the input tensor. These ranges are used to form a set of matrices in
170  // the input tensor that are transposed and copied to the output tensor.
171  // The remaining (fused) index ranges {I_1, ..., I_i} and
172  // {I_j+1, ..., I_k} are used to form the outer loop around the matrix
173  // transpose operations. These outer ranges may or may not be zero size.
174  typename Result::ordinal_type other_fused_size[4];
175  typename Result::ordinal_type other_fused_weight[4];
176  fuse_dimensions(other_fused_size, other_fused_weight,
177  arg0.range().extent_data(), perm);
178 
179  // Compute the fused stride for the result matrix transpose.
180  const auto* MADNESS_RESTRICT const result_extent =
181  result.range().extent_data();
182  typename Result::ordinal_type result_outer_stride = 1ul;
183  for (unsigned int i = perm[ndim1] + 1u; i < ndim; ++i)
184  result_outer_stride *= result_extent[i];
185 
186  // Copy data from the input to the output matrix via a series of matrix
187  // transposes.
188  for (typename Result::ordinal_type i = 0ul; i < other_fused_size[0]; ++i) {
189  typename Result::ordinal_type index = i * other_fused_weight[0];
190  for (typename Result::ordinal_type j = 0ul; j < other_fused_size[2];
191  ++j, index += other_fused_weight[2]) {
192  // Compute the ordinal index of the input and output matrices.
193  typename Result::ordinal_type perm_index = perm_index_op(index);
194 
195  math::transpose(input_op, output_op, other_fused_size[1],
196  other_fused_size[3], result_outer_stride,
197  result.data() + perm_index, other_fused_weight[1],
198  arg0.data() + index, (args.data() + index)...);
199  }
200  }
201  }
202 }
203 
204 } // namespace detail
205 } // namespace TiledArray
206 
207 #endif // TILEDARRAY_TENSOR_PERMUTE_H__INCLUDED
index_type size() const
Domain size accessor.
Definition: permutation.h:214
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Perm &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
Definition: permute.h:117
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
void fuse_dimensions(SizeType *MADNESS_RESTRICT const fused_size, SizeType *MADNESS_RESTRICT const fused_weight, const ExtentType *MADNESS_RESTRICT const size, const Permutation &perm)
Compute the fused dimensions for permutation.
Definition: permute.h:51
auto outer(const Permutation &p)
Definition: permutation.h:820
void transpose(InputOp &&input_op, OutputOp &&output_op, const std::size_t m, const std::size_t n, const std::size_t result_stride, Result *result, const std::size_t arg_stride, const Args *const ... args)
Matrix transpose and initialization.
Definition: transpose.h:178
A functor that permutes ordinal indices.
Definition: perm_index.h:38
void vector_ptr_op(Op &&op, const std::size_t n, Result *const result, const Args *const ... args)
Definition: vector_op.h:538
size_t volume(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1622