blas.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * justus
19  * Department of Chemistry, Virginia Tech
20  *
21  * blas.h
22  * Nov 17, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_MATH_BLAS_H__INCLUDED
27 #define TILEDARRAY_MATH_BLAS_H__INCLUDED
28 
30 #include <TiledArray/type_traits.h>
31 
32 #include <blas/dot.hh>
33 #include <blas/gemm.hh>
34 #include <blas/scal.hh>
35 #include <blas/util.hh>
36 #include <blas/wrappers.hh>
37 
38 #include <cstdint>
39 
41 
44 using integer = int64_t;
45 
46 using Op = ::blas::Op;
47 static constexpr auto NoTranspose = Op::NoTrans;
48 static constexpr auto Transpose = Op::Trans;
49 static constexpr auto ConjTranspose = Op::ConjTrans;
50 
53 inline int64_t to_int(Op op) {
54  if (op == NoTranspose)
55  return 0;
56  else if (op == Transpose)
57  return 1;
58  else // op == ConjTranspose
59  return 2;
60 }
61 
62 template <typename T, int Options = ::Eigen::ColMajor>
63 using Matrix = ::Eigen::Matrix<T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options>;
64 
65 template <typename T>
66 using Vector = ::Eigen::Matrix<T, ::Eigen::Dynamic, 1, ::Eigen::ColMajor>;
67 
68 // BLAS _GEMM wrapper functions
69 
70 template <typename S1, typename T1, typename T2, typename S2, typename T3>
71 inline void gemm(Op op_a, Op op_b, const integer m, const integer n,
72  const integer k, const S1 alpha, const T1* a,
73  const integer lda, const T2* b, const integer ldb,
74  const S2 beta, T3* c, const integer ldc) {
75  // Define operations
76  static const unsigned int notrans_notrans = 0x00000000,
77  notrans_trans = 0x00000004,
78  trans_notrans = 0x00000001,
79  trans_trans = 0x00000005,
80  notrans_conjtrans = 0x00000008,
81  trans_conjtrans = 0x00000009,
82  conjtrans_notrans = 0x00000002,
83  conjtrans_trans = 0x00000006,
84  conjtrans_conjtrans = 0x0000000a;
85 
86  // Construct matrix maps for a, b, and c.
87  typedef Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
88  matrixA_type;
89  typedef Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
90  matrixB_type;
91  typedef Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
92  matrixC_type;
93  Eigen::Map<const matrixA_type, Eigen::AutoAlign, Eigen::OuterStride<>> A(
94  a, (op_a == NoTranspose ? m : k), (op_a == NoTranspose ? k : m),
95  Eigen::OuterStride<>(lda));
96  Eigen::Map<const matrixB_type, Eigen::AutoAlign, Eigen::OuterStride<>> B(
97  b, (op_b == NoTranspose ? k : n), (op_b == NoTranspose ? n : k),
98  Eigen::OuterStride<>(ldb));
99  Eigen::Map<matrixC_type, Eigen::AutoAlign, Eigen::OuterStride<>> C(
100  c, m, n, Eigen::OuterStride<>(ldc));
101 
102  const bool beta_is_nonzero = (beta != static_cast<S2>(0));
103 
104  switch (to_int(op_a) | (to_int(op_b) << 2)) {
105  case notrans_notrans:
106  if (beta_is_nonzero)
107  C.noalias() = alpha * A * B + beta * C;
108  else
109  C.noalias() = alpha * A * B;
110  break;
111  case notrans_trans:
112  if (beta_is_nonzero)
113  C.noalias() = alpha * A * B.transpose() + beta * C;
114  else
115  C.noalias() = alpha * A * B.transpose();
116  break;
117  case trans_notrans:
118  if (beta_is_nonzero)
119  C.noalias() = alpha * A.transpose() * B + beta * C;
120  else
121  C.noalias() = alpha * A.transpose() * B;
122  break;
123  case trans_trans:
124  if (beta_is_nonzero)
125  C.noalias() = alpha * A.transpose() * B.transpose() + beta * C;
126  else
127  C.noalias() = alpha * A.transpose() * B.transpose();
128  break;
129 
130  case notrans_conjtrans:
131  if (beta_is_nonzero)
132  C.noalias() = alpha * A * B.adjoint() + beta * C;
133  else
134  C.noalias() = alpha * A * B.adjoint();
135  break;
136  case trans_conjtrans:
137  if (beta_is_nonzero)
138  C.noalias() = alpha * A.transpose() * B.adjoint() + beta * C;
139  else
140  C.noalias() = alpha * A.transpose() * B.adjoint();
141  break;
142  case conjtrans_notrans:
143  if (beta_is_nonzero)
144  C.noalias() = alpha * A.adjoint() * B + beta * C;
145  else
146  C.noalias() = alpha * A.adjoint() * B;
147  break;
148  case conjtrans_trans:
149  if (beta_is_nonzero)
150  C.noalias() = alpha * A.adjoint() * B.transpose() + beta * C;
151  else
152  C.noalias() = alpha * A.adjoint() * B.transpose();
153  break;
154  case conjtrans_conjtrans:
155  if (beta_is_nonzero)
156  C.noalias() = alpha * A.adjoint() * B.adjoint() + beta * C;
157  else
158  C.noalias() = alpha * A.adjoint() * B.adjoint();
159  break;
160  }
161 }
162 
163 inline void gemm(Op op_a, Op op_b, const integer m, const integer n,
164  const integer k, const float alpha, const float* a,
165  const integer lda, const float* b, const integer ldb,
166  const float beta, float* c, const integer ldc) {
167  ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
168  lda, beta, c, ldc);
169 }
170 
171 inline void gemm(Op op_a, Op op_b, const integer m, const integer n,
172  const integer k, const double alpha, const double* a,
173  const integer lda, const double* b, const integer ldb,
174  const double beta, double* c, const integer ldc) {
175  ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
176  lda, beta, c, ldc);
177 }
178 
179 inline void gemm(Op op_a, Op op_b, const integer m, const integer n,
180  const integer k, const std::complex<float> alpha,
181  const std::complex<float>* a, const integer lda,
182  const std::complex<float>* b, const integer ldb,
183  const std::complex<float> beta, std::complex<float>* c,
184  const integer ldc) {
185  ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
186  lda, beta, c, ldc);
187 }
188 
189 inline void gemm(Op op_a, Op op_b, const integer m, const integer n,
190  const integer k, const std::complex<double> alpha,
191  const std::complex<double>* a, const integer lda,
192  const std::complex<double>* b, const integer ldb,
193  const std::complex<double> beta, std::complex<double>* c,
194  const integer ldc) {
195  ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
196  lda, beta, c, ldc);
197 }
198 
199 // BLAS _SCAL wrapper functions
200 
201 template <typename T, typename U>
202 inline typename std::enable_if<detail::is_numeric_v<T>>::type scale(
203  const integer n, const T alpha, U* x) {
204  Vector<T>::Map(x, n) *= alpha;
205 }
206 
207 inline void scale(const integer n, const float alpha, float* x) {
208  ::blas::scal(n, alpha, x, 1);
209 }
210 
211 inline void scale(const integer n, const double alpha, double* x) {
212  ::blas::scal(n, alpha, x, 1);
213 }
214 
215 inline void scale(const integer n, const std::complex<float> alpha,
216  std::complex<float>* x) {
217  ::blas::scal(n, alpha, x, 1);
218 }
219 
220 inline void scale(const integer n, const std::complex<double> alpha,
221  std::complex<double>* x) {
222  ::blas::scal(n, alpha, x, 1);
223 }
224 
225 inline void scale(const integer n, const float alpha, std::complex<float>* x) {
226  ::blas::scal(n, std::complex<float>{alpha, 0}, x, 1);
227 }
228 
229 inline void scale(const integer n, const double alpha,
230  std::complex<double>* x) {
231  ::blas::scal(n, std::complex<double>{alpha, 0}, x, 1);
232 }
233 
234 // BLAS _DOT wrapper functions
235 
236 template <typename T, typename U>
237 T dot(const integer n, const T* x, const U* y) {
238  return Vector<T>::Map(x, n).dot(Vector<T>::Map(y, n));
239 }
240 
241 inline float dot(const integer n, const float* x, const float* y) {
242  return ::blas::dot(n, x, 1, y, 1);
243 }
244 
245 inline double dot(integer n, const double* x, const double* y) {
246  return ::blas::dot(n, x, 1, y, 1);
247 }
248 
249 inline std::complex<float> dot(integer n, const std::complex<float>* x,
250  const std::complex<float>* y) {
251  return ::blas::dot(n, x, 1, y, 1);
252 }
253 
254 inline std::complex<double> dot(integer n, const std::complex<double>* x,
255  const std::complex<double>* y) {
256  return ::blas::dot(n, x, 1, y, 1);
257 }
258 
259 // Import the madness dot functions into the TiledArray namespace
261 
262 } // namespace TiledArray::math::blas
263 
264 namespace TiledArray {
265 // namespace blas = TiledArray::math::blas;
266 }
267 
268 #endif // TILEDARRAY_MATH_BLAS_H__INCLUDED
::blas::Op Op
Definition: blas.h:46
std::complex< double > dot(integer n, const std::complex< double > *x, const std::complex< double > *y)
Definition: blas.h:254
int64_t integer
Definition: blas.h:44
void gemm(Op op_a, Op op_b, const integer m, const integer n, const integer k, const S1 alpha, const T1 *a, const integer lda, const T2 *b, const integer ldb, const S2 beta, T3 *c, const integer ldc)
Definition: blas.h:71
void gemm(Op op_a, Op op_b, const integer m, const integer n, const integer k, const std::complex< double > alpha, const std::complex< double > *a, const integer lda, const std::complex< double > *b, const integer ldb, const std::complex< double > beta, std::complex< double > *c, const integer ldc)
Definition: blas.h:189
int64_t to_int(Op op)
Definition: blas.h:53
std::enable_if< detail::is_numeric_v< T > >::type scale(const integer n, const T alpha, U *x)
Definition: blas.h:202
::Eigen::Matrix< T, ::Eigen::Dynamic, 1, ::Eigen::ColMajor > Vector
Definition: blas.h:66
T dot(const integer n, const T *x, const U *y)
Definition: blas.h:237
::Eigen::Matrix< T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options > Matrix
Definition: blas.h:63