thrust.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2018 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Eduard Valeyev
19  * Department of Chemistry, Virginia Tech
20  * Mar 16, 2018
21  *
22  */
23 
24 #ifndef TILEDARRAY_CUDA_THRUST_H__INCLUDED
25 #define TILEDARRAY_CUDA_THRUST_H__INCLUDED
26 
27 #include <TiledArray/config.h>
28 
29 #ifdef TILEDARRAY_HAS_CUDA
30 
31 #include <cuda_runtime_api.h>
32 #include <thrust/device_vector.h>
33 #include <thrust/host_vector.h>
34 
35 // thrust::device_vector::data() returns a proxy, provide an overload for
36 // std::data() to provide raw ptr
37 namespace thrust {
38 
39 // thrust::device_malloc_allocator name changed to device_allocator after
40 // version 10
41 #if CUDART_VERSION < 10000
42 template <typename T>
43 using device_allocator = thrust::device_malloc_allocator<T>;
44 #endif
45 
46 template <typename T, typename Alloc>
47 const T* data(const thrust::device_vector<T, Alloc>& dev_vec) {
48  return thrust::raw_pointer_cast(dev_vec.data());
49 }
50 template <typename T, typename Alloc>
51 T* data(thrust::device_vector<T, Alloc>& dev_vec) {
52  return thrust::raw_pointer_cast(dev_vec.data());
53 }
54 
55 // this must be instantiated in a .cu file
56 template <typename T, typename Alloc>
57 void resize(thrust::device_vector<T, Alloc>& dev_vec, size_t size);
58 } // namespace thrust
59 
60 #endif // TILEDARRAY_HAS_CUDA
61 
62 #endif // TILEDARRAY_CUDA_THRUST_H__INCLUDED