perm_index.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * perm_index.h
22  * Oct 10, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_PERM_INDEX_H__INCLUDED
27 #define TILEDARRAY_PERM_INDEX_H__INCLUDED
28 
29 #include <TiledArray/range.h>
30 
31 namespace TiledArray {
32 namespace detail {
33 
35 
38 class PermIndex {
39  std::size_t* weights_;
40  unsigned int
42  ndim_;
43 
44  public:
46  PermIndex() : weights_(NULL), ndim_(0) {}
47 
49 
51  PermIndex(const Range& range, const Permutation& perm)
52  : weights_(NULL), ndim_(perm.size()) {
53  if (ndim_ > 0) {
54  // Check the input data
55  TA_ASSERT(range.rank() == perm.size());
56 
57  // Construct the inverse permutation
58  const Permutation inv_perm_ = -perm;
59 
60  // Allocate memory for this object
61  weights_ = static_cast<std::size_t*>(
62  malloc((ndim_ + ndim_) * sizeof(std::size_t)));
63  if (!weights_) throw std::bad_alloc();
64 
65  // Construct MADNESS_RESTRICTed pointers to the input data
66  const auto* MADNESS_RESTRICT const inv_perm = &inv_perm_.data().front();
67  const auto* MADNESS_RESTRICT const range_size = range.extent_data();
68  const auto* MADNESS_RESTRICT const range_weight = range.stride_data();
69 
70  // Construct MADNESS_RESTRICTed pointers to the object data
71  std::size_t* MADNESS_RESTRICT const input_weight = weights_;
72  std::size_t* MADNESS_RESTRICT const output_weight = weights_ + ndim_;
73 
74  // Initialize input and output weights
75  std::size_t volume = 1ul;
76  for (int i = int(ndim_) - 1; i >= 0; --i) {
77  // Load input data for iteration i.
78  const auto inv_perm_i = inv_perm[i];
79  const auto weight = range_weight[i];
80  const auto size = range_size[inv_perm_i];
81 
82  // Store the input and output weights
83  output_weight[inv_perm_i] = volume;
84  volume *= size;
85  input_weight[i] = weight;
86  }
87  }
88  }
89 
90  PermIndex(const PermIndex& other) : weights_(NULL), ndim_(other.ndim_) {
91  if (ndim_) {
92  // Allocate memory for this object
93  weights_ = static_cast<std::size_t*>(
94  malloc((ndim_ + ndim_) * sizeof(std::size_t)));
95  if (!weights_) throw std::bad_alloc();
96 
97  // Copy data
98  memcpy(weights_, other.weights_, (ndim_ + ndim_) * sizeof(std::size_t));
99  }
100  }
101 
103  free(weights_);
104  weights_ = NULL;
105  }
106 
107  PermIndex& operator=(const PermIndex& other) {
108  // Deallocate memory
109  if (ndim_ && (ndim_ != other.ndim_)) {
110  free(weights_);
111  weights_ = NULL;
112  }
113 
114  const std::size_t bytes = (other.ndim_ + other.ndim_) * sizeof(std::size_t);
115 
116  if (!weights_ && bytes) {
117  // Allocate new memory
118  weights_ = static_cast<std::size_t*>(malloc(bytes));
119  if (!weights_) throw std::bad_alloc();
120  }
121 
122  // copy the data (safe if ndim_ == 0)
123  ndim_ = other.ndim_;
124  memcpy(weights_, other.weights_, bytes);
125 
126  return *this;
127  }
128 
130 
132  int dim() const { return ndim_; }
133 
135 
137  const std::size_t* data() const { return weights_; }
138 
140  std::size_t operator()(std::size_t index) const {
141  TA_ASSERT(ndim_);
142  TA_ASSERT(weights_);
143 
144  // Construct MADNESS_RESTRICTed pointers to data
145  const std::size_t* MADNESS_RESTRICT const input_weight = weights_;
146  const std::size_t* MADNESS_RESTRICT const output_weight = weights_ + ndim_;
147 
148  // create result index
149  std::size_t perm_index = 0ul;
150 
151  for (unsigned int i = 0u; i < ndim_; ++i) {
152  const std::size_t input_weight_i = input_weight[i];
153  const std::size_t output_weight_i = output_weight[i];
154  perm_index += index / input_weight_i * output_weight_i;
155  index %= input_weight_i;
156  }
157 
158  return perm_index;
159  }
160 
161  // Check for valid permutation
162  operator bool() const { return ndim_; }
163 }; // class PermIndex
164 
165 } // namespace detail
166 } // namespace TiledArray
167 
168 #endif // MADNESS_PERM_INDEX_H__INCLUDED
169 TILEDARRAY_PERM_INDEX_H__INCLUDED
std::size_t operator()(std::size_t index) const
Compute the permuted index for the current block.
Definition: perm_index.h:140
index_type size() const
Domain size accessor.
Definition: permutation.h:214
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
int dim() const
Dimension accessor.
Definition: perm_index.h:132
PermIndex(const PermIndex &other)
Definition: perm_index.h:90
PermIndex & operator=(const PermIndex &other)
Definition: perm_index.h:107
const auto & data() const
Permutation data accessor.
Definition: permutation.h:388
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
A functor that permutes ordinal indices.
Definition: perm_index.h:38
PermIndex()
Default constructor.
Definition: perm_index.h:46
const index1_type * stride_data() const
Range stride data accessor.
Definition: range.h:759
PermIndex(const Range &range, const Permutation &perm)
Construct permuting functor.
Definition: perm_index.h:51
size_t volume(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1622
const index1_type * extent_data() const
Range extent data accessor.
Definition: range.h:735
unsigned int rank() const
Rank accessor.
Definition: range.h:669
const std::size_t * data() const
Data accessor.
Definition: perm_index.h:137
A (hyperrectangular) interval on , space of integer -indices.
Definition: range.h:46