1 #ifndef MPQC_ARRAY_THREAD_HPP
2 #define MPQC_ARRAY_THREAD_HPP
4 #include "mpqc/mpi.hpp"
5 #include "mpqc/utility/timer.hpp"
6 #include "mpqc/array/forward.hpp"
7 #include "mpqc/array/socket.hpp"
11 #include <boost/thread/thread.hpp>
12 #include "boost/thread/tss.hpp"
16 namespace ArrayServer {
20 static const size_t BUFFER = (32<<20);
30 std::vector<range> extents;
34 const std::vector<range> &r) {
35 this->
object = object;
36 this->rank = r.size();
40 size_t N = BUFFER/
sizeof(double);
42 for (
int i = 0; i < rank-1; ++i)
44 MPQC_ASSERT(block < N);
46 BOOST_FOREACH (
range rj,
split(r.back(), N/block)) {
47 for (
int i = 0; i < rank-1; ++i) {
48 data.push_back(*r[i].begin());
49 data.push_back(*r[i].end());
51 data.push_back(*rj.begin());
52 data.push_back(*rj.end());
57 array_proxy(Descriptor ds) {
58 this->
object = ds.object;
60 this->count = ds.count;
61 this->data.resize(2*rank*count);
64 Descriptor descriptor()
const {
66 ds.object = this->object;
68 ds.count = this->count;
72 std::vector<Segment> segments()
const {
73 std::vector<Segment> segments;
74 auto data = this->data.begin();
75 for (
int j = 0; j < count; ++j) {
78 for (
int i = 0; i < rank; ++i) {
79 r.push_back(range(data[0], data[1]));
80 size *= r.back().size();
83 segments.push_back(Segment());
84 segments.back().size = size;
85 segments.back().extents = r;
94 std::vector<int> data;
101 enum Request { INVALID = 0x0,
112 explicit Message(
int tag = 0, Request r = INVALID)
113 : tag(tag), request(r) {}
117 this->dataspace = ds;
128 : comm_(comm), tag_(1<<20)
130 buffer_ = malloc(array_proxy::BUFFER);
131 this->socket_.start();
132 this->servers_ = comm.allgather(this->socket_.address());
133 this->thread_ =
new boost::thread(&Thread::run,
this);
147 delete this->thread_;
153 send(
Message(tag, Message::JOIN), comm_.rank());
154 this->thread_->join();
160 send(
Message(tag, Message::SYNC), comm_.rank());
161 comm_.recv<
Message>(comm_.rank(), tag | SEND_MASK);
164 void send(
Message msg,
int proc)
const {
165 msg.src = comm_.rank();
166 ArraySocket::send(&msg, this->servers_.at(proc));
172 size_t count, MPI_Datatype type,
173 int proc,
int tag)
const {
174 MPQC_ASSERT(!(tag & SEND_MASK));
175 MPQC_ASSERT(!(tag & RECV_MASK));
176 comm_.send(data, count, type, proc, tag | RECV_MASK);
181 size_t count, MPI_Datatype type,
182 int proc,
int tag)
const {
183 MPQC_ASSERT(!(tag & SEND_MASK));
184 MPQC_ASSERT(!(tag & RECV_MASK));
185 comm_.recv(data, count, type, proc, tag | SEND_MASK);
188 static std::shared_ptr<Thread>& instance() {
189 static std::shared_ptr<Thread> thread;
191 MPI::initialize(MPI_THREAD_MULTIPLE);
197 static void run(Thread *thread) {
206 const unsigned int N = 1 << 21;
207 boost::mutex::scoped_lock lock(mutex_);
208 return int(N + (next_++ % N));
211 int translate(MPI::Comm comm1,
int rank1)
const {
213 MPI_Group group1, group2;
214 MPI_Comm_group(comm1, &group1);
215 MPI_Comm_group(this->comm_, &group2);
216 MPI_Group_translate_ranks(group1, 1, &rank1, group2, &rank2);
217 MPQC_ASSERT(rank2 != MPI_UNDEFINED);
233 this->socket_.wait(&msg);
248 if (msg.request == Message::READ) {
250 read(msg, msg.dataspace);
253 if (msg.request == Message::WRITE) {
255 write(msg, msg.dataspace);
258 if (msg.request == Message::SYNC) {
262 if (msg.request == Message::JOIN) {
267 printf(
"invalid message request %i\n", msg.request);
268 throw std::runtime_error(
"invalid message");
276 boost::thread *thread_;
279 std::vector<ArraySocket::Address> servers_;
282 mutable unsigned int next_;
283 mutable boost::mutex mutex_;
285 void sync(Message msg) {
287 comm_.send(Message(Message::SYNC), msg.src, msg.tag | SEND_MASK);
290 void read(Message msg, array_proxy::Descriptor ds) {
291 io<Message::READ>(array_proxy(ds), msg.src, msg.tag);
294 void write(Message msg, array_proxy::Descriptor ds) {
295 io<Message::WRITE>(array_proxy(ds), msg.src, msg.tag);
300 template<Message::Request OP>
301 void io(array_proxy ds,
int proc,
int tag) {
304 comm_.recv(&ds.data[0], ds.data.size(), MPI_INT, proc, tag | RECV_MASK);
305 const auto &segments = ds.segments();
306 double* buffer = static_cast<double*>(this->buffer_);
307 for (
int i = 0; i < segments.size(); ++i) {
308 const auto &extents = segments[i].extents;
309 if (OP == Message::WRITE) {
312 comm_.recv(buffer, segments[i].size, MPI_DOUBLE,
313 proc, tag | RECV_MASK);
315 ds.object->put(extents, buffer);
317 if (OP == Message::READ) {
321 ds.object->get(extents, buffer);
322 comm_.send(buffer, segments[i].size, MPI_DOUBLE,
323 proc, tag | SEND_MASK);
347 thread_ = Thread::instance();
353 return *this->thread_;
364 void write(
const T *data,
ArrayBase *
object,
365 const std::vector<range> &r,
int rank)
const {
367 io<Message::WRITE>((T*)data,
object, r, rank);
372 const std::vector<range> &r,
int rank)
const {
374 io<Message::READ>(data,
object, r, rank);
384 std::shared_ptr<Thread> thread_;
387 static boost::thread_specific_ptr<int> tag;
388 if (!tag.get()) tag.reset(
new int(thread_->next()));
392 template<Message::Request OP,
typename T>
394 const std::vector<range> &r,
398 static_assert(OP == Message::WRITE ||
403 int tag = this->tag();
409 thread_->send(
Message(tag, OP, ds.descriptor()), proc);
415 thread_->send(&ds.data[0], ds.data.size(), MPI_INT, proc, tag);
418 auto segments = ds.segments();
419 std::vector<MPI_Request> requests(ds.count);
423 for (
int i = 0; i < segments.size(); ++i) {
424 size_t size = segments[i].size;
425 if (OP == Message::READ) {
430 thread_->
recv(buffer, size*
sizeof(T), MPI_BYTE, proc, tag);
432 if (OP == Message::WRITE) {
437 thread_->send(buffer, size*
sizeof(T), MPI_BYTE, proc, tag);
454 #endif // MPQC_ARRAY_THREAD_HPP