5 #ifndef TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED
6 #define TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED
8 #include <TiledArray/config.h>
10 #ifdef TILEDARRAY_HAS_CUDA
14 #include <cuda_runtime.h>
15 #include <madness/world/taskfn.h>
20 template <
int64_t CallabackId>
21 std::atomic<int64_t>& cuda_callback_duration_ns() {
22 static std::atomic<int64_t> value{0};
26 inline std::atomic<int64_t>& cuda_taskfn_callback_duration_ns() {
27 static std::atomic<int64_t> value{0};
43 template <
typename fnT,
typename arg1T = void,
typename arg2T = void,
44 typename arg3T = void,
typename arg4T = void,
typename arg5T = void,
45 typename arg6T = void,
typename arg7T = void,
typename arg8T = void,
46 typename arg9T =
void>
47 struct cudaTaskFn :
public TaskInterface {
48 static_assert(not(std::is_const<arg1T>::value ||
49 std::is_reference<arg1T>::value),
50 "improper instantiation of cudaTaskFn, arg1T cannot be a const "
52 static_assert(not(std::is_const<arg2T>::value ||
53 std::is_reference<arg2T>::value),
54 "improper instantiation of cudaTaskFn, arg2T cannot be a const "
56 static_assert(not(std::is_const<arg3T>::value ||
57 std::is_reference<arg3T>::value),
58 "improper instantiation of cudaTaskFn, arg3T cannot be a const "
60 static_assert(not(std::is_const<arg4T>::value ||
61 std::is_reference<arg4T>::value),
62 "improper instantiation of cudaTaskFn, arg4T cannot be a const "
64 static_assert(not(std::is_const<arg5T>::value ||
65 std::is_reference<arg5T>::value),
66 "improper instantiation of cudaTaskFn, arg5T cannot be a const "
68 static_assert(not(std::is_const<arg6T>::value ||
69 std::is_reference<arg6T>::value),
70 "improper instantiation of cudaTaskFn, arg6T cannot be a const "
72 static_assert(not(std::is_const<arg7T>::value ||
73 std::is_reference<arg7T>::value),
74 "improper instantiation of cudaTaskFn, arg7T cannot be a const "
76 static_assert(not(std::is_const<arg8T>::value ||
77 std::is_reference<arg8T>::value),
78 "improper instantiation of cudaTaskFn, arg8T cannot be a const "
80 static_assert(not(std::is_const<arg9T>::value ||
81 std::is_reference<arg9T>::value),
82 "improper instantiation of cudaTaskFn, arg9T cannot be a const "
87 typedef cudaTaskFn<fnT, arg1T, arg2T, arg3T, arg4T, arg5T, arg6T, arg7T,
91 friend class AsyncTaskInterface;
94 struct AsyncTaskInterface :
public madness::TaskInterface {
95 AsyncTaskInterface(cudaTaskFn_* task,
int ndepend = 0,
96 const TaskAttributes attr = TaskAttributes())
97 : TaskInterface(ndepend, attr), task_(task) {}
99 virtual ~AsyncTaskInterface() =
default;
102 void run(
const TaskThreadEnv& env)
override {
108 auto stream = TiledArray::tls_cudastream_accessor();
113 if (stream ==
nullptr) {
118 cudaLaunchHostFunc(*stream, cuda_callback, task_);
120 TiledArray::synchronize_stream(
nullptr);
125 static void CUDART_CB cuda_callback(
void* userData) {
128 auto* callback =
static_cast<cudaTaskFn_*
>(userData);
136 TiledArray::detail::cuda_taskfn_callback_duration_ns() +=
144 typedef fnT functionT;
145 typedef typename detail::task_result_type<fnT>::resultT resultT;
147 typedef typename detail::task_result_type<fnT>::futureT futureT;
151 static const unsigned int arity =
152 detail::ArgCount<arg1T, arg2T, arg3T, arg4T, arg5T, arg6T, arg7T, arg8T,
160 const functionT func_;
161 TaskInterface* async_task_;
162 futureT async_result_;
168 typename detail::task_arg<arg1T>::holderT
170 typename detail::task_arg<arg2T>::holderT
172 typename detail::task_arg<arg3T>::holderT
174 typename detail::task_arg<arg4T>::holderT
176 typename detail::task_arg<arg5T>::holderT
178 typename detail::task_arg<arg6T>::holderT
180 typename detail::task_arg<arg7T>::holderT
182 typename detail::task_arg<arg8T>::holderT
184 typename detail::task_arg<arg9T>::holderT
187 template <
typename fT>
188 static fT& get_func(fT& f) {
192 template <
typename ptrT,
typename memfnT,
typename resT>
193 static memfnT get_func(
194 const detail::MemFuncWrapper<ptrT, memfnT, resT>& wrapper) {
195 return detail::get_mem_func_ptr(wrapper);
198 void get_id(std::pair<void*, unsigned short>&
id)
const override {
199 return make_id(
id, get_func(func_));
204 detail::run_function(async_result_, func_, arg1_, arg2_, arg3_, arg4_,
205 arg5_, arg6_, arg7_, arg8_, arg9_);
213 template <
typename T>
214 inline void check_dependency(Future<T>& fut) {
217 fut.register_callback(async_task_);
226 template <
typename T>
227 inline void check_dependency(Future<T>* fut) {
230 fut->register_callback(async_task_);
235 template <
typename T>
236 inline void check_dependency(detail::ArgHolder<
std::vector<Future<T>>>& arg) {
237 check_dependency(
static_cast<std::vector<Future<T>
>&>(arg));
241 template <
typename T>
242 inline void check_dependency(
std::vector<Future<T>>& vec) {
243 for (
typename std::vector<Future<T>>::iterator it = vec.begin();
244 it != vec.end(); ++it)
245 check_dependency(*it);
249 inline void check_dependency(
const std::vector<Future<void>>&) {}
252 template <
typename T>
253 inline void check_dependency(
const detail::ArgHolder<T>&) {}
256 inline void check_dependency(
const Future<void>&) {}
259 void check_dependencies() {
262 check_dependency(arg1_);
263 check_dependency(arg2_);
264 check_dependency(arg3_);
265 check_dependency(arg4_);
266 check_dependency(arg5_);
267 check_dependency(arg6_);
268 check_dependency(arg7_);
269 check_dependency(arg8_);
270 check_dependency(arg9_);
274 cudaTaskFn(
const cudaTaskFn_&);
275 cudaTaskFn_ operator=(cudaTaskFn_&);
278 #if MADNESS_TASKQ_VARIADICS
280 cudaTaskFn(
const futureT& result, functionT func,
const TaskAttributes& attr)
281 : TaskInterface(attr),
284 async_task_(new AsyncTaskInterface(this)),
295 MADNESS_ASSERT(arity == 0u);
296 check_dependencies();
299 template <
typename a1T>
300 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1,
301 const TaskAttributes& attr)
302 : TaskInterface(attr),
305 async_task_(new AsyncTaskInterface(this)),
307 arg1_(std::forward<a1T>(a1)),
316 MADNESS_ASSERT(arity == 1u);
317 check_dependencies();
320 template <
typename a1T,
typename a2T>
321 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
322 const TaskAttributes& attr)
323 : TaskInterface(attr),
326 async_task_(new AsyncTaskInterface(this)),
328 arg1_(std::forward<a1T>(a1)),
329 arg2_(std::forward<a2T>(a2)),
337 MADNESS_ASSERT(arity == 2u);
338 check_dependencies();
341 template <
typename a1T,
typename a2T,
typename a3T>
342 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
343 a3T&& a3,
const TaskAttributes& attr)
344 : TaskInterface(attr),
347 async_task_(new AsyncTaskInterface(this)),
349 arg1_(std::forward<a1T>(a1)),
350 arg2_(std::forward<a2T>(a2)),
351 arg3_(std::forward<a3T>(a3)),
358 MADNESS_ASSERT(arity == 3u);
359 check_dependencies();
362 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T>
363 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
364 a3T&& a3, a4T&& a4,
const TaskAttributes& attr)
365 : TaskInterface(attr),
368 async_task_(new AsyncTaskInterface(this)),
370 arg1_(std::forward<a1T>(a1)),
371 arg2_(std::forward<a2T>(a2)),
372 arg3_(std::forward<a3T>(a3)),
373 arg4_(std::forward<a4T>(a4)),
379 MADNESS_ASSERT(arity == 4u);
380 check_dependencies();
383 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
385 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
386 a3T&& a3, a4T&& a4, a5T&& a5,
const TaskAttributes& attr)
387 : TaskInterface(attr),
390 async_task_(new AsyncTaskInterface(this)),
392 arg1_(std::forward<a1T>(a1)),
393 arg2_(std::forward<a2T>(a2)),
394 arg3_(std::forward<a3T>(a3)),
395 arg4_(std::forward<a4T>(a4)),
396 arg5_(std::forward<a5T>(a5)),
401 MADNESS_ASSERT(arity == 5u);
402 check_dependencies();
405 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
406 typename a5T,
typename a6T>
407 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
408 a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6,
const TaskAttributes& attr)
409 : TaskInterface(attr),
412 async_task_(new AsyncTaskInterface(this)),
414 arg1_(std::forward<a1T>(a1)),
415 arg2_(std::forward<a2T>(a2)),
416 arg3_(std::forward<a3T>(a3)),
417 arg4_(std::forward<a4T>(a4)),
418 arg5_(std::forward<a5T>(a5)),
419 arg6_(std::forward<a6T>(a6)),
423 MADNESS_ASSERT(arity == 6u);
424 check_dependencies();
427 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
428 typename a5T,
typename a6T,
typename a7T>
429 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
430 a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7,
431 const TaskAttributes& attr)
432 : TaskInterface(attr),
435 async_task_(new AsyncTaskInterface(this)),
437 arg1_(std::forward<a1T>(a1)),
438 arg2_(std::forward<a2T>(a2)),
439 arg3_(std::forward<a3T>(a3)),
440 arg4_(std::forward<a4T>(a4)),
441 arg5_(std::forward<a5T>(a5)),
442 arg6_(std::forward<a6T>(a6)),
443 arg7_(std::forward<a7T>(a7)),
446 MADNESS_ASSERT(arity == 7u);
447 check_dependencies();
450 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
451 typename a5T,
typename a6T,
typename a7T,
typename a8T>
452 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
453 a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7, a8T&& a8,
454 const TaskAttributes& attr)
455 : TaskInterface(attr),
458 async_task_(new AsyncTaskInterface(this)),
460 arg1_(std::forward<a1T>(a1)),
461 arg2_(std::forward<a2T>(a2)),
462 arg3_(std::forward<a3T>(a3)),
463 arg4_(std::forward<a4T>(a4)),
464 arg5_(std::forward<a5T>(a5)),
465 arg6_(std::forward<a6T>(a6)),
466 arg7_(std::forward<a7T>(a7)),
467 arg8_(std::forward<a8T>(a8)),
469 MADNESS_ASSERT(arity == 8u);
470 check_dependencies();
473 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
474 typename a5T,
typename a6T,
typename a7T,
typename a8T,
476 cudaTaskFn(
const futureT& result, functionT func, a1T&& a1, a2T&& a2,
477 a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7, a8T&& a8,
478 a9T&& a9,
const TaskAttributes& attr)
479 : TaskInterface(attr),
482 async_task_(new AsyncTaskInterface(this)),
484 arg1_(std::forward<a1T>(a1)),
485 arg2_(std::forward<a2T>(a2)),
486 arg3_(std::forward<a3T>(a3)),
487 arg4_(std::forward<a4T>(a4)),
488 arg5_(std::forward<a5T>(a5)),
489 arg6_(std::forward<a6T>(a6)),
490 arg7_(std::forward<a7T>(a7)),
491 arg8_(std::forward<a8T>(a8)),
492 arg9_(std::forward<a9T>(a9)) {
493 MADNESS_ASSERT(arity == 9u);
494 check_dependencies();
497 cudaTaskFn(
const futureT& result, functionT func,
const TaskAttributes& attr,
498 archive::BufferInputArchive& input_arch)
499 : TaskInterface(attr),
502 async_task_(new AsyncTaskInterface(this)),
513 check_dependencies();
515 #else // MADNESS_TASKQ_VARIADICS
516 cudaTaskFn(
const futureT& result, functionT func,
const TaskAttributes& attr)
517 : TaskInterface(attr),
520 async_task_(new AsyncTaskInterface(this)),
531 MADNESS_ASSERT(arity == 0u);
532 check_dependencies();
535 template <
typename a1T>
536 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
537 const TaskAttributes& attr)
538 : TaskInterface(attr),
541 async_task_(new AsyncTaskInterface(this)),
552 MADNESS_ASSERT(arity == 1u);
553 check_dependencies();
556 template <
typename a1T,
typename a2T>
557 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
558 const a2T& a2,
const TaskAttributes& attr = TaskAttributes())
559 : TaskInterface(attr),
562 async_task_(new AsyncTaskInterface(this)),
573 MADNESS_ASSERT(arity == 2u);
574 check_dependencies();
577 template <
typename a1T,
typename a2T,
typename a3T>
578 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
579 const a2T& a2,
const a3T& a3,
const TaskAttributes& attr)
580 : TaskInterface(attr),
583 async_task_(new AsyncTaskInterface(this)),
594 MADNESS_ASSERT(arity == 3u);
595 check_dependencies();
598 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T>
599 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
600 const a2T& a2,
const a3T& a3,
const a4T& a4,
601 const TaskAttributes& attr)
602 : TaskInterface(attr),
605 async_task_(new AsyncTaskInterface(this)),
616 MADNESS_ASSERT(arity == 4u);
617 check_dependencies();
620 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
622 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
623 const a2T& a2,
const a3T& a3,
const a4T& a4,
const a5T& a5,
624 const TaskAttributes& attr)
625 : TaskInterface(attr),
628 async_task_(new AsyncTaskInterface(this)),
639 MADNESS_ASSERT(arity == 5u);
640 check_dependencies();
643 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
644 typename a5T,
typename a6T>
645 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
646 const a2T& a2,
const a3T& a3,
const a4T& a4,
const a5T& a5,
647 const a6T& a6,
const TaskAttributes& attr)
648 : TaskInterface(attr),
651 async_task_(new AsyncTaskInterface(this)),
662 MADNESS_ASSERT(arity == 6u);
663 check_dependencies();
666 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
667 typename a5T,
typename a6T,
typename a7T>
668 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
669 const a2T& a2,
const a3T& a3,
const a4T& a4,
const a5T& a5,
670 const a6T& a6,
const a7T& a7,
const TaskAttributes& attr)
671 : TaskInterface(attr),
674 async_task_(new AsyncTaskInterface(this)),
685 MADNESS_ASSERT(arity == 7u);
686 check_dependencies();
689 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
690 typename a5T,
typename a6T,
typename a7T,
typename a8T>
691 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
692 const a2T& a2,
const a3T& a3,
const a4T& a4,
const a5T& a5,
693 const a6T& a6,
const a7T& a7,
const a8T& a8,
694 const TaskAttributes& attr)
695 : TaskInterface(attr),
698 async_task_(new AsyncTaskInterface(this)),
709 MADNESS_ASSERT(arity == 8u);
710 check_dependencies();
713 template <
typename a1T,
typename a2T,
typename a3T,
typename a4T,
714 typename a5T,
typename a6T,
typename a7T,
typename a8T,
716 cudaTaskFn(
const futureT& result, functionT func,
const a1T& a1,
717 const a2T& a2,
const a3T& a3,
const a4T& a4,
const a5T& a5,
718 const a6T& a6,
const a7T& a7,
const a8T& a8,
const a9T& a9,
719 const TaskAttributes& attr)
720 : TaskInterface(attr),
723 async_task_(new AsyncTaskInterface(this)),
734 MADNESS_ASSERT(arity == 9u);
735 check_dependencies();
738 cudaTaskFn(
const futureT& result, functionT func,
const TaskAttributes& attr,
739 archive::BufferInputArchive& input_arch)
740 : TaskInterface(attr),
743 async_task_(new AsyncTaskInterface(this)),
754 check_dependencies();
756 #endif // MADNESS_TASKQ_VARIADICS
759 virtual ~cudaTaskFn() =
default;
761 const futureT& result()
const {
return result_; }
763 TaskInterface* async_task() {
return async_task_; }
765 #ifdef HAVE_INTEL_TBB
766 virtual tbb::task* execute() {
767 result_.set(std::move(async_result_));
774 void run(
const TaskThreadEnv& env)
override {
775 result_.set(std::move(async_result_));
777 #endif // HAVE_INTEL_TBB
794 template <
typename fnT,
typename a1T,
typename a2T,
typename a3T,
typename a4T,
795 typename a5T,
typename a6T,
typename a7T,
typename a8T,
typename a9T>
796 typename cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>::futureT
798 madness::World& world,
799 cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>* t) {
800 typename cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>::futureT
803 world.taskq.add(
static_cast<TaskInterface*
>(t));
805 world.taskq.add(t->async_task());
814 typename fnT,
typename... argsT,
815 typename = std::enable_if_t<!meta::taskattr_is_last_arg<argsT...>::value>>
816 typename detail::function_enabler<fnT(future_to_ref_t<argsT>...)>::type
817 add_cuda_task(madness::World& world, fnT&& fn, argsT&&... args) {
820 cudaTaskFn<std::decay_t<fnT>,
821 std::remove_const_t<std::remove_reference_t<argsT>>...>;
823 return add_cuda_taskfn(
824 world,
new taskT(
typename taskT::futureT(), std::forward<fnT>(fn),
825 std::forward<argsT>(args)..., TaskAttributes()));
833 typename fnT,
typename... argsT,
834 typename = std::enable_if_t<meta::taskattr_is_last_arg<argsT...>::value>>
835 typename meta::drop_last_arg_and_apply_callable<
836 detail::function_enabler, fnT, future_to_ref_t<argsT>...>::type::type
837 add_cuda_task(madness::World& world, fnT&& fn, argsT&&... args) {
839 using taskT =
typename meta::drop_last_arg_and_apply<
840 cudaTaskFn, std::decay_t<fnT>,
841 std::remove_const_t<std::remove_reference_t<argsT>>...>::type;
843 return add_cuda_taskfn(
844 world,
new taskT(
typename taskT::futureT(), std::forward<fnT>(fn),
845 std::forward<argsT>(args)...));
853 template <
typename objT,
typename memfnT,
typename... argsT>
854 typename detail::memfunc_enabler<objT, memfnT>::type add_cuda_task(
855 madness::World& world, objT&& obj, memfnT memfn, argsT&&... args) {
856 return add_cuda_task(world,
857 detail::wrap_mem_fn(std::forward<objT>(obj), memfn),
858 std::forward<argsT>(args)...);
863 #endif // TILDARRAY_HAS_CUDA
864 #endif // TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED