20 #ifndef TILEDARRAY_REDUCE_TASK_H__INCLUDED
21 #define TILEDARRAY_REDUCE_TASK_H__INCLUDED
23 #include <TiledArray/config.h>
27 #ifdef TILEDARRAY_HAS_CUDA
47 template <
typename T,
typename U>
55 template <
typename opT>
60 typedef typename std::remove_cv<
typename std::remove_reference<
63 typedef typename std::remove_cv<
typename std::remove_reference<
117 op_(result, arg.first, arg.second);
199 template <
typename opT>
202 typedef typename opT::result_type result_type;
203 typedef typename std::remove_const<
204 typename std::remove_reference<typename opT::argument_type>::type>::type
212 class ReduceTaskImpl :
public madness::TaskInterface {
220 ReduceTaskImpl* parent_;
223 madness::CallbackInterface* callback_;
224 madness::AtomicInt count_;
230 template <
typename T>
233 parent_->ready(
this);
236 f.register_callback(
this);
245 template <
typename T,
typename U>
247 if (p.first.probe() && p.second.probe()) {
248 parent_->ready(
this);
251 p.first.register_callback(
this);
252 p.second.register_callback(
this);
264 template <
typename Arg>
266 madness::CallbackInterface* callback)
267 : parent_(parent), arg_(
arg), callback_(callback) {
269 register_callbacks(arg_);
276 if ((--count_) == 0) parent_->ready(
this);
282 const argument_type&
arg()
const {
return arg_; }
289 static constexpr
const bool trace_tasks =
290 #ifdef TILEDARRAY_ENABLE_TASK_DEBUG_TRACE
296 if (object->callback_) {
298 object->callback_->notify_debug(
"destroy(*ReduceObject)");
300 object->callback_->notify();
307 #ifdef TILEDARRAY_HAS_CUDA
309 static void CUDART_CB cuda_reduceobject_delete_callback(
void* userData) {
312 std::vector<void*>* objects =
static_cast<std::vector<void*>*
>(userData);
315 madness::World* world =
static_cast<madness::World*
>((*objects)[0]);
318 std::size_t n_objects = objects->size();
320 for (std::size_t i = 1; i < n_objects; i++) {
322 ReduceObject* reduce_object =
323 static_cast<ReduceObject*
>((*objects)[i]);
336 world->taskq.add(
destroy_vector, objects, TaskAttributes::hipri());
339 TiledArray::detail::cuda_callback_duration_ns<0>() +=
343 static void CUDART_CB cuda_dependency_dec_callback(
void* userData) {
346 std::vector<void*>* objects =
static_cast<std::vector<void*>*
>(userData);
348 for (
auto& item : *objects) {
350 ReduceTaskImpl* dep =
static_cast<ReduceTaskImpl*
>(item);
360 TiledArray::detail::cuda_callback_duration_ns<1>() +=
364 static void CUDART_CB
365 cuda_dependency_dec_reduceobject_delete_callback(
void* userData) {
368 std::vector<void*>* objects =
static_cast<std::vector<void*>*
>(userData);
370 assert(objects->size() == 3);
373 madness::World* world =
static_cast<madness::World*
>(objects->at(0));
376 ReduceTaskImpl* dep =
static_cast<ReduceTaskImpl*
>(objects->at(1));
381 ReduceObject* reduce_object =
static_cast<ReduceObject*
>(objects->at(2));
383 auto destroy = [](ReduceObject* object) {
391 world->taskq.add(destroy, reduce_object, TaskAttributes::hipri());
396 TiledArray::detail::cuda_callback_duration_ns<2>() +=
400 static void CUDART_CB cuda_readyresult_reset_callback(
void* userData) {
403 std::vector<void*>* objects =
static_cast<std::vector<void*>*
>(userData);
406 madness::World* world =
static_cast<madness::World*
>((*objects)[0]);
408 auto reset = [](std::vector<void*>* objects) {
411 std::shared_ptr<result_type>* result =
412 static_cast<std::shared_ptr<result_type>*
>((*objects)[1]);
421 world->taskq.add(reset, objects, TaskAttributes::hipri());
424 TiledArray::detail::cuda_callback_duration_ns<3>() +=
429 virtual void get_id(std::pair<void*, unsigned short>&
id)
const {
430 return PoolTaskInterface::make_id(
id, *
this);
441 void reduce(std::shared_ptr<result_type>& result) {
446 ReduceObject* ready_object =
const_cast<ReduceObject*
>(ready_object_);
447 ready_object_ =
nullptr;
451 op_(*result, ready_object->arg());
454 #ifdef TILEDARRAY_HAS_CUDA
455 auto stream_ptr = tls_cudastream_accessor();
458 if (stream_ptr ==
nullptr) {
462 auto callback_object =
new std::vector<void*>(3);
463 (*callback_object)[0] = &world_;
464 (*callback_object)[1] =
this;
465 (*callback_object)[2] = ready_object;
467 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
468 CudaSafeCall(cudaLaunchHostFunc(
469 *stream_ptr, cuda_dependency_dec_reduceobject_delete_callback,
471 synchronize_stream(
nullptr);
479 }
else if (ready_result_) {
481 std::shared_ptr<result_type> ready_result = ready_result_;
482 ready_result_.reset();
486 op_(*result, *ready_result);
489 #ifdef TILEDARRAY_HAS_CUDA
490 auto stream_ptr = tls_cudastream_accessor();
491 if (stream_ptr ==
nullptr) {
492 ready_result.reset();
494 auto ready_result_heap =
495 new std::shared_ptr<result_type>(ready_result);
496 auto callback_object =
new std::vector<void*>(2);
497 (*callback_object)[0] = &world_;
498 (*callback_object)[1] = ready_result_heap;
500 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
501 CudaSafeCall(cudaLaunchHostFunc(
502 *stream_ptr, cuda_readyresult_reset_callback, callback_object));
503 synchronize_stream(
nullptr);
508 ready_result.reset();
512 ready_result_ = result;
523 void reduce_result_object(std::shared_ptr<result_type> result,
524 const ReduceObject*
object) {
526 op_(*result, object->arg());
529 #ifdef TILEDARRAY_HAS_CUDA
530 auto stream_ptr = tls_cudastream_accessor();
531 if (stream_ptr ==
nullptr) {
534 auto callback_object =
new std::vector<void*>(2);
535 (*callback_object)[0] = &world_;
536 (*callback_object)[1] =
const_cast<ReduceObject*
>(object);
538 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
539 CudaSafeCall(cudaLaunchHostFunc(
540 *stream_ptr, cuda_reduceobject_delete_callback, callback_object));
541 synchronize_stream(
nullptr);
552 #ifdef TILEDARRAY_HAS_CUDA
553 if (stream_ptr ==
nullptr) {
556 auto callback_object2 =
new std::vector<void*>(1);
557 (*callback_object2)[0] =
this;
559 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
560 CudaSafeCall(cudaLaunchHostFunc(
561 *stream_ptr, cuda_dependency_dec_callback, callback_object2));
570 void reduce_object_object(
const ReduceObject* object1,
571 const ReduceObject* object2) {
573 auto result = std::make_shared<result_type>(op_());
576 op_(*result, object1->arg());
577 op_(*result, object2->arg());
580 #ifdef TILEDARRAY_HAS_CUDA
581 auto stream_ptr = tls_cudastream_accessor();
582 if (stream_ptr ==
nullptr) {
586 auto callback_object1 =
new std::vector<void*>(3);
587 (*callback_object1)[0] = &world_;
588 (*callback_object1)[1] =
const_cast<ReduceObject*
>(object1);
589 (*callback_object1)[2] =
const_cast<ReduceObject*
>(object2);
591 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
592 CudaSafeCall(cudaLaunchHostFunc(
593 *stream_ptr, cuda_reduceobject_delete_callback, callback_object1));
594 synchronize_stream(
nullptr);
607 #ifdef TILEDARRAY_HAS_CUDA
608 if (stream_ptr ==
nullptr) {
612 auto callback_object2 =
new std::vector<void*>(2);
613 (*callback_object2)[0] =
this;
614 (*callback_object2)[1] =
this;
616 cudaSetDevice(cudaEnv::instance()->current_cuda_device_id()));
617 CudaSafeCall(cudaLaunchHostFunc(
618 *stream_ptr, cuda_dependency_dec_callback, callback_object2));
628 #ifdef TILEDARRAY_HAS_CUDA
629 template <
typename Result = result_type>
630 std::enable_if_t<detail::is_cuda_tile_v<Result>,
void> internal_run(
631 const madness::TaskThreadEnv&) {
634 auto post_result = madness::add_cuda_task(world_, op_, *ready_result_);
635 result_.set(post_result);
638 result_.register_callback(callback_);
642 template <
typename Result = result_type>
643 std::enable_if_t<!detail::is_cuda_tile_v<Result>,
void>
647 internal_run(
const madness::TaskThreadEnv&) {
649 result_.set(op_(*ready_result_));
651 if (callback_) callback_->notify();
656 std::shared_ptr<result_type>
658 volatile ReduceObject*
660 Future<result_type> result_;
661 madness::Spinlock lock_;
662 madness::CallbackInterface* callback_;
671 ReduceTaskImpl(World& world, opT op, madness::CallbackInterface* callback)
672 :
madness::TaskInterface(1, TaskAttributes::hipri()),
675 ready_result_(std::make_shared<result_type>(op())),
676 ready_object_(nullptr),
679 callback_(callback) {}
681 virtual ~ReduceTaskImpl() {}
684 virtual void run(
const madness::TaskThreadEnv& threadEnv) {
685 internal_run(threadEnv);
694 void ready(ReduceObject*
object) {
698 std::shared_ptr<result_type> ready_result = ready_result_;
699 ready_result_.reset();
702 world_.taskq.add(
this, &ReduceTaskImpl::reduce_result_object,
703 ready_result,
object, TaskAttributes::hipri());
704 }
else if (ready_object_) {
705 ReduceObject* ready_object =
const_cast<ReduceObject*
>(ready_object_);
706 ready_object_ =
nullptr;
709 world_.taskq.add(
this, &ReduceTaskImpl::reduce_object_object,
object,
710 ready_object, TaskAttributes::hipri());
712 ready_object_ = object;
720 const Future<result_type>& result()
const {
return result_; }
725 World& world()
const {
return world_; }
729 ReduceTaskImpl* pimpl_;
743 madness::CallbackInterface* callback =
nullptr)
744 : pimpl_(new ReduceTaskImpl(world, op, callback)), count_(0ul) {}
750 : pimpl_(other.pimpl_), count_(other.count_) {
751 other.pimpl_ =
nullptr;
765 pimpl_ = other.pimpl_;
766 count_ = other.count_;
767 other.pimpl_ =
nullptr;
785 template <
typename Arg>
786 int add(
const Arg& arg, madness::CallbackInterface* callback =
nullptr) {
789 new typename ReduceTaskImpl::ReduceObject(pimpl_, arg, callback);
796 int count()
const {
return count_; }
811 World& world = pimpl_->world();
812 world.taskq.add(pimpl_);
821 operator bool()
const {
return pimpl_ !=
nullptr; }
909 template <
typename opT>
916 second_argument_type;
932 madness::CallbackInterface* callback =
nullptr)
965 template <
typename L,
typename R>
966 void add(
const L& left,
const R& right,
967 madness::CallbackInterface* callback =
nullptr) {
978 #endif // TILEDARRAY_REDUCE_TASK_H__INCLUDED