cuda_task_fn.h
Go to the documentation of this file.
1 //
2 // Created by Chong Peng on 2019-03-20.
3 //
4 
5 #ifndef TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED
6 #define TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED
7 
8 #include <TiledArray/config.h>
9 
10 #ifdef TILEDARRAY_HAS_CUDA
11 
13 #include <TiledArray/util/time.h>
14 #include <cuda_runtime.h>
15 #include <madness/world/taskfn.h>
16 
17 namespace TiledArray {
18 namespace detail {
19 
20 template <int64_t CallabackId>
21 std::atomic<int64_t>& cuda_callback_duration_ns() {
22  static std::atomic<int64_t> value{0};
23  return value;
24 }
25 
26 inline std::atomic<int64_t>& cuda_taskfn_callback_duration_ns() {
27  static std::atomic<int64_t> value{0};
28  return value;
29 }
30 
31 } // namespace detail
32 } // namespace TiledArray
33 
34 namespace madness {
35 
42 
43 template <typename fnT, typename arg1T = void, typename arg2T = void,
44  typename arg3T = void, typename arg4T = void, typename arg5T = void,
45  typename arg6T = void, typename arg7T = void, typename arg8T = void,
46  typename arg9T = void>
47 struct cudaTaskFn : public TaskInterface {
48  static_assert(not(std::is_const<arg1T>::value ||
49  std::is_reference<arg1T>::value),
50  "improper instantiation of cudaTaskFn, arg1T cannot be a const "
51  "or reference type");
52  static_assert(not(std::is_const<arg2T>::value ||
53  std::is_reference<arg2T>::value),
54  "improper instantiation of cudaTaskFn, arg2T cannot be a const "
55  "or reference type");
56  static_assert(not(std::is_const<arg3T>::value ||
57  std::is_reference<arg3T>::value),
58  "improper instantiation of cudaTaskFn, arg3T cannot be a const "
59  "or reference type");
60  static_assert(not(std::is_const<arg4T>::value ||
61  std::is_reference<arg4T>::value),
62  "improper instantiation of cudaTaskFn, arg4T cannot be a const "
63  "or reference type");
64  static_assert(not(std::is_const<arg5T>::value ||
65  std::is_reference<arg5T>::value),
66  "improper instantiation of cudaTaskFn, arg5T cannot be a const "
67  "or reference type");
68  static_assert(not(std::is_const<arg6T>::value ||
69  std::is_reference<arg6T>::value),
70  "improper instantiation of cudaTaskFn, arg6T cannot be a const "
71  "or reference type");
72  static_assert(not(std::is_const<arg7T>::value ||
73  std::is_reference<arg7T>::value),
74  "improper instantiation of cudaTaskFn, arg7T cannot be a const "
75  "or reference type");
76  static_assert(not(std::is_const<arg8T>::value ||
77  std::is_reference<arg8T>::value),
78  "improper instantiation of cudaTaskFn, arg8T cannot be a const "
79  "or reference type");
80  static_assert(not(std::is_const<arg9T>::value ||
81  std::is_reference<arg9T>::value),
82  "improper instantiation of cudaTaskFn, arg9T cannot be a const "
83  "or reference type");
84 
85  private:
87  typedef cudaTaskFn<fnT, arg1T, arg2T, arg3T, arg4T, arg5T, arg6T, arg7T,
88  arg8T, arg9T>
89  cudaTaskFn_;
90 
91  friend class AsyncTaskInterface;
92 
94  struct AsyncTaskInterface : public madness::TaskInterface {
95  AsyncTaskInterface(cudaTaskFn_* task, int ndepend = 0,
96  const TaskAttributes attr = TaskAttributes())
97  : TaskInterface(ndepend, attr), task_(task) {}
98 
99  virtual ~AsyncTaskInterface() = default;
100 
101  protected:
102  void run(const TaskThreadEnv& env) override {
103  // run the async function, the function must call synchronize_stream() to
104  // set the stream it used!!
105  task_->run_async();
106 
107  // get the stream used by async function
108  auto stream = TiledArray::tls_cudastream_accessor();
109 
110  // TA_ASSERT(stream != nullptr);
111 
112  // WARNING, need to handle NoOp
113  if (stream == nullptr) {
114  task_->notify();
115  } else {
116  // TODO should we use cuda callback or cuda events??
117  // insert cuda callback
118  cudaLaunchHostFunc(*stream, cuda_callback, task_);
119  // reset stream to nullptr
120  TiledArray::synchronize_stream(nullptr);
121  }
122  }
123 
124  private:
125  static void CUDART_CB cuda_callback(void* userData) {
126  const auto t0 = TiledArray::now();
127  // convert void * to AsyncTaskInterface*
128  auto* callback = static_cast<cudaTaskFn_*>(userData);
129  // std::stringstream address;
130  // address << (void*) callback;
131  // std::string message = "callback on cudaTaskFn: " + address.str() +
132  // '\n'; std::cout << message;
133  callback->notify();
134  const auto t1 = TiledArray::now();
135 
136  TiledArray::detail::cuda_taskfn_callback_duration_ns() +=
138  }
139 
140  cudaTaskFn_* task_;
141  };
142 
143  public:
144  typedef fnT functionT;
145  typedef typename detail::task_result_type<fnT>::resultT resultT;
147  typedef typename detail::task_result_type<fnT>::futureT futureT;
148 
149  // argument value typedefs
150 
151  static const unsigned int arity =
152  detail::ArgCount<arg1T, arg2T, arg3T, arg4T, arg5T, arg6T, arg7T, arg8T,
153  arg9T>::value;
157 
158  private:
159  futureT result_;
160  const functionT func_;
161  TaskInterface* async_task_;
162  futureT async_result_;
164 
165  // If the value of the argument is known at the time the
166  // Note: The type argNT for argN, where N is > arity should be void
167 
168  typename detail::task_arg<arg1T>::holderT
169  arg1_;
170  typename detail::task_arg<arg2T>::holderT
171  arg2_;
172  typename detail::task_arg<arg3T>::holderT
173  arg3_;
174  typename detail::task_arg<arg4T>::holderT
175  arg4_;
176  typename detail::task_arg<arg5T>::holderT
177  arg5_;
178  typename detail::task_arg<arg6T>::holderT
179  arg6_;
180  typename detail::task_arg<arg7T>::holderT
181  arg7_;
182  typename detail::task_arg<arg8T>::holderT
183  arg8_;
184  typename detail::task_arg<arg9T>::holderT
185  arg9_;
186 
187  template <typename fT>
188  static fT& get_func(fT& f) {
189  return f;
190  }
191 
192  template <typename ptrT, typename memfnT, typename resT>
193  static memfnT get_func(
194  const detail::MemFuncWrapper<ptrT, memfnT, resT>& wrapper) {
195  return detail::get_mem_func_ptr(wrapper);
196  }
197 
198  void get_id(std::pair<void*, unsigned short>& id) const override {
199  return make_id(id, get_func(func_));
200  }
201 
203  void run_async() {
204  detail::run_function(async_result_, func_, arg1_, arg2_, arg3_, arg4_,
205  arg5_, arg6_, arg7_, arg8_, arg9_);
206  }
207 
210 
213  template <typename T>
214  inline void check_dependency(Future<T>& fut) {
215  if (!fut.probe()) {
216  async_task_->inc();
217  fut.register_callback(async_task_);
218  }
219  }
220 
223 
226  template <typename T>
227  inline void check_dependency(Future<T>* fut) {
228  if (!fut->probe()) {
229  async_task_->inc();
230  fut->register_callback(async_task_);
231  }
232  }
233 
235  template <typename T>
236  inline void check_dependency(detail::ArgHolder<std::vector<Future<T>>>& arg) {
237  check_dependency(static_cast<std::vector<Future<T>>&>(arg));
238  }
239 
241  template <typename T>
242  inline void check_dependency(std::vector<Future<T>>& vec) {
243  for (typename std::vector<Future<T>>::iterator it = vec.begin();
244  it != vec.end(); ++it)
245  check_dependency(*it);
246  }
247 
249  inline void check_dependency(const std::vector<Future<void>>&) {}
250 
252  template <typename T>
253  inline void check_dependency(const detail::ArgHolder<T>&) {}
254 
256  inline void check_dependency(const Future<void>&) {}
257 
259  void check_dependencies() {
260  this->inc(); // the current cudaTaskFn depends on the internal
261  // AsyncTaskInterface, dependency = 1
262  check_dependency(arg1_);
263  check_dependency(arg2_);
264  check_dependency(arg3_);
265  check_dependency(arg4_);
266  check_dependency(arg5_);
267  check_dependency(arg6_);
268  check_dependency(arg7_);
269  check_dependency(arg8_);
270  check_dependency(arg9_);
271  }
272 
273  // Copies are not allowed.
274  cudaTaskFn(const cudaTaskFn_&);
275  cudaTaskFn_ operator=(cudaTaskFn_&);
276 
277  public:
278 #if MADNESS_TASKQ_VARIADICS
279 
280  cudaTaskFn(const futureT& result, functionT func, const TaskAttributes& attr)
281  : TaskInterface(attr),
282  result_(result),
283  func_(func),
284  async_task_(new AsyncTaskInterface(this)),
285  async_result_(),
286  arg1_(),
287  arg2_(),
288  arg3_(),
289  arg4_(),
290  arg5_(),
291  arg6_(),
292  arg7_(),
293  arg8_(),
294  arg9_() {
295  MADNESS_ASSERT(arity == 0u);
296  check_dependencies();
297  }
298 
299  template <typename a1T>
300  cudaTaskFn(const futureT& result, functionT func, a1T&& a1,
301  const TaskAttributes& attr)
302  : TaskInterface(attr),
303  result_(result),
304  func_(func),
305  async_task_(new AsyncTaskInterface(this)),
306  async_result_(),
307  arg1_(std::forward<a1T>(a1)),
308  arg2_(),
309  arg3_(),
310  arg4_(),
311  arg5_(),
312  arg6_(),
313  arg7_(),
314  arg8_(),
315  arg9_() {
316  MADNESS_ASSERT(arity == 1u);
317  check_dependencies();
318  }
319 
320  template <typename a1T, typename a2T>
321  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
322  const TaskAttributes& attr)
323  : TaskInterface(attr),
324  result_(result),
325  func_(func),
326  async_task_(new AsyncTaskInterface(this)),
327  async_result_(),
328  arg1_(std::forward<a1T>(a1)),
329  arg2_(std::forward<a2T>(a2)),
330  arg3_(),
331  arg4_(),
332  arg5_(),
333  arg6_(),
334  arg7_(),
335  arg8_(),
336  arg9_() {
337  MADNESS_ASSERT(arity == 2u);
338  check_dependencies();
339  }
340 
341  template <typename a1T, typename a2T, typename a3T>
342  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
343  a3T&& a3, const TaskAttributes& attr)
344  : TaskInterface(attr),
345  result_(result),
346  func_(func),
347  async_task_(new AsyncTaskInterface(this)),
348  async_result_(),
349  arg1_(std::forward<a1T>(a1)),
350  arg2_(std::forward<a2T>(a2)),
351  arg3_(std::forward<a3T>(a3)),
352  arg4_(),
353  arg5_(),
354  arg6_(),
355  arg7_(),
356  arg8_(),
357  arg9_() {
358  MADNESS_ASSERT(arity == 3u);
359  check_dependencies();
360  }
361 
362  template <typename a1T, typename a2T, typename a3T, typename a4T>
363  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
364  a3T&& a3, a4T&& a4, const TaskAttributes& attr)
365  : TaskInterface(attr),
366  result_(result),
367  func_(func),
368  async_task_(new AsyncTaskInterface(this)),
369  async_result_(),
370  arg1_(std::forward<a1T>(a1)),
371  arg2_(std::forward<a2T>(a2)),
372  arg3_(std::forward<a3T>(a3)),
373  arg4_(std::forward<a4T>(a4)),
374  arg5_(),
375  arg6_(),
376  arg7_(),
377  arg8_(),
378  arg9_() {
379  MADNESS_ASSERT(arity == 4u);
380  check_dependencies();
381  }
382 
383  template <typename a1T, typename a2T, typename a3T, typename a4T,
384  typename a5T>
385  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
386  a3T&& a3, a4T&& a4, a5T&& a5, const TaskAttributes& attr)
387  : TaskInterface(attr),
388  result_(result),
389  func_(func),
390  async_task_(new AsyncTaskInterface(this)),
391  async_result_(),
392  arg1_(std::forward<a1T>(a1)),
393  arg2_(std::forward<a2T>(a2)),
394  arg3_(std::forward<a3T>(a3)),
395  arg4_(std::forward<a4T>(a4)),
396  arg5_(std::forward<a5T>(a5)),
397  arg6_(),
398  arg7_(),
399  arg8_(),
400  arg9_() {
401  MADNESS_ASSERT(arity == 5u);
402  check_dependencies();
403  }
404 
405  template <typename a1T, typename a2T, typename a3T, typename a4T,
406  typename a5T, typename a6T>
407  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
408  a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, const TaskAttributes& attr)
409  : TaskInterface(attr),
410  result_(result),
411  func_(func),
412  async_task_(new AsyncTaskInterface(this)),
413  async_result_(),
414  arg1_(std::forward<a1T>(a1)),
415  arg2_(std::forward<a2T>(a2)),
416  arg3_(std::forward<a3T>(a3)),
417  arg4_(std::forward<a4T>(a4)),
418  arg5_(std::forward<a5T>(a5)),
419  arg6_(std::forward<a6T>(a6)),
420  arg7_(),
421  arg8_(),
422  arg9_() {
423  MADNESS_ASSERT(arity == 6u);
424  check_dependencies();
425  }
426 
427  template <typename a1T, typename a2T, typename a3T, typename a4T,
428  typename a5T, typename a6T, typename a7T>
429  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
430  a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7,
431  const TaskAttributes& attr)
432  : TaskInterface(attr),
433  result_(result),
434  func_(func),
435  async_task_(new AsyncTaskInterface(this)),
436  async_result_(),
437  arg1_(std::forward<a1T>(a1)),
438  arg2_(std::forward<a2T>(a2)),
439  arg3_(std::forward<a3T>(a3)),
440  arg4_(std::forward<a4T>(a4)),
441  arg5_(std::forward<a5T>(a5)),
442  arg6_(std::forward<a6T>(a6)),
443  arg7_(std::forward<a7T>(a7)),
444  arg8_(),
445  arg9_() {
446  MADNESS_ASSERT(arity == 7u);
447  check_dependencies();
448  }
449 
450  template <typename a1T, typename a2T, typename a3T, typename a4T,
451  typename a5T, typename a6T, typename a7T, typename a8T>
452  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
453  a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7, a8T&& a8,
454  const TaskAttributes& attr)
455  : TaskInterface(attr),
456  result_(result),
457  func_(func),
458  async_task_(new AsyncTaskInterface(this)),
459  async_result_(),
460  arg1_(std::forward<a1T>(a1)),
461  arg2_(std::forward<a2T>(a2)),
462  arg3_(std::forward<a3T>(a3)),
463  arg4_(std::forward<a4T>(a4)),
464  arg5_(std::forward<a5T>(a5)),
465  arg6_(std::forward<a6T>(a6)),
466  arg7_(std::forward<a7T>(a7)),
467  arg8_(std::forward<a8T>(a8)),
468  arg9_() {
469  MADNESS_ASSERT(arity == 8u);
470  check_dependencies();
471  }
472 
473  template <typename a1T, typename a2T, typename a3T, typename a4T,
474  typename a5T, typename a6T, typename a7T, typename a8T,
475  typename a9T>
476  cudaTaskFn(const futureT& result, functionT func, a1T&& a1, a2T&& a2,
477  a3T&& a3, a4T&& a4, a5T&& a5, a6T&& a6, a7T&& a7, a8T&& a8,
478  a9T&& a9, const TaskAttributes& attr)
479  : TaskInterface(attr),
480  result_(result),
481  func_(func),
482  async_task_(new AsyncTaskInterface(this)),
483  async_result_(),
484  arg1_(std::forward<a1T>(a1)),
485  arg2_(std::forward<a2T>(a2)),
486  arg3_(std::forward<a3T>(a3)),
487  arg4_(std::forward<a4T>(a4)),
488  arg5_(std::forward<a5T>(a5)),
489  arg6_(std::forward<a6T>(a6)),
490  arg7_(std::forward<a7T>(a7)),
491  arg8_(std::forward<a8T>(a8)),
492  arg9_(std::forward<a9T>(a9)) {
493  MADNESS_ASSERT(arity == 9u);
494  check_dependencies();
495  }
496 
497  cudaTaskFn(const futureT& result, functionT func, const TaskAttributes& attr,
498  archive::BufferInputArchive& input_arch)
499  : TaskInterface(attr),
500  result_(result),
501  func_(func),
502  async_task_(new AsyncTaskInterface(this)),
503  async_result_(),
504  arg1_(input_arch),
505  arg2_(input_arch),
506  arg3_(input_arch),
507  arg4_(input_arch),
508  arg5_(input_arch),
509  arg6_(input_arch),
510  arg7_(input_arch),
511  arg8_(input_arch),
512  arg9_(input_arch) {
513  check_dependencies();
514  }
515 #else // MADNESS_TASKQ_VARIADICS
516  cudaTaskFn(const futureT& result, functionT func, const TaskAttributes& attr)
517  : TaskInterface(attr),
518  result_(result),
519  func_(func),
520  async_task_(new AsyncTaskInterface(this)),
521  async_result_(),
522  arg1_(),
523  arg2_(),
524  arg3_(),
525  arg4_(),
526  arg5_(),
527  arg6_(),
528  arg7_(),
529  arg8_(),
530  arg9_() {
531  MADNESS_ASSERT(arity == 0u);
532  check_dependencies();
533  }
534 
535  template <typename a1T>
536  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
537  const TaskAttributes& attr)
538  : TaskInterface(attr),
539  result_(result),
540  func_(func),
541  async_task_(new AsyncTaskInterface(this)),
542  async_result_(),
543  arg1_(a1),
544  arg2_(),
545  arg3_(),
546  arg4_(),
547  arg5_(),
548  arg6_(),
549  arg7_(),
550  arg8_(),
551  arg9_() {
552  MADNESS_ASSERT(arity == 1u);
553  check_dependencies();
554  }
555 
556  template <typename a1T, typename a2T>
557  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
558  const a2T& a2, const TaskAttributes& attr = TaskAttributes())
559  : TaskInterface(attr),
560  result_(result),
561  func_(func),
562  async_task_(new AsyncTaskInterface(this)),
563  async_result_(),
564  arg1_(a1),
565  arg2_(a2),
566  arg3_(),
567  arg4_(),
568  arg5_(),
569  arg6_(),
570  arg7_(),
571  arg8_(),
572  arg9_() {
573  MADNESS_ASSERT(arity == 2u);
574  check_dependencies();
575  }
576 
577  template <typename a1T, typename a2T, typename a3T>
578  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
579  const a2T& a2, const a3T& a3, const TaskAttributes& attr)
580  : TaskInterface(attr),
581  result_(result),
582  func_(func),
583  async_task_(new AsyncTaskInterface(this)),
584  async_result_(),
585  arg1_(a1),
586  arg2_(a2),
587  arg3_(a3),
588  arg4_(),
589  arg5_(),
590  arg6_(),
591  arg7_(),
592  arg8_(),
593  arg9_() {
594  MADNESS_ASSERT(arity == 3u);
595  check_dependencies();
596  }
597 
598  template <typename a1T, typename a2T, typename a3T, typename a4T>
599  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
600  const a2T& a2, const a3T& a3, const a4T& a4,
601  const TaskAttributes& attr)
602  : TaskInterface(attr),
603  result_(result),
604  func_(func),
605  async_task_(new AsyncTaskInterface(this)),
606  async_result_(),
607  arg1_(a1),
608  arg2_(a2),
609  arg3_(a3),
610  arg4_(a4),
611  arg5_(),
612  arg6_(),
613  arg7_(),
614  arg8_(),
615  arg9_() {
616  MADNESS_ASSERT(arity == 4u);
617  check_dependencies();
618  }
619 
620  template <typename a1T, typename a2T, typename a3T, typename a4T,
621  typename a5T>
622  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
623  const a2T& a2, const a3T& a3, const a4T& a4, const a5T& a5,
624  const TaskAttributes& attr)
625  : TaskInterface(attr),
626  result_(result),
627  func_(func),
628  async_task_(new AsyncTaskInterface(this)),
629  async_result_(),
630  arg1_(a1),
631  arg2_(a2),
632  arg3_(a3),
633  arg4_(a4),
634  arg5_(a5),
635  arg6_(),
636  arg7_(),
637  arg8_(),
638  arg9_() {
639  MADNESS_ASSERT(arity == 5u);
640  check_dependencies();
641  }
642 
643  template <typename a1T, typename a2T, typename a3T, typename a4T,
644  typename a5T, typename a6T>
645  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
646  const a2T& a2, const a3T& a3, const a4T& a4, const a5T& a5,
647  const a6T& a6, const TaskAttributes& attr)
648  : TaskInterface(attr),
649  result_(result),
650  func_(func),
651  async_task_(new AsyncTaskInterface(this)),
652  async_result_(),
653  arg1_(a1),
654  arg2_(a2),
655  arg3_(a3),
656  arg4_(a4),
657  arg5_(a5),
658  arg6_(a6),
659  arg7_(),
660  arg8_(),
661  arg9_() {
662  MADNESS_ASSERT(arity == 6u);
663  check_dependencies();
664  }
665 
666  template <typename a1T, typename a2T, typename a3T, typename a4T,
667  typename a5T, typename a6T, typename a7T>
668  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
669  const a2T& a2, const a3T& a3, const a4T& a4, const a5T& a5,
670  const a6T& a6, const a7T& a7, const TaskAttributes& attr)
671  : TaskInterface(attr),
672  result_(result),
673  func_(func),
674  async_task_(new AsyncTaskInterface(this)),
675  async_result_(),
676  arg1_(a1),
677  arg2_(a2),
678  arg3_(a3),
679  arg4_(a4),
680  arg5_(a5),
681  arg6_(a6),
682  arg7_(a7),
683  arg8_(),
684  arg9_() {
685  MADNESS_ASSERT(arity == 7u);
686  check_dependencies();
687  }
688 
689  template <typename a1T, typename a2T, typename a3T, typename a4T,
690  typename a5T, typename a6T, typename a7T, typename a8T>
691  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
692  const a2T& a2, const a3T& a3, const a4T& a4, const a5T& a5,
693  const a6T& a6, const a7T& a7, const a8T& a8,
694  const TaskAttributes& attr)
695  : TaskInterface(attr),
696  result_(result),
697  func_(func),
698  async_task_(new AsyncTaskInterface(this)),
699  async_result_(),
700  arg1_(a1),
701  arg2_(a2),
702  arg3_(a3),
703  arg4_(a4),
704  arg5_(a5),
705  arg6_(a6),
706  arg7_(a7),
707  arg8_(a8),
708  arg9_() {
709  MADNESS_ASSERT(arity == 8u);
710  check_dependencies();
711  }
712 
713  template <typename a1T, typename a2T, typename a3T, typename a4T,
714  typename a5T, typename a6T, typename a7T, typename a8T,
715  typename a9T>
716  cudaTaskFn(const futureT& result, functionT func, const a1T& a1,
717  const a2T& a2, const a3T& a3, const a4T& a4, const a5T& a5,
718  const a6T& a6, const a7T& a7, const a8T& a8, const a9T& a9,
719  const TaskAttributes& attr)
720  : TaskInterface(attr),
721  result_(result),
722  func_(func),
723  async_task_(new AsyncTaskInterface(this)),
724  async_result_(),
725  arg1_(a1),
726  arg2_(a2),
727  arg3_(a3),
728  arg4_(a4),
729  arg5_(a5),
730  arg6_(a6),
731  arg7_(a7),
732  arg8_(a8),
733  arg9_(a9) {
734  MADNESS_ASSERT(arity == 9u);
735  check_dependencies();
736  }
737 
738  cudaTaskFn(const futureT& result, functionT func, const TaskAttributes& attr,
739  archive::BufferInputArchive& input_arch)
740  : TaskInterface(attr),
741  result_(result),
742  func_(func),
743  async_task_(new AsyncTaskInterface(this)),
744  async_result_(),
745  arg1_(input_arch),
746  arg2_(input_arch),
747  arg3_(input_arch),
748  arg4_(input_arch),
749  arg5_(input_arch),
750  arg6_(input_arch),
751  arg7_(input_arch),
752  arg8_(input_arch),
753  arg9_(input_arch) {
754  check_dependencies();
755  }
756 #endif // MADNESS_TASKQ_VARIADICS
757 
758  // no need to delete async_task_, as it will be deleted by the TaskQueue
759  virtual ~cudaTaskFn() = default;
760 
761  const futureT& result() const { return result_; }
762 
763  TaskInterface* async_task() { return async_task_; }
764 
765 #ifdef HAVE_INTEL_TBB
766  virtual tbb::task* execute() {
767  result_.set(std::move(async_result_));
768  return nullptr;
769  }
770 #else
771  protected:
774  void run(const TaskThreadEnv& env) override {
775  result_.set(std::move(async_result_));
776  }
777 #endif // HAVE_INTEL_TBB
778 
779 }; // class cudaTaskFn
780 
794 template <typename fnT, typename a1T, typename a2T, typename a3T, typename a4T,
795  typename a5T, typename a6T, typename a7T, typename a8T, typename a9T>
796 typename cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>::futureT
797 add_cuda_taskfn(
798  madness::World& world,
799  cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>* t) {
800  typename cudaTaskFn<fnT, a1T, a2T, a3T, a4T, a5T, a6T, a7T, a8T, a9T>::futureT
801  res(t->result());
802  // add the cuda task
803  world.taskq.add(static_cast<TaskInterface*>(t));
804  // add the internal async task in cuda task as well
805  world.taskq.add(t->async_task());
806  return res;
807 }
808 
813 template <
814  typename fnT, typename... argsT,
815  typename = std::enable_if_t<!meta::taskattr_is_last_arg<argsT...>::value>>
816 typename detail::function_enabler<fnT(future_to_ref_t<argsT>...)>::type
817 add_cuda_task(madness::World& world, fnT&& fn, argsT&&... args) {
819  using taskT =
820  cudaTaskFn<std::decay_t<fnT>,
821  std::remove_const_t<std::remove_reference_t<argsT>>...>;
822 
823  return add_cuda_taskfn(
824  world, new taskT(typename taskT::futureT(), std::forward<fnT>(fn),
825  std::forward<argsT>(args)..., TaskAttributes()));
826 }
827 
832 template <
833  typename fnT, typename... argsT,
834  typename = std::enable_if_t<meta::taskattr_is_last_arg<argsT...>::value>>
835 typename meta::drop_last_arg_and_apply_callable<
836  detail::function_enabler, fnT, future_to_ref_t<argsT>...>::type::type
837 add_cuda_task(madness::World& world, fnT&& fn, argsT&&... args) {
839  using taskT = typename meta::drop_last_arg_and_apply<
840  cudaTaskFn, std::decay_t<fnT>,
841  std::remove_const_t<std::remove_reference_t<argsT>>...>::type;
842 
843  return add_cuda_taskfn(
844  world, new taskT(typename taskT::futureT(), std::forward<fnT>(fn),
845  std::forward<argsT>(args)...));
846 }
847 
853 template <typename objT, typename memfnT, typename... argsT>
854 typename detail::memfunc_enabler<objT, memfnT>::type add_cuda_task(
855  madness::World& world, objT&& obj, memfnT memfn, argsT&&... args) {
856  return add_cuda_task(world,
857  detail::wrap_mem_fn(std::forward<objT>(obj), memfn),
858  std::forward<argsT>(args)...);
859 }
860 
861 } // namespace madness
862 
863 #endif // TILDARRAY_HAS_CUDA
864 #endif // TILEDARRAY_CUDA_CUDA_TASK_FN_H__INCLUDED
int64_t duration_in_ns(time_point const &t0, time_point const &t1)
Definition: time.h:45
time_point now()
Definition: time.h:35
std::vector< T > vector
Definition: vector.h:41