1 #ifndef MPQC_ARRAY_CORE_HPP
2 #define MPQC_ARRAY_CORE_HPP
4 #include "mpqc/array/forward.hpp"
6 #include "mpqc_config.h"
7 #include "mpqc/mpi.hpp"
9 #include "mpqc/array/parallel.hpp"
17 #include "mpqc/range.hpp"
18 #include <boost/foreach.hpp>
19 #include "mpqc/utility/mutex.hpp"
20 #include "mpqc/utility/exception.hpp"
22 #include <boost/noncopyable.hpp>
36 template<
typename Extent>
38 const std::vector<Extent> &extents)
41 data_.resize(this->size());
48 void _put(
const std::vector<range> &r,
const void *buffer) {
49 apply(putv, r, this->dims_, &this->data_[0], (
const T*)buffer);
52 void _get(
const std::vector<range> &r,
void *buffer)
const {
53 apply(getv, r, this->dims_, &this->data_[0], (T*)buffer);
60 static size_t putv(
size_t size, T *data,
const T *buffer) {
61 std::copy(buffer, buffer+size, data);
65 static size_t getv(
size_t size,
const T *data, T *buffer) {
66 std::copy(data, data+size, buffer);
70 template<
class F,
typename data_ptr,
typename buffer_ptr>
71 static size_t apply(F f,
72 const std::vector<range> &r,
73 const std::vector<size_t> &dims,
74 data_ptr data, buffer_ptr buffer) {
76 std::vector<size_t> counts;
77 std::vector<size_t> strides;
80 for (
size_t i = 0, stride = 1; i < dims.size(); ++i) {
81 counts.back() *= r[i].size();
82 begin += *r[i].begin()*stride;
86 if (dims[i] != r[i].size()) {
88 strides.push_back(stride);
91 return apply(f, counts, strides, data+begin, buffer, strides.size());
94 template<
class F,
typename data_ptr,
typename buffer_ptr>
95 static size_t apply(F f,
96 const std::vector<size_t> &counts,
97 const std::vector<size_t> &strides,
98 data_ptr data, buffer_ptr buffer,
104 return f(counts[0], data, buffer);
108 for (
size_t i = 0; i < counts[level]; ++i) {
110 size_t n = apply(f, counts, strides, data, buffer+size, level);
111 data += strides[level];
128 const std::vector<size_t> &dims,
132 initialize(ArrayBase::dims_, comm);
137 void initialize(
const std::vector<size_t> &dims,
const MPI::Comm &comm) {
140 MPQC_CHECK(comm == MPI::Comm::World());
143 std::vector<range> extents;
144 for (
int i = 0; i < int(dims.size())-1; ++i) {
146 extents.push_back(
range(0, dims[i]));
148 size_t block = (dims.back() + comm.size() - 1)/comm.size();
151 check(ARMCI_Init(),
"ARMCI_Init", comm);
153 std::vector<void*> data(comm.size(), NULL);
154 check(ARMCI_Malloc(&data[0], size*
sizeof(T)),
"ARMCI_Malloc", comm);
155 data_ = data[comm.rank()];
157 for (
size_t i = 0; i < comm.size(); ++i) {
161 tile.local = (i == comm.rank());
162 size_t begin = std::min<size_t>(dims.back(), i*block);
163 size_t end = std::min<size_t>(dims.back(), begin+block);
164 tile.extents = extents;
165 tile.extents.push_back(range(begin, end));
166 tiles_.push_back(tile);
172 ~array_parallel_impl() {
173 ARMCI_Free(this->data_);
180 const MPI::Comm& comm()
const {
186 void _put(
const std::vector<range> &r,
const void *buffer) {
187 apply<PUT>(this->tiles_, r, (
void*)buffer);
190 void _get(
const std::vector<range> &r,
void *buffer)
const {
191 apply<GET>(this->tiles_, r, buffer);
196 enum OP { PUT, GET };
198 typedef array_tile<void*> Tile;
200 std::vector<Tile> tiles_;
202 static void check(
int err,
const std::string &func,
203 MPI::Comm comm = MPI::Comm::Self()) {
204 if (comm.any((err != 0))) {
205 throw std::runtime_error(func +
" failed");
210 static void apply(
const std::vector<Tile> &tiles,
211 const std::vector<range> &r,
214 T *local = (T*)buffer;
216 BOOST_FOREACH (Tile t, tiles) {
218 std::vector<range> x = t.subset(r);
220 if (x.empty())
continue;
224 int remote_strides[N];
225 int local_strides[N];
228 for (
size_t i = 0; i < N; ++i) {
229 total *= x[i].size();
230 count[i] = x[i].size();
232 x[i].size()*((i > 0) ? local_strides[i-1] : 1);
234 t.extents[i].size()*((i > 0) ? remote_strides[i - 1] : 1);
235 local_strides[i] *=
sizeof(T);
236 remote_strides[i] *=
sizeof(T);
238 count[0] *=
sizeof(T);
240 T *remote = (T*)t.data;
241 for (
size_t i = 0, stride = 1; i < N; ++i) {
242 remote += (x[i].begin() - t.extents[i].begin())*stride;
243 stride *= t.extents[i].size();
252 mutex::global::lock();
255 err = ARMCI_GetS(remote, remote_strides,
256 local, local_strides,
258 check(err,
"ARMCI_GetS failed");
262 err = ARMCI_PutS(local, local_strides,
263 remote, remote_strides,
265 check(err,
"ARMCI_PutS failed");
267 mutex::global::unlock();
282 static const detail::array_core_driver ARRAY_CORE;