TiledArray  0.7.0
blas.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * justus
19  * Department of Chemistry, Virginia Tech
20  *
21  * blas.h
22  * Nov 17, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_BLAS_H__INCLUDED
27 #define TILEDARRAY_BLAS_H__INCLUDED
28 
29 #include <madness/tensor/cblas.h>
30 #include <TiledArray/type_traits.h>
31 #include <TiledArray/math/eigen.h>
32 
33 namespace TiledArray {
34  namespace math {
35 
36  // BLAS _GEMM wrapper functions
37 
38  template <typename S1, typename T1, typename T2, typename S2, typename T3>
39  inline void gemm(madness::cblas::CBLAS_TRANSPOSE op_a,
40  madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n,
41  const integer k, const S1 alpha, const T1* a, const integer lda,
42  const T2* b, const integer ldb, const S2 beta, T3* c, const integer ldc)
43  {
44  // Define operations
45  static const unsigned int
46  notrans_notrans = 0x00000000,
47  notrans_trans = 0x00000004,
48  trans_notrans = 0x00000001,
49  trans_trans = 0x00000005,
50  notrans_conjtrans = 0x00000008,
51  trans_conjtrans = 0x00000009,
52  conjtrans_notrans = 0x00000002,
53  conjtrans_trans = 0x00000006,
54  conjtrans_conjtrans = 0x0000000a;
55 
56  // Construct matrix maps for a, b, and c.
61  (op_a == madness::cblas::NoTrans ? m : k),
62  (op_a == madness::cblas::NoTrans ? k : m),
63  Eigen::OuterStride<>(lda));
65  (op_b == madness::cblas::NoTrans ? k : n),
66  (op_b == madness::cblas::NoTrans ? n : k),
67  Eigen::OuterStride<>(ldb));
69  C(c, m, n, Eigen::OuterStride<>(ldc));
70 
71  const bool beta_is_nonzero = (beta != static_cast<S2>(0));
72 
73  switch(op_a | (op_b << 2)) {
74  case notrans_notrans:
75  if (beta_is_nonzero)
76  C.noalias() = alpha * A * B + beta * C;
77  else
78  C.noalias() = alpha * A * B;
79  break;
80  case notrans_trans:
81  if (beta_is_nonzero)
82  C.noalias() = alpha * A * B.transpose() + beta * C;
83  else
84  C.noalias() = alpha * A * B.transpose();
85  break;
86  case trans_notrans:
87  if (beta_is_nonzero)
88  C.noalias() = alpha * A.transpose() * B + beta * C;
89  else
90  C.noalias() = alpha * A.transpose() * B;
91  break;
92  case trans_trans:
93  if (beta_is_nonzero)
94  C.noalias() = alpha * A.transpose() * B.transpose() + beta * C;
95  else
96  C.noalias() = alpha * A.transpose() * B.transpose();
97  break;
98 
99  case notrans_conjtrans:
100  if (beta_is_nonzero)
101  C.noalias() = alpha * A * B.adjoint() + beta * C;
102  else
103  C.noalias() = alpha * A * B.adjoint();
104  break;
105  case trans_conjtrans:
106  if (beta_is_nonzero)
107  C.noalias() = alpha * A.transpose() * B.adjoint() + beta * C;
108  else
109  C.noalias() = alpha * A.transpose() * B.adjoint();
110  break;
111  case conjtrans_notrans:
112  if (beta_is_nonzero)
113  C.noalias() = alpha * A.adjoint() * B + beta * C;
114  else
115  C.noalias() = alpha * A.adjoint() * B;
116  break;
117  case conjtrans_trans:
118  if (beta_is_nonzero)
119  C.noalias() = alpha * A.adjoint() * B.transpose() + beta * C;
120  else
121  C.noalias() = alpha * A.adjoint() * B.transpose();
122  break;
123  case conjtrans_conjtrans:
124  if (beta_is_nonzero)
125  C.noalias() = alpha * A.adjoint() * B.adjoint() + beta * C;
126  else
127  C.noalias() = alpha * A.adjoint() * B.adjoint();
128  break;
129  }
130  }
131 
132  inline void gemm(madness::cblas::CBLAS_TRANSPOSE op_a,
133  madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n,
134  const integer k, const float alpha, const float* a, const integer lda,
135  const float* b, const integer ldb, const float beta, float* c, const integer ldc)
136  {
137  madness::cblas::gemm(op_b, op_a, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc);
138  }
139 
140  inline void gemm(madness::cblas::CBLAS_TRANSPOSE op_a,
141  madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n,
142  const integer k, const double alpha, const double* a, const integer lda,
143  const double* b, const integer ldb, const double beta, double* c, const integer ldc)
144  {
145  madness::cblas::gemm(op_b, op_a, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc);
146  }
147 
148  inline void gemm(madness::cblas::CBLAS_TRANSPOSE op_a,
149  madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n,
150  const integer k, const std::complex<float> alpha, const std::complex<float>* a,
151  const integer lda, const std::complex<float>* b, const integer ldb,
152  const std::complex<float> beta, std::complex<float>* c, const integer ldc)
153  {
154  madness::cblas::gemm(op_b, op_a, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc);
155  }
156 
157  inline void gemm(madness::cblas::CBLAS_TRANSPOSE op_a,
158  madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n,
159  const integer k, const std::complex<double> alpha, const std::complex<double>* a,
160  const integer lda, const std::complex<double>* b, const integer ldb,
161  const std::complex<double> beta, std::complex<double>* c, const integer ldc)
162  {
163  madness::cblas::gemm(op_b, op_a, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc);
164  }
165 
166 
167  // BLAS _SCAL wrapper functions
168 
169  template <typename T, typename U>
170  inline typename std::enable_if<detail::is_numeric<T>::value>::type
171  scale(const integer n, const T alpha, U* x) {
172  eigen_map(x, n) *= alpha;
173  }
174 
175  inline void scale(const integer n, const float alpha, float* x) {
176  madness::cblas::scal(n, alpha, x, 1);
177  }
178 
179  inline void scale(const integer n, const double alpha, double* x) {
180  madness::cblas::scal(n, alpha, x, 1);
181  }
182 
183  inline void scale(const integer n, const std::complex<float> alpha, std::complex<float>* x) {
184  madness::cblas::scal(n, alpha, x, 1);
185  }
186 
187  inline void scale(const integer n, const std::complex<double> alpha, std::complex<double>* x) {
188  madness::cblas::scal(n, alpha, x, 1);
189  }
190 
191  inline void scale(const integer n, const float alpha, std::complex<float>* x) {
192  madness::cblas::scal(n, alpha, x, 1);
193  }
194 
195  inline void scale(const integer n, const double alpha, std::complex<double>* x) {
196  madness::cblas::scal(n, alpha, x, 1);
197  }
198 
199 
200  // BLAS _DOT wrapper functions
201 
202  template <typename T, typename U>
203  T dot(const integer n, const T* x, const U* y) {
204  return eigen_map(x, n).dot(eigen_map(y, n));
205  }
206 
207  inline float dot(integer n, const float* x, const float* y) {
208  return madness::cblas::dot(n, x, 1, y, 1);
209  }
210 
211  inline double dot(integer n, const double* x, const double* y) {
212  return madness::cblas::dot(n, x, 1, y, 1);
213  }
214 
215  inline std::complex<float> dot(integer n, const std::complex<float>* x, const std::complex<float>* y) {
216  return madness::cblas::dot(n, x, 1, y, 1);
217  }
218 
219  inline std::complex<double> dot(integer n, const std::complex<double>* x, const std::complex<double>* y) {
220  return madness::cblas::dot(n, x, 1, y, 1);
221  }
222 
223  // Import the madness dot functions into the TiledArray namespace
224  using madness::cblas::dot;
225 
226 
227  } // namespace math
228 } // namespace TiledArray
229 
230 #endif // TILEDARRAY_BLAS_H__INCLUDED
std::complex< double > dot(integer n, const std::complex< double > *x, const std::complex< double > *y)
Definition: blas.h:219
Eigen::Map< const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor >, Eigen::AutoAlign > eigen_map(const T *t, const std::size_t m, const std::size_t n)
Construct a const Eigen::Map object for a given Tensor object.
Definition: eigen.h:51
void gemm(madness::cblas::CBLAS_TRANSPOSE op_a, madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n, const integer k, const std::complex< double > alpha, const std::complex< double > *a, const integer lda, const std::complex< double > *b, const integer ldb, const std::complex< double > beta, std::complex< double > *c, const integer ldc)
Definition: blas.h:157
T dot(const integer n, const T *x, const U *y)
Definition: blas.h:203
void gemm(madness::cblas::CBLAS_TRANSPOSE op_a, madness::cblas::CBLAS_TRANSPOSE op_b, const integer m, const integer n, const integer k, const S1 alpha, const T1 *a, const integer lda, const T2 *b, const integer ldb, const S2 beta, T3 *c, const integer ldc)
Definition: blas.h:39
std::enable_if< detail::is_numeric< T >::value >::type scale(const integer n, const T alpha, U *x)
Definition: blas.h:171