24 #ifndef TILEDARRAY_CUDA_MULT_KERNEL_IMPL_H__INCLUDED
25 #define TILEDARRAY_CUDA_MULT_KERNEL_IMPL_H__INCLUDED
28 #include <thrust/device_vector.h>
29 #include <thrust/execution_policy.h>
36 cudaStream_t stream,
int device_id) {
37 CudaSafeCall(cudaSetDevice(device_id));
39 thrust::multiplies<T> mul_op;
41 thrust::cuda::par.on(stream), thrust::device_pointer_cast(arg),
42 thrust::device_pointer_cast(arg) + n, thrust::device_pointer_cast(result),
43 thrust::device_pointer_cast(result), mul_op);
49 std::size_t n, cudaStream_t stream,
int device_id) {
50 CudaSafeCall(cudaSetDevice(device_id));
52 thrust::multiplies<T> mul_op;
54 thrust::cuda::par.on(stream), thrust::device_pointer_cast(arg1),
55 thrust::device_pointer_cast(arg1) + n, thrust::device_pointer_cast(arg2),
56 thrust::device_pointer_cast(result), mul_op);
61 #endif // TILEDARRAY_CUDA_MULT_KERNEL_IMPL_H__INCLUDED