24 #ifndef TILEDARRAY_MATH_CUBLAS_H__INCLUDED
25 #define TILEDARRAY_MATH_CUBLAS_H__INCLUDED
27 #include <TiledArray/config.h>
29 #ifdef TILEDARRAY_HAS_CUDA
33 #include <cublas_v2.h>
34 #include <thrust/system/cuda/error.h>
35 #include <thrust/system_error.h>
39 #define CublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__)
41 inline void __cublasSafeCall(cublasStatus_t err,
const char *file,
43 #ifdef TILEDARRAY_CHECK_CUDA_ERROR
44 if (CUBLAS_STATUS_SUCCESS != err) {
46 ss <<
"cublasSafeCall() failed at: " << file <<
"(" << line <<
")";
47 std::string what = ss.str();
48 throw std::runtime_error(what);
67 class cuBLASHandlePool {
69 static const cublasHandle_t &handle() {
70 static thread_local cublasHandle_t *handle_{
nullptr};
71 if (handle_ ==
nullptr) {
72 handle_ =
new cublasHandle_t;
73 CublasSafeCall(cublasCreate(handle_));
74 CublasSafeCall(cublasSetPointerMode(*handle_, CUBLAS_POINTER_MODE_HOST));
82 cublasOperation_t result{};
84 case math::blas::Op::NoTrans:
87 case math::blas::Op::Trans:
90 case math::blas::Op::ConjTrans:
100 cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
101 cublasOperation_t transb,
int m,
int n,
int k,
102 const T *alpha,
const T *A,
int lda,
const T *B,
103 int ldb,
const T *beta, T *C,
int ldc);
105 inline cublasStatus_t cublasGemm<float>(
106 cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
107 int m,
int n,
int k,
const float *alpha,
const float *A,
int lda,
108 const float *B,
int ldb,
const float *beta,
float *C,
int ldc) {
109 return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
113 inline cublasStatus_t cublasGemm<double>(
114 cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
115 int m,
int n,
int k,
const double *alpha,
const double *A,
int lda,
116 const double *B,
int ldb,
const double *beta,
double *C,
int ldc) {
117 return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
123 template <
typename T,
typename Scalar>
124 cublasStatus_t cublasAxpy(cublasHandle_t handle,
int n,
const Scalar *alpha,
125 const T *x,
int incx, T *y,
int incy);
127 inline cublasStatus_t cublasAxpy<float, float>(cublasHandle_t handle,
int n,
129 const float *x,
int incx,
130 float *y,
int incy) {
131 return cublasSaxpy(handle, n, alpha, x, incx, y, incy);
135 inline cublasStatus_t cublasAxpy<double, double>(cublasHandle_t handle,
int n,
137 const double *x,
int incx,
138 double *y,
int incy) {
139 return cublasDaxpy(handle, n, alpha, x, incx, y, incy);
143 inline cublasStatus_t cublasAxpy<float, int>(cublasHandle_t handle,
int n,
144 const int *alpha,
const float *x,
145 int incx,
float *y,
int incy) {
146 const float alpha_float = float(*alpha);
147 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
151 inline cublasStatus_t cublasAxpy<float, double>(cublasHandle_t handle,
int n,
153 const float *x,
int incx,
154 float *y,
int incy) {
155 const float alpha_float = float(*alpha);
156 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
160 inline cublasStatus_t cublasAxpy<double, int>(cublasHandle_t handle,
int n,
161 const int *alpha,
const double *x,
162 int incx,
double *y,
int incy) {
163 const double alpha_double = double(*alpha);
164 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
168 inline cublasStatus_t cublasAxpy<double, float>(cublasHandle_t handle,
int n,
170 const double *x,
int incx,
171 double *y,
int incy) {
172 const double alpha_double = double(*alpha);
173 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
177 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<void>>(
178 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<void> *alpha,
179 const float *x,
int incx,
float *y,
int incy) {
180 return CUBLAS_STATUS_SUCCESS;
184 inline cublasStatus_t
185 cublasAxpy<float, detail::ComplexConjugate<detail::ComplexNegTag>>(
186 cublasHandle_t handle,
int n,
187 const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
188 const float *x,
int incx,
float *y,
int incy) {
189 const float alpha_float = float(-1.0);
190 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
194 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<int>>(
195 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<int> *alpha,
196 const float *x,
int incx,
float *y,
int incy) {
197 const float alpha_float = float(alpha->factor());
198 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
202 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<float>>(
203 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<float> *alpha,
204 const float *x,
int incx,
float *y,
int incy) {
205 const float alpha_float = float(alpha->factor());
206 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
210 inline cublasStatus_t cublasAxpy<float, detail::ComplexConjugate<double>>(
211 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<double> *alpha,
212 const float *x,
int incx,
float *y,
int incy) {
213 const float alpha_float = float(alpha->factor());
214 return cublasSaxpy(handle, n, &alpha_float, x, incx, y, incy);
218 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<void>>(
219 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<void> *alpha,
220 const double *x,
int incx,
double *y,
int incy) {
221 return CUBLAS_STATUS_SUCCESS;
225 inline cublasStatus_t
226 cublasAxpy<double, detail::ComplexConjugate<detail::ComplexNegTag>>(
227 cublasHandle_t handle,
int n,
228 const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
229 const double *x,
int incx,
double *y,
int incy) {
230 const double alpha_double = double(-1.0);
231 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
235 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<int>>(
236 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<int> *alpha,
237 const double *x,
int incx,
double *y,
int incy) {
238 const double alpha_double = double(alpha->factor());
239 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
243 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<float>>(
244 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<float> *alpha,
245 const double *x,
int incx,
double *y,
int incy) {
246 const double alpha_double = double(alpha->factor());
247 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
251 inline cublasStatus_t cublasAxpy<double, detail::ComplexConjugate<double>>(
252 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<double> *alpha,
253 const double *x,
int incx,
double *y,
int incy) {
254 const double alpha_double = double(alpha->factor());
255 return cublasDaxpy(handle, n, &alpha_double, x, incx, y, incy);
260 template <
typename T>
261 cublasStatus_t cublasDot(cublasHandle_t handle,
int n,
const T *x,
int incx,
262 const T *y,
int incy, T *result);
264 inline cublasStatus_t cublasDot<float>(cublasHandle_t handle,
int n,
265 const float *x,
int incx,
const float *y,
266 int incy,
float *result) {
267 return cublasSdot(handle, n, x, incx, y, incy, result);
271 inline cublasStatus_t cublasDot<double>(cublasHandle_t handle,
int n,
272 const double *x,
int incx,
273 const double *y,
int incy,
275 return cublasDdot(handle, n, x, incx, y, incy, result);
279 template <
typename T,
typename Scalar>
280 cublasStatus_t cublasScal(cublasHandle_t handle,
int n,
const Scalar *alpha,
284 inline cublasStatus_t cublasScal<float, float>(cublasHandle_t handle,
int n,
285 const float *alpha,
float *x,
287 return cublasSscal(handle, n, alpha, x, incx);
291 inline cublasStatus_t cublasScal<double, double>(cublasHandle_t handle,
int n,
292 const double *alpha,
double *x,
294 return cublasDscal(handle, n, alpha, x, incx);
298 inline cublasStatus_t cublasScal<float, int>(cublasHandle_t handle,
int n,
299 const int *alpha,
float *x,
301 const float alpha_float = float(*alpha);
302 return cublasSscal(handle, n, &alpha_float, x, incx);
306 inline cublasStatus_t cublasScal<float, double>(cublasHandle_t handle,
int n,
307 const double *alpha,
float *x,
309 const float alpha_float = float(*alpha);
310 return cublasSscal(handle, n, &alpha_float, x, incx);
315 inline cublasStatus_t cublasScal<double, int>(cublasHandle_t handle,
int n,
316 const int *alpha,
double *x,
318 const double alpha_double = double(*alpha);
319 return cublasDscal(handle, n, &alpha_double, x, incx);
323 inline cublasStatus_t cublasScal<double, float>(cublasHandle_t handle,
int n,
324 const float *alpha,
double *x,
326 const double alpha_double = double(*alpha);
327 return cublasDscal(handle, n, &alpha_double, x, incx);
331 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<void>>(
332 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<void> *alpha,
333 float *x,
int incx) {
334 return CUBLAS_STATUS_SUCCESS;
338 inline cublasStatus_t
339 cublasScal<float, detail::ComplexConjugate<detail::ComplexNegTag>>(
340 cublasHandle_t handle,
int n,
341 const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
float *x,
343 const float alpha_float = float(-1.0);
344 return cublasSscal(handle, n, &alpha_float, x, incx);
348 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<int>>(
349 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<int> *alpha,
350 float *x,
int incx) {
351 const float alpha_float = float(alpha->factor());
352 return cublasSscal(handle, n, &alpha_float, x, incx);
356 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<float>>(
357 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<float> *alpha,
358 float *x,
int incx) {
359 const float alpha_float = float(alpha->factor());
360 return cublasSscal(handle, n, &alpha_float, x, incx);
364 inline cublasStatus_t cublasScal<float, detail::ComplexConjugate<double>>(
365 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<double> *alpha,
366 float *x,
int incx) {
367 const float alpha_float = float(alpha->factor());
368 return cublasSscal(handle, n, &alpha_float, x, incx);
372 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<void>>(
373 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<void> *alpha,
374 double *x,
int incx) {
375 return CUBLAS_STATUS_SUCCESS;
379 inline cublasStatus_t
380 cublasScal<double, detail::ComplexConjugate<detail::ComplexNegTag>>(
381 cublasHandle_t handle,
int n,
382 const detail::ComplexConjugate<detail::ComplexNegTag> *alpha,
double *x,
384 const double alpha_double = double(-1.0);
385 return cublasDscal(handle, n, &alpha_double, x, incx);
389 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<int>>(
390 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<int> *alpha,
391 double *x,
int incx) {
392 const double alpha_double = double(alpha->factor());
393 return cublasDscal(handle, n, &alpha_double, x, incx);
397 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<float>>(
398 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<float> *alpha,
399 double *x,
int incx) {
400 const double alpha_double = double(alpha->factor());
401 return cublasDscal(handle, n, &alpha_double, x, incx);
405 inline cublasStatus_t cublasScal<double, detail::ComplexConjugate<double>>(
406 cublasHandle_t handle,
int n,
const detail::ComplexConjugate<double> *alpha,
407 double *x,
int incx) {
408 const double alpha_double = double(alpha->factor());
409 return cublasDscal(handle, n, &alpha_double, x, incx);
413 template <
typename T>
414 cublasStatus_t cublasCopy(cublasHandle_t handle,
int n,
const T *x,
int incx,
418 inline cublasStatus_t cublasCopy(cublasHandle_t handle,
int n,
const float *x,
419 int incx,
float *y,
int incy) {
420 return cublasScopy(handle, n, x, incx, y, incy);
424 inline cublasStatus_t cublasCopy(cublasHandle_t handle,
int n,
const double *x,
425 int incx,
double *y,
int incy) {
426 return cublasDcopy(handle, n, x, incx, y, incy);
431 #endif // TILEDARRAY_HAS_CUDA
433 #endif // TILEDARRAY_MATH_CUBLAS_H__INCLUDED