replicator.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19 
20 #ifndef TILEDARRAY_REPLICATOR_H__INCLUDED
21 #define TILEDARRAY_REPLICATOR_H__INCLUDED
22 
24 
25 namespace TiledArray {
26 namespace detail {
27 
29 
34 template <typename A>
35 class Replicator : public madness::WorldObject<Replicator<A> >,
36  private madness::Spinlock {
37  private:
38  typedef Replicator<A> Replicator_;
39  typedef madness::WorldObject<Replicator_>
40  wobj_type;
41  typedef std::stack<madness::CallbackInterface*,
42  std::vector<madness::CallbackInterface*> >
43  callback_type;
44 
45  A destination_;
46  std::vector<typename A::ordinal_type>
47  indices_;
48  std::vector<Future<typename A::value_type> > data_;
49  madness::AtomicInt sent_;
50  World& world_;
51  volatile callback_type callbacks_;
52  volatile mutable bool probe_;
53 
55  void do_callbacks() {
56  callback_type& callbacks = const_cast<callback_type&>(callbacks_);
57  while (!callbacks.empty()) {
58  callbacks.top()->notify();
59  callbacks.pop();
60  }
61  }
62 
64  class DelaySend : public madness::TaskInterface {
65  private:
66  Replicator_& parent_;
67 
68  public:
70  DelaySend(Replicator_& parent)
71  : madness::TaskInterface(madness::TaskAttributes::hipri()),
72  parent_(parent) {
73  typename std::vector<Future<typename A::value_type> >::iterator it =
74  parent_.data_.begin();
75  typename std::vector<Future<typename A::value_type> >::iterator end =
76  parent_.data_.end();
77  for (; it != end; ++it) {
78  if (!it->probe()) {
79  madness::DependencyInterface::inc();
80  it->register_callback(this);
81  }
82  }
83  }
84 
86  virtual ~DelaySend() {}
87 
89  virtual void run(const madness::TaskThreadEnv&) { parent_.send(); }
90 
91  }; // class DelaySend
92 
94 
96  bool probe() const {
97  madness::ScopedMutex<madness::Spinlock> locker(this);
98 
99  if (!probe_) {
100  typename std::vector<Future<typename A::value_type> >::const_iterator it =
101  data_.begin();
102  typename std::vector<Future<typename A::value_type> >::const_iterator
103  end = data_.end();
104  for (; it != end; ++it)
105  if (!it->probe()) break;
106 
107  probe_ = (it == end);
108  }
109 
110  return probe_;
111  }
112 
114  void delay_send() {
115  if (probe()) {
116  // The data is ready so send it now.
117  send(); // Replication is done
118  } else {
119  // The local data is not ready to be sent, so create a task that will
120  // send it when it is ready.
121  DelaySend* delay_send_task = new DelaySend(*this);
122  world_.taskq.add(delay_send_task);
123  }
124  }
125 
127  void send() {
128  const long sent = ++sent_;
129  const ProcessID dest = (world_.rank() + sent) % world_.size();
130 
131  if (dest != world_.rank()) {
132  wobj_type::task(dest, &Replicator_::send_handler, indices_, data_,
133  madness::TaskAttributes::hipri());
134  } else
135  do_callbacks(); // Replication is done
136  }
137 
138  void send_handler(const std::vector<typename A::ordinal_type>& indices,
140  typename std::vector<typename A::ordinal_type>::const_iterator index_it =
141  indices.begin();
142  typename std::vector<Future<typename A::value_type> >::const_iterator
143  data_it = data.begin();
144  typename std::vector<Future<typename A::value_type> >::const_iterator
145  data_end = data.end();
146 
147  for (; data_it != data_end; ++data_it, ++index_it)
148  destination_.set(*index_it, data_it->get());
149 
150  delay_send();
151  }
152 
153  public:
154  Replicator(const A& source, const A destination)
155  : wobj_type(source.world()),
156  madness::Spinlock(),
157  destination_(destination),
158  indices_(),
159  data_(),
160  sent_(),
161  world_(source.world()),
162  callbacks_(),
163  probe_(false) {
164  sent_ = 0;
165 
166  // Generate a list of local tiles from other.
167  typename A::pmap_interface::const_iterator end = source.pmap()->end();
168  typename A::pmap_interface::const_iterator it = source.pmap()->begin();
169  indices_.reserve(source.pmap()->local_size());
170  data_.reserve(source.pmap()->local_size());
171  if (source.is_dense()) {
172  // When dense, all tiles are present
173  for (; it != end; ++it) {
174  indices_.push_back(*it);
175  data_.push_back(source.find(*it));
176  destination_.set(*it, data_.back());
177  }
178  } else {
179  // When sparse, we need to generate a list
180  for (; it != end; ++it)
181  if (!source.is_zero(*it)) {
182  indices_.push_back(*it);
183  data_.push_back(source.find(*it));
184  destination_.set(*it, data_.back());
185  }
186  }
187 
189  delay_send();
190 
191  // Process any pending messages
192  wobj_type::process_pending();
193  }
194 
196 
198  bool done() {
199  madness::ScopedMutex<madness::Spinlock> locker(this);
200  return sent_ == world_.size();
201  }
202 
204 
209  void register_callback(madness::CallbackInterface* callback) {
210  madness::ScopedMutex<madness::Spinlock> locker(this);
211  if (sent_ == world_.size())
212  callback->notify();
213  else
214  const_cast<callback_type&>(callbacks_).push(callback);
215  }
216 
217 }; // class Replicator
218 
219 } // namespace detail
220 } // namespace TiledArray
221 
222 #endif // TILEDARRAY_REPLICATOR_H__INCLUDED
Replicator(const A &source, const A destination)
Definition: replicator.h:154
bool done()
Check that the replication is complete.
Definition: replicator.h:198
Replicate a Array object.
Definition: replicator.h:36
void register_callback(madness::CallbackInterface *callback)
Add a callback.
Definition: replicator.h:209
std::vector< T > vector
Definition: vector.h:41