parallel_gemm.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2015 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * justus
19  * Department of Chemistry, Virginia Tech
20  *
21  * parallel_gemm.h
22  * Apr 29, 2015
23  *
24  */
25 
26 #ifndef TILEDARRAY_PARALLEL_GEMM_H__INCLUDED
27 #define TILEDARRAY_PARALLEL_GEMM_H__INCLUDED
28 
29 #include <TiledArray/blas.h>
31 #include <TiledArray/vector_op.h>
32 
33 #define TILEDARRAY_DYNAMIC_BLOCK_SIZE std::numeric_limits<std::size_t>::max();
34 
35 namespace TiledArray {
36 namespace math {
37 
38 //#ifdef HAVE_INTEL_TBB
39 
40 template <typename T, integer Size>
41 class MatrixBlockTask : public tbb::task {
42  const integer rows_;
43  const integer cols_;
44  T* data_;
45  const integer ld_;
46  std::shared_ptr<T> result_;
47 
49 
53  void copy_block(T* result, const T* data, const integer ld) {
54  const T* const block_end = result + (TILEDARRAY_LOOP_UNWIND * Size);
55  for (; result < block_end; result += Size, data += ld)
57  }
58 
60 
66  void copy_block(const integer m, const integer n, T* result, const T* data,
67  const integer ld) {
68  const T* const block_end = result + (m * Size);
69  for (; result < block_end; result += Size, data += ld)
71  }
72 
73  public:
74  MatrixBlockTask(const integer rows, const integer cols, const T* const data,
75  const integer ld)
76  : rows_(rows), cols_(cols), data_(data), ld_(ld) {}
77 
79  virtual tbb::task* execut() {
80  // Compute block iteration limit
82  const integer mx =
83  rows_ & index_mask; // = rows - rows % TILEDARRAY_LOOP_UNWIND
84  const integer nx =
85  cols_ & index_mask; // = cols - cols % TILEDARRAY_LOOP_UNWIND
86  const integer m_tail = rows_ - mx;
87  const integer n_tail = cols_ - nx;
88 
89  // Copy data into block_
90  integer i = 0ul;
91  T* result_i = result_.get();
92  const T* data_i = data_;
93  for (; i < mx;
94  i += TILEDARRAY_LOOP_UNWIND, result_i += Size, data_i += ld_) {
95  integer j = 0ul;
96  for (; j < nx; j += TILEDARRAY_LOOP_UNWIND)
97  copy_block(result_i + j, data_i + j);
98 
99  if (n_tail)
100  copy_block(TILEDARRAY_LOOP_UNWIND, n_tail, result_i + j, data_i + j);
101  }
102 
103  if (m_tail) {
104  integer j = 0ul;
105  for (; j < nx; j += TILEDARRAY_LOOP_UNWIND)
106  copy_block(m_tail, TILEDARRAY_LOOP_UNWIND, result_i + j, data_i + j);
107 
108  if (n_tail) copy_block(m_tail, n_tail, result_i + j, data_i + j);
109  }
110 
111  return nullptr;
112  }
113 
114  std::shared_ptr<T> result() {
115  constexpr integer size = Size * Size;
116  constexpr integer bytes = size * sizeof(T);
117 
118  T* result_ptr = nullptr;
119  if (!posix_memalign(result_ptr, TILEARRAY_ALIGNMENT, bytes))
120  throw std::bad_alloc();
121 
122  result_.reset(result_ptr);
123 
124  return result_;
125  }
126 
127 }; // class MatrixBlockTask
128 
129 template <integer Size, typename C, typename A = C, typename B = C,
130  typename Alpha = C, typename Beta = C>
131 class GemmTask : public tbb::task {
132  const blas::Op op_a_, op_b_;
133  const integer m_, n_, k_;
134  const Alpha alpha_;
135  std::shared_ptr<A> a_;
136  constexpr integer lda_ = Size;
137  std::shared_ptr<B> b_;
138  const Beta beta_;
139  std::shared_ptr<C> c_;
140  const integer ldc_;
141 
142  public:
143  GemmTask(blas::Op op_a, blas::Op op_b, const integer m, const integer n,
144  const integer k, const Alpha alpha, const std::shared_ptr<A>& a,
145  const std::shared_ptr<B>& b, const Beta beta,
146  const std::shared_ptr<C>& c, const integer ldc)
147  : op_a_(op_a),
148  op_b_(op_b),
149  m_(m),
150  n_(n),
151  k_(k),
152  alpha_(alpha),
153  a_(a),
154  b_(b),
155  beta_(beta),
156  c_(c),
157  ldc_(c) {}
158 
159  virtual tbb::task execute() {
160  gemm(op_a_, op_b_, m_, n_, k_, alpha_, a_.get(), Size, b_.get(), Size, c_,
161  ldc_);
162  }
163 
164 }; // class GemmTask
165 
166 //#endif // HAVE_INTEL_TBB
167 
168 } // namespace math
169 } // namespace TiledArray
170 
171 #endif // TILEDARRAY_PARALLEL_GEMM_H__INCLUDED
::blas::Op Op
Definition: blas.h:46
int64_t integer
Definition: blas.h:44
virtual tbb::task * execut()
Task body.
Definition: parallel_gemm.h:79
virtual tbb::task execute()
TILEDARRAY_FORCE_INLINE void copy_block(Result *const result, const Arg *const arg)
Definition: vector_op.h:219
std::shared_ptr< T > result()
MatrixBlockTask(const integer rows, const integer cols, const T *const data, const integer ld)
Definition: parallel_gemm.h:74
#define TILEDARRAY_LOOP_UNWIND
Definition: vector_op.h:40
std::integral_constant< std::size_t, ~std::size_t(TILEDARRAY_LOOP_UNWIND - 1ul)> index_mask
Definition: vector_op.h:54
GemmTask(blas::Op op_a, blas::Op op_b, const integer m, const integer n, const integer k, const Alpha alpha, const std::shared_ptr< A > &a, const std::shared_ptr< B > &b, const Beta beta, const std::shared_ptr< C > &c, const integer ldc)
decltype(auto) gemm(const Tile< Left > &left, const Tile< Right > &right, const Scalar factor, const math::GemmHelper &gemm_config)
Contract 2 tensors over head/tail modes and scale the product.
Definition: tile.h:1396