TiledArray  0.7.0
gemm_helper.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * gemm_helper.h
22  * Jan 20, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
27 #define TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
28 
29 #include <TiledArray/error.h>
30 #include <TiledArray/madness.h>
31 
32 namespace TiledArray {
33  namespace math {
34 
36 
39  class GemmHelper {
40  private:
41 
42  madness::cblas::CBLAS_TRANSPOSE left_op_;
44  madness::cblas::CBLAS_TRANSPOSE right_op_;
46  unsigned int result_rank_;
47 
49 
53  struct ContractArg {
54  unsigned int inner[2];
55  unsigned int outer[2];
56  unsigned int rank;
57  }
58  left_,
59  right_;
60 
61 
62  public:
63 
64  GemmHelper(const madness::cblas::CBLAS_TRANSPOSE left_op,
65  const madness::cblas::CBLAS_TRANSPOSE right_op,
66  const unsigned int result_rank, const unsigned int left_rank,
67  const unsigned int right_rank) :
68  left_op_(left_op), right_op_(right_op),
69  result_rank_(result_rank), left_(), right_()
70  {
71  // Compute the number of contracted dimensions in left and right.
72  TA_ASSERT(((left_rank + right_rank - result_rank) % 2u) == 0u);
73 
74  left_.rank = left_rank;
75  right_.rank = right_rank;
76  const unsigned int contract_size = num_contract_ranks();
77 
78  // Store the inner and outer dimension ranges for the left-hand argument.
79  if(left_op == madness::cblas::NoTrans) {
80  left_.outer[0] = 0u;
81  left_.outer[1] = left_.inner[0] = left_rank - contract_size;
82  left_.inner[1] = left_rank;
83  } else {
84  left_.inner[0] = 0ul;
85  left_.inner[1] = left_.outer[0] = contract_size;
86  left_.outer[1] = left_rank;
87  }
88 
89  // Store the inner and outer dimension ranges for the right-hand argument.
90  if(right_op == madness::cblas::NoTrans) {
91  right_.inner[0] = 0u;
92  right_.inner[1] = right_.outer[0] = contract_size;
93  right_.outer[1] = right_rank;
94  } else {
95  right_.outer[0] = 0u;
96  right_.outer[1] = right_.inner[0] = right_rank - contract_size;
97  right_.inner[1] = right_rank;
98  }
99  }
100 
102 
105  GemmHelper(const GemmHelper& other) :
106  left_op_(other.left_op_), right_op_(other.right_op_),
107  result_rank_(other.result_rank_),
108  left_(other.left_), right_(other.right_)
109  { }
110 
112 
114  GemmHelper& operator=(const GemmHelper& other) {
115  left_op_ = other.left_op_;
116  right_op_ = other.right_op_;
117  result_rank_ = other.result_rank_;
118  left_ = other.left_;
119  right_ = other.right_;
120 
121  return *this;
122  }
123 
125 
127  unsigned int num_contract_ranks() const {
128  return (left_.rank + right_.rank - result_rank_) >> 1;
129  }
130 
132 
134  unsigned int result_rank() const { return result_rank_; }
135 
137 
139  unsigned int left_rank() const { return left_.rank; }
140 
142 
144  unsigned int right_rank() const { return right_.rank; }
145 
146  unsigned int left_inner_begin() const { return left_.inner[0]; }
147  unsigned int left_inner_end() const { return left_.inner[1]; }
148  unsigned int left_outer_begin() const { return left_.outer[0]; }
149  unsigned int left_outer_end() const { return left_.outer[1]; }
150 
151  unsigned int right_inner_begin() const { return right_.inner[0]; }
152  unsigned int right_inner_end() const { return right_.inner[1]; }
153  unsigned int right_outer_begin() const { return right_.outer[0]; }
154  unsigned int right_outer_end() const { return right_.outer[1]; }
155 
157 
165  template <typename R, typename Left, typename Right>
166  R make_result_range(const Left& left, const Right& right) const {
167  // Get pointers to lower and upper bounds of left and right.
168  const auto* MADNESS_RESTRICT const left_lower = left.lobound_data();
169  const auto* MADNESS_RESTRICT const left_upper = left.upbound_data();
170  const auto* MADNESS_RESTRICT const right_lower = right.lobound_data();
171  const auto* MADNESS_RESTRICT const right_upper = right.upbound_data();
172 
173  // Create the start and finish indices
174  std::vector<std::size_t> lower, upper;
175  lower.reserve(result_rank_);
176  upper.reserve(result_rank_);
177 
178  // Copy left-hand argument outer dimensions to start and finish
179  for(unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i) {
180  lower.push_back(left_lower[i]);
181  upper.push_back(left_upper[i]);
182  }
183 
184  // Copy right-hand argument outer dimensions to start and finish
185  for(unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i) {
186  lower.push_back(right_lower[i]);
187  upper.push_back(right_upper[i]);
188  }
189 
190  // Construct the result tile range
191  return R(lower, upper);
192  }
193 
195 
204  template <typename Left, typename Result>
205  bool left_result_congruent(const Left& left, const Result& result) const {
206  return std::equal(left + left_.outer[0], left + left_.outer[1], result);
207  }
208 
210 
219  template <typename Right, typename Result>
220  bool right_result_congruent(const Right& right, const Result& result) const {
221  return std::equal(right + right_.outer[0], right + right_.outer[1],
222  result + (left_.outer[1] - left_.outer[0]));
223  }
224 
226 
235  template <typename Left, typename Right>
236  bool left_right_congruent(const Left& left, const Right& right) const {
237  return std::equal(left + left_.inner[0], left + left_.inner[1],
238  right + right_.inner[0]);
239  }
240 
242 
251  template <typename Left, typename Right>
252  void compute_matrix_sizes(integer& m, integer& n, integer& k,
253  const Left& left, const Right& right) const
254  {
255  // Check that the arguments are not empty and have the correct ranks
256  TA_ASSERT(left.rank() == left_.rank);
257  TA_ASSERT(right.rank() == right_.rank);
258  const auto* MADNESS_RESTRICT const left_extent = left.extent_data();
259  const auto* MADNESS_RESTRICT const right_extent = right.extent_data();
260 
261  // Compute fused dimension sizes
262  m = 1;
263  for(unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i)
264  m *= left_extent[i];
265  k = 1;
266  for(unsigned int i = left_.inner[0]; i < left_.inner[1]; ++i)
267  k *= left_extent[i];
268  n = 1;
269  for(unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i)
270  n *= right_extent[i];
271  }
272 
273  madness::cblas::CBLAS_TRANSPOSE left_op() const { return left_op_; }
274  madness::cblas::CBLAS_TRANSPOSE right_op() const { return right_op_; }
275  }; // class GemmHelper
276 
277  } // namespace math
278 } // namespace TiledArray
279 
280 #endif // TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED
void compute_matrix_sizes(integer &m, integer &n, integer &k, const Left &left, const Right &right) const
Compute the matrix dimension that can be used in a *GEMM call.
Definition: gemm_helper.h:252
unsigned int right_inner_begin() const
Definition: gemm_helper.h:151
bool left_result_congruent(const Left &left, const Result &result) const
Test that the outer dimensions of left are congruent (have equal extent) with that of the result tens...
Definition: gemm_helper.h:205
bool right_result_congruent(const Right &right, const Result &result) const
Test that the outer dimensions of right are congruent (have equal extent) with that of the result ten...
Definition: gemm_helper.h:220
madness::cblas::CBLAS_TRANSPOSE right_op() const
Definition: gemm_helper.h:274
bool left_right_congruent(const Left &left, const Right &right) const
Test that the inner dimensions of left are congruent (have equal extent) with that of right...
Definition: gemm_helper.h:236
unsigned int result_rank() const
Result rank accessor.
Definition: gemm_helper.h:134
void outer(const std::size_t m, const std::size_t n, const X *const x, const Y *const y, A *a, const Op &op)
Compute the outer of x and y to modify a.
Definition: outer.h:248
GemmHelper & operator=(const GemmHelper &other)
Functor assignment operator.
Definition: gemm_helper.h:114
R make_result_range(const Left &left, const Right &right) const
Construct a result range based on left and right ranges.
Definition: gemm_helper.h:166
unsigned int left_outer_end() const
Definition: gemm_helper.h:149
#define TA_ASSERT(a)
Definition: error.h:107
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
Definition: gemm_helper.h:127
unsigned int right_inner_end() const
Definition: gemm_helper.h:152
unsigned int right_outer_begin() const
Definition: gemm_helper.h:153
GemmHelper(const GemmHelper &other)
Functor copy constructor.
Definition: gemm_helper.h:105
unsigned int right_rank() const
Right-hand argument rank accessor.
Definition: gemm_helper.h:144
Contraction to *GEMM helper.
Definition: gemm_helper.h:39
madness::cblas::CBLAS_TRANSPOSE left_op() const
Definition: gemm_helper.h:273
unsigned int left_outer_begin() const
Definition: gemm_helper.h:148
unsigned int right_outer_end() const
Definition: gemm_helper.h:154
unsigned int left_rank() const
Left-hand argument rank accessor.
Definition: gemm_helper.h:139
unsigned int left_inner_end() const
Definition: gemm_helper.h:147
GemmHelper(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank)
Definition: gemm_helper.h:64
unsigned int left_inner_begin() const
Definition: gemm_helper.h:146