gemm_helper.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * gemm_helper.h
22  * Jan 20, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
27 #define TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
28 
29 #include <TiledArray/error.h>
30 #include <TiledArray/math/blas.h>
31 
32 #include <vector>
33 
34 namespace TiledArray::math {
35 
37 
40 class GemmHelper {
41  private:
42  blas::Op left_op_;
44  blas::Op right_op_;
46  unsigned int result_rank_;
47 
49 
53  struct ContractArg {
54  unsigned int inner[2];
55  unsigned int outer[2];
56  unsigned int rank;
57  } left_,
58  right_;
59 
60  public:
62  const unsigned int result_rank, const unsigned int left_rank,
63  const unsigned int right_rank)
64  : left_op_(left_op),
65  right_op_(right_op),
66  result_rank_(result_rank),
67  left_(),
68  right_() {
69  // Compute the number of contracted dimensions in left and right.
70  TA_ASSERT(((left_rank + right_rank - result_rank) % 2u) == 0u);
71 
72  left_.rank = left_rank;
73  right_.rank = right_rank;
74  const unsigned int contract_size = num_contract_ranks();
75 
76  // Store the inner and outer dimension ranges for the left-hand argument.
77  if (left_op == blas::NoTranspose) {
78  left_.outer[0] = 0u;
79  left_.outer[1] = left_.inner[0] = left_rank - contract_size;
80  left_.inner[1] = left_rank;
81  } else {
82  left_.inner[0] = 0ul;
83  left_.inner[1] = left_.outer[0] = contract_size;
84  left_.outer[1] = left_rank;
85  }
86 
87  // Store the inner and outer dimension ranges for the right-hand argument.
88  if (right_op == blas::NoTranspose) {
89  right_.inner[0] = 0u;
90  right_.inner[1] = right_.outer[0] = contract_size;
91  right_.outer[1] = right_rank;
92  } else {
93  right_.outer[0] = 0u;
94  right_.outer[1] = right_.inner[0] = right_rank - contract_size;
95  right_.inner[1] = right_rank;
96  }
97  }
98 
100 
103  GemmHelper(const GemmHelper& other)
104  : left_op_(other.left_op_),
105  right_op_(other.right_op_),
106  result_rank_(other.result_rank_),
107  left_(other.left_),
108  right_(other.right_) {}
109 
111 
113  GemmHelper& operator=(const GemmHelper& other) {
114  left_op_ = other.left_op_;
115  right_op_ = other.right_op_;
116  result_rank_ = other.result_rank_;
117  left_ = other.left_;
118  right_ = other.right_;
119 
120  return *this;
121  }
122 
124 
126  unsigned int num_contract_ranks() const {
127  return (left_.rank + right_.rank - result_rank_) >> 1;
128  }
129 
131 
133  unsigned int result_rank() const { return result_rank_; }
134 
136 
138  unsigned int left_rank() const { return left_.rank; }
139 
141 
143  unsigned int right_rank() const { return right_.rank; }
144 
145  unsigned int left_inner_begin() const { return left_.inner[0]; }
146  unsigned int left_inner_end() const { return left_.inner[1]; }
147  unsigned int left_outer_begin() const { return left_.outer[0]; }
148  unsigned int left_outer_end() const { return left_.outer[1]; }
149 
150  unsigned int right_inner_begin() const { return right_.inner[0]; }
151  unsigned int right_inner_end() const { return right_.inner[1]; }
152  unsigned int right_outer_begin() const { return right_.outer[0]; }
153  unsigned int right_outer_end() const { return right_.outer[1]; }
154 
156 
164  template <typename R, typename Left, typename Right>
165  R make_result_range(const Left& left, const Right& right) const {
166  // Get pointers to lower and upper bounds of left and right.
167  const auto* MADNESS_RESTRICT const left_lower = left.lobound_data();
168  const auto* MADNESS_RESTRICT const left_upper = left.upbound_data();
169  const auto* MADNESS_RESTRICT const right_lower = right.lobound_data();
170  const auto* MADNESS_RESTRICT const right_upper = right.upbound_data();
171 
172  // Create the start and finish indices
173  std::vector<std::size_t> lower, upper;
174  lower.reserve(result_rank_);
175  upper.reserve(result_rank_);
176 
177  // Copy left-hand argument outer dimensions to start and finish
178  for (unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i) {
179  lower.push_back(left_lower[i]);
180  upper.push_back(left_upper[i]);
181  }
182 
183  // Copy right-hand argument outer dimensions to start and finish
184  for (unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i) {
185  lower.push_back(right_lower[i]);
186  upper.push_back(right_upper[i]);
187  }
188 
189  // Construct the result tile range
190  return R(lower, upper);
191  }
192 
195 
204  template <typename Left, typename Result>
205  bool left_result_congruent(const Left& left, const Result& result) const {
206  return std::equal(left + left_.outer[0], left + left_.outer[1], result);
207  }
208 
211 
220  template <typename Right, typename Result>
221  bool right_result_congruent(const Right& right, const Result& result) const {
222  return std::equal(right + right_.outer[0], right + right_.outer[1],
223  result + (left_.outer[1] - left_.outer[0]));
224  }
225 
228 
237  template <typename Left, typename Right>
238  bool left_right_congruent(const Left& left, const Right& right) const {
239  return std::equal(left + left_.inner[0], left + left_.inner[1],
240  right + right_.inner[0]);
241  }
242 
244 
253  template <typename Left, typename Right>
255  blas::integer& k, const Left& left,
256  const Right& right) const {
257  // Check that the arguments are not empty and have the correct ranks
258  TA_ASSERT(left.rank() == left_.rank);
259  TA_ASSERT(right.rank() == right_.rank);
260  const auto* MADNESS_RESTRICT const left_extent = left.extent_data();
261  const auto* MADNESS_RESTRICT const right_extent = right.extent_data();
262 
263  // Compute fused dimension sizes
264  m = 1;
265  for (unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i)
266  m *= left_extent[i];
267  k = 1;
268  for (unsigned int i = left_.inner[0]; i < left_.inner[1]; ++i)
269  k *= left_extent[i];
270  n = 1;
271  for (unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i)
272  n *= right_extent[i];
273  }
274 
275  blas::Op left_op() const { return left_op_; }
276  blas::Op right_op() const { return right_op_; }
277 }; // class GemmHelper
278 
279 } // namespace TiledArray::math
280 
281 #endif // TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
R make_result_range(const Left &left, const Right &right) const
Construct a result range based on left and right ranges.
Definition: gemm_helper.h:165
unsigned int right_outer_begin() const
Definition: gemm_helper.h:152
Contraction to *GEMM helper.
Definition: gemm_helper.h:40
::blas::Op Op
Definition: blas.h:46
unsigned int left_rank() const
Left-hand argument rank accessor.
Definition: gemm_helper.h:138
GemmHelper(const blas::Op left_op, const blas::Op right_op, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank)
Definition: gemm_helper.h:61
unsigned int left_inner_end() const
Definition: gemm_helper.h:146
unsigned int right_inner_end() const
Definition: gemm_helper.h:151
unsigned int left_outer_end() const
Definition: gemm_helper.h:148
void outer(const std::size_t m, const std::size_t n, const X *const x, const Y *const y, A *a, const Op &op)
Compute the outer of x and y to modify a.
Definition: outer.h:239
unsigned int right_outer_end() const
Definition: gemm_helper.h:153
blas::Op left_op() const
Definition: gemm_helper.h:275
int64_t integer
Definition: blas.h:44
bool right_result_congruent(const Right &right, const Result &result) const
Definition: gemm_helper.h:221
bool left_right_congruent(const Left &left, const Right &right) const
Definition: gemm_helper.h:238
bool left_result_congruent(const Left &left, const Result &result) const
Definition: gemm_helper.h:205
unsigned int left_inner_begin() const
Definition: gemm_helper.h:145
auto rank(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1617
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
unsigned int result_rank() const
Result rank accessor.
Definition: gemm_helper.h:133
void compute_matrix_sizes(blas::integer &m, blas::integer &n, blas::integer &k, const Left &left, const Right &right) const
Compute the matrix dimension that can be used in a *GEMM call.
Definition: gemm_helper.h:254
blas::Op right_op() const
Definition: gemm_helper.h:276
unsigned int right_rank() const
Right-hand argument rank accessor.
Definition: gemm_helper.h:143
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
Definition: gemm_helper.h:126
unsigned int right_inner_begin() const
Definition: gemm_helper.h:150
GemmHelper(const GemmHelper &other)
Functor copy constructor.
Definition: gemm_helper.h:103
unsigned int left_outer_begin() const
Definition: gemm_helper.h:147
GemmHelper & operator=(const GemmHelper &other)
Functor assignment operator.
Definition: gemm_helper.h:113
auto inner(const Permutation &p)
Definition: permutation.h:813