cublas.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2018 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Chong Peng
19  * Department of Chemistry, Virginia Tech
20  * July 23, 2018
21  *
22  */
23 
24 #ifndef TILEDARRAY_MATH_CUBLAS_H__INCLUDED
25 #define TILEDARRAY_MATH_CUBLAS_H__INCLUDED
26 
27 #include <TiledArray/config.h>
28 
29 #ifdef TILEDARRAY_HAS_CUDA
30 
31 #include <TiledArray/error.h>
33 #include <cublas_v2.h>
34 #include <thrust/system/cuda/error.h>
35 #include <thrust/system_error.h>
36 
37 #include <TiledArray/math/blas.h>
38 
39 #define CublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__)
40 
41 inline void __cublasSafeCall(cublasStatus_t err, const char *file,
42  const int line) {
43 #ifdef TILEDARRAY_CHECK_CUDA_ERROR
44  if (CUBLAS_STATUS_SUCCESS != err) {
45  std::stringstream ss;
46  ss << "cublasSafeCall() failed at: " << file << "(" << line << ")";
47  std::string what = ss.str();
48  throw std::runtime_error(what);
49  }
50 #endif
51 
52  return;
53 }
54 
55 namespace TiledArray {
56 
57 /*
58  * cuBLAS interface functions
59  */
60 
67 class cuBLASHandlePool {
68  public:
69  static const cublasHandle_t &handle() {
70  static thread_local cublasHandle_t *handle_{nullptr};
71  if (handle_ == nullptr) {
72  handle_ = new cublasHandle_t;
73  CublasSafeCall(cublasCreate(handle_));
74  CublasSafeCall(cublasSetPointerMode(*handle_, CUBLAS_POINTER_MODE_HOST));
75  }
76  return *handle_;
77  }
78 };
79 // thread_local cublasHandle_t *cuBLASHandlePool::handle_;
80 
81 inline cublasOperation_t to_cublas_op(math::blas::Op cblas_op) {
82  cublasOperation_t result{};
83  switch (cblas_op) {
84  case math::blas::Op::NoTrans:
85  result = CUBLAS_OP_N;
86  break;
87  case math::blas::Op::Trans:
88  result = CUBLAS_OP_T;
89  break;
90  case math::blas::Op::ConjTrans:
91  result = CUBLAS_OP_C;
92  break;
93  }
94  return result;
95 }
96 
98 
99 template <typename T>
100 cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
101  cublasOperation_t transb, int m, int n, int k,
102  const T *alpha, const T *A, int lda, const T *B,
103  int ldb, const T *beta, T *C, int ldc);
104 template <>
105 inline cublasStatus_t cublasGemm<float>(
106  cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
107  int m, int n, int k, const float *alpha, const float *A, int lda,
108  const float *B, int ldb, const float *beta, float *C, int ldc) {
109  return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
110  beta, C, ldc);
111 }
112 template <>
113 inline cublasStatus_t cublasGemm<double>(
114  cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
115  int m, int n, int k, const double *alpha, const double *A, int lda,
116  const double *B, int ldb, const double *beta, double *C, int ldc) {
117  return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
118  beta, C, ldc);
119 }
120 
122 
123 template <typename T, typename Scalar>
124 cublasStatus_t cublasAxpy(cublasHandle_t handle, int n, const Scalar *alpha,
125  const T *x, int incx, T *y, int incy);
126 template <>
127 inline cublasStatus_t cublasAxpy<float, float>(cublasHandle_t handle, int n,
128  const float *alpha,
129  const float *x, int incx,
130  float *y, int incy) {
131  return cublasSaxpy(handle, n, alpha, x, incx, y, incy);
132 }
133 
134 template <>
135 inline cublasStatus_t cublasAxpy<double, double>(cublasHandle_t handle, int n,
136  const double *alpha,
137  const double *x, int incx,
138  double *y, int incy) {
139  return cublasDaxpy(handle, n, alpha, x, incx, y, incy);
140 }
141 
142 template <>
143 inline cublasStatus_t cublasAxpy<float, int>(cublasHandle_t handle, int n,
144  const int *alpha, const float *x,
145  int incx, float *y, int incy) {
146  const float alpha_float = float(*alpha);
147  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
148 }
149 
150 template <>
151 inline cublasStatus_t cublasAxpy<float, double>(cublasHandle_t handle, int n,
152  const double *alpha,
153  const float *x, int incx,
154  float *y, int incy) {
155  const float alpha_float = float(*alpha);
156  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
157 }
158 
159 template <>
160 inline cublasStatus_t cublasAxpy<double, int>(cublasHandle_t handle, int n,
161  const int *alpha, const double *x,
162  int incx, double *y, int incy) {
163  const double alpha_double = double(*alpha);
164  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
165 }
166 
167 template <>
168 inline cublasStatus_t cublasAxpy<double, float>(cublasHandle_t handle, int n,
169  const float *alpha,
170  const double *x, int incx,
171  double *y, int incy) {
172  const double alpha_double = double(*alpha);
173  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
174 }
175 
176 template <>
177 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<void>>(
178  cublasHandle_t handle, int n, const detail::ComplexConjugate<void> *alpha,
179  const float *x, int incx, float *y, int incy) {
180  return CUBLAS_STATUS_SUCCESS;
181 }
182 
183 template <>
184 inline cublasStatus_t
185 cublasAxpy<float, detail::ComplexConjugate<detail::ComplexNegTag>>(
186  cublasHandle_t handle, int n,
187  const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
188  const float *x, int incx, float *y, int incy) {
189  const float alpha_float = float(-1.0);
190  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
191 }
192 
193 template <>
194 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<int>>(
195  cublasHandle_t handle, int n, const detail::ComplexConjugate<int> *alpha,
196  const float *x, int incx, float *y, int incy) {
197  const float alpha_float = float(alpha->factor());
198  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
199 }
200 
201 template <>
202 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<float>>(
203  cublasHandle_t handle, int n, const detail::ComplexConjugate<float> *alpha,
204  const float *x, int incx, float *y, int incy) {
205  const float alpha_float = float(alpha->factor());
206  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
207 }
208 
209 template <>
210 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<double>>(
211  cublasHandle_t handle, int n, const detail::ComplexConjugate<double> *alpha,
212  const float *x, int incx, float *y, int incy) {
213  const float alpha_float = float(alpha->factor());
214  return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
215 }
216 
217 template <>
218 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<void>>(
219  cublasHandle_t handle, int n, const detail::ComplexConjugate<void> *alpha,
220  const double *x, int incx, double *y, int incy) {
221  return CUBLAS_STATUS_SUCCESS;
222 }
223 
224 template <>
225 inline cublasStatus_t
226 cublasAxpy<double, detail::ComplexConjugate<detail::ComplexNegTag>>(
227  cublasHandle_t handle, int n,
228  const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
229  const double *x, int incx, double *y, int incy) {
230  const double alpha_double = double(-1.0);
231  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
232 }
233 
234 template <>
235 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<int>>(
236  cublasHandle_t handle, int n, const detail::ComplexConjugate<int> *alpha,
237  const double *x, int incx, double *y, int incy) {
238  const double alpha_double = double(alpha->factor());
239  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
240 }
241 
242 template <>
243 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<float>>(
244  cublasHandle_t handle, int n, const detail::ComplexConjugate<float> *alpha,
245  const double *x, int incx, double *y, int incy) {
246  const double alpha_double = double(alpha->factor());
247  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
248 }
249 
250 template <>
251 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<double>>(
252  cublasHandle_t handle, int n, const detail::ComplexConjugate<double> *alpha,
253  const double *x, int incx, double *y, int incy) {
254  const double alpha_double = double(alpha->factor());
255  return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
256 }
257 
259 
260 template <typename T>
261 cublasStatus_t cublasDot(cublasHandle_t handle, int n, const T *x, int incx,
262  const T *y, int incy, T *result);
263 template <>
264 inline cublasStatus_t cublasDot<float>(cublasHandle_t handle, int n,
265  const float *x, int incx, const float *y,
266  int incy, float *result) {
267  return cublasSdot(handle, n, x, incx, y, incy, result);
268 }
269 
270 template <>
271 inline cublasStatus_t cublasDot<double>(cublasHandle_t handle, int n,
272  const double *x, int incx,
273  const double *y, int incy,
274  double *result) {
275  return cublasDdot(handle, n, x, incx, y, incy, result);
276 }
277 
279 template <typename T, typename Scalar>
280 cublasStatus_t cublasScal(cublasHandle_t handle, int n, const Scalar *alpha,
281  T *x, int incx);
282 
283 template <>
284 inline cublasStatus_t cublasScal<float, float>(cublasHandle_t handle, int n,
285  const float *alpha, float *x,
286  int incx) {
287  return cublasSscal(handle, n, alpha, x, incx);
288 };
289 
290 template <>
291 inline cublasStatus_t cublasScal<double, double>(cublasHandle_t handle, int n,
292  const double *alpha, double *x,
293  int incx) {
294  return cublasDscal(handle, n, alpha, x, incx);
295 };
296 
297 template <>
298 inline cublasStatus_t cublasScal<float, int>(cublasHandle_t handle, int n,
299  const int *alpha, float *x,
300  int incx) {
301  const float alpha_float = float(*alpha);
302  return cublasSscal(handle, n, &alpha_float, x, incx);
303 };
304 
305 template <>
306 inline cublasStatus_t cublasScal<float, double>(cublasHandle_t handle, int n,
307  const double *alpha, float *x,
308  int incx) {
309  const float alpha_float = float(*alpha);
310  return cublasSscal(handle, n, &alpha_float, x, incx);
311 };
312 
313 //
314 template <>
315 inline cublasStatus_t cublasScal<double, int>(cublasHandle_t handle, int n,
316  const int *alpha, double *x,
317  int incx) {
318  const double alpha_double = double(*alpha);
319  return cublasDscal(handle, n, &alpha_double, x, incx);
320 };
321 
322 template <>
323 inline cublasStatus_t cublasScal<double, float>(cublasHandle_t handle, int n,
324  const float *alpha, double *x,
325  int incx) {
326  const double alpha_double = double(*alpha);
327  return cublasDscal(handle, n, &alpha_double, x, incx);
328 };
329 
330 template <>
331 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<void>>(
332  cublasHandle_t handle, int n, const detail::ComplexConjugate<void> *alpha,
333  float *x, int incx) {
334  return CUBLAS_STATUS_SUCCESS;
335 }
336 
337 template <>
338 inline cublasStatus_t
339 cublasScal<float, detail::ComplexConjugate<detail::ComplexNegTag>>(
340  cublasHandle_t handle, int n,
341  const detail::ComplexConjugate<detail::ComplexNegTag> *alpha, float *x,
342  int incx) {
343  const float alpha_float = float(-1.0);
344  return cublasSscal(handle, n, &alpha_float, x, incx);
345 }
346 
347 template <>
348 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<int>>(
349  cublasHandle_t handle, int n, const detail::ComplexConjugate<int> *alpha,
350  float *x, int incx) {
351  const float alpha_float = float(alpha->factor());
352  return cublasSscal(handle, n, &alpha_float, x, incx);
353 }
354 
355 template <>
356 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<float>>(
357  cublasHandle_t handle, int n, const detail::ComplexConjugate<float> *alpha,
358  float *x, int incx) {
359  const float alpha_float = float(alpha->factor());
360  return cublasSscal(handle, n, &alpha_float, x, incx);
361 }
362 
363 template <>
364 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<double>>(
365  cublasHandle_t handle, int n, const detail::ComplexConjugate<double> *alpha,
366  float *x, int incx) {
367  const float alpha_float = float(alpha->factor());
368  return cublasSscal(handle, n, &alpha_float, x, incx);
369 }
370 
371 template <>
372 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<void>>(
373  cublasHandle_t handle, int n, const detail::ComplexConjugate<void> *alpha,
374  double *x, int incx) {
375  return CUBLAS_STATUS_SUCCESS;
376 }
377 
378 template <>
379 inline cublasStatus_t
380 cublasScal<double, detail::ComplexConjugate<detail::ComplexNegTag>>(
381  cublasHandle_t handle, int n,
382  const detail::ComplexConjugate<detail::ComplexNegTag> *alpha, double *x,
383  int incx) {
384  const double alpha_double = double(-1.0);
385  return cublasDscal(handle, n, &alpha_double, x, incx);
386 }
387 
388 template <>
389 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<int>>(
390  cublasHandle_t handle, int n, const detail::ComplexConjugate<int> *alpha,
391  double *x, int incx) {
392  const double alpha_double = double(alpha->factor());
393  return cublasDscal(handle, n, &alpha_double, x, incx);
394 }
395 
396 template <>
397 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<float>>(
398  cublasHandle_t handle, int n, const detail::ComplexConjugate<float> *alpha,
399  double *x, int incx) {
400  const double alpha_double = double(alpha->factor());
401  return cublasDscal(handle, n, &alpha_double, x, incx);
402 }
403 
404 template <>
405 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<double>>(
406  cublasHandle_t handle, int n, const detail::ComplexConjugate<double> *alpha,
407  double *x, int incx) {
408  const double alpha_double = double(alpha->factor());
409  return cublasDscal(handle, n, &alpha_double, x, incx);
410 }
411 
413 template <typename T>
414 cublasStatus_t cublasCopy(cublasHandle_t handle, int n, const T *x, int incx,
415  T *y, int incy);
416 
417 template <>
418 inline cublasStatus_t cublasCopy(cublasHandle_t handle, int n, const float *x,
419  int incx, float *y, int incy) {
420  return cublasScopy(handle, n, x, incx, y, incy);
421 }
422 
423 template <>
424 inline cublasStatus_t cublasCopy(cublasHandle_t handle, int n, const double *x,
425  int incx, double *y, int incy) {
426  return cublasDcopy(handle, n, x, incx, y, incy);
427 }
428 
429 } // end of namespace TiledArray
430 
431 #endif // TILEDARRAY_HAS_CUDA
432 
433 #endif // TILEDARRAY_MATH_CUBLAS_H__INCLUDED
::blas::Op Op
Definition: blas.h:46