binary_wrapper.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * binary_interface.h
22  * Oct 6, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_TILE_OP_BINARY_WRAPPER_H__INCLUDED
27 #define TILEDARRAY_TILE_OP_BINARY_WRAPPER_H__INCLUDED
28 
29 #include <TiledArray/permutation.h>
31 #include <TiledArray/zero_tensor.h>
32 
33 namespace TiledArray {
34 namespace detail {
35 
37 
93 template <typename Op>
95  public:
97  typedef typename Op::left_type left_type;
98  typedef typename Op::right_type right_type;
99  typedef typename Op::result_type result_type;
100 
102  static constexpr bool left_is_consumable = Op::left_is_consumable;
105  static constexpr bool right_is_consumable = Op::right_is_consumable;
106 
107  template <typename T>
108  static constexpr bool is_lazy_tile_v = is_lazy_tile<std::decay_t<T>>::value;
109 
110  template <typename T>
111  static constexpr bool is_array_tile_v = is_array_tile<std::decay_t<T>>::value;
112 
113  template <typename T>
114  static constexpr bool is_nonarray_lazy_tile_v =
115  is_lazy_tile_v<T> && !is_array_tile_v<T>;
116 
117  template <typename T>
118  using eval_t = typename eval_trait<std::decay_t<T>>::type;
119 
120  private:
121  Op op_;
122  BipartitePermutation perm_;
123 
124  public:
125  // Compiler generated functions
126  BinaryWrapper(const BinaryWrapper<Op>&) = default;
128  ~BinaryWrapper() = default;
131 
132  template <typename Perm, typename = std::enable_if_t<
133  TiledArray::detail::is_permutation_v<Perm>>>
134  BinaryWrapper(const Op& op, const Perm& perm) : op_(op), perm_(perm) {}
135 
136  BinaryWrapper(const Op& op) : op_(op), perm_() {}
137 
139 
146  template <typename L, typename R,
147  std::enable_if_t<
148  !(is_lazy_tile_v<L> || is_lazy_tile_v<R>)&&!std::is_same<
149  std::decay_t<L>, ZeroTensor>::value &&
150  !std::is_same<std::decay_t<R>, ZeroTensor>::value>* = nullptr>
151  auto operator()(L&& left, R&& right) const {
152  static_assert(
153  std::is_same<std::decay_t<L>, left_type>::value,
154  "BinaryWrapper::operator()(L&&,R&&): invalid argument type L");
155  static_assert(
156  std::is_same<std::decay_t<R>, right_type>::value,
157  "BinaryWrapper::operator()(L&&,R&&): invalid argument type R");
158  if (perm_) return op_(std::forward<L>(left), std::forward<R>(right), perm_);
159 
160  return op_(std::forward<L>(left), std::forward<R>(right));
161  }
162 
164 
171  template <typename R, std::enable_if_t<!is_lazy_tile_v<R>>* = nullptr>
172  auto operator()(const ZeroTensor& left, R&& right) const {
173  static_assert(
174  std::is_same<std::decay_t<R>, right_type>::value,
175  "BinaryWrapper::operator()(zero,R&&): invalid argument type R");
176  if (perm_) return op_(left, std::forward<R>(right), perm_);
177 
178  return op_(left, std::forward<R>(right));
179  }
180 
182 
189  template <typename L, std::enable_if_t<!is_lazy_tile_v<L>>* = nullptr>
190  auto operator()(L&& left, const ZeroTensor& right) const {
191  static_assert(
192  std::is_same<std::decay_t<L>, left_type>::value,
193  "BinaryWrapper::operator()(L&&,zero): invalid argument type L");
194  if (perm_) return op_(std::forward<L>(left), right, perm_);
195 
196  return op_(std::forward<L>(left), right);
197  }
198 
199  // The following operators will evaluate lazy tile and use the base class
200  // interface functions to call the correct evaluation kernel.
201 
203 
213  template <
214  typename L, typename R,
215  std::enable_if_t<is_lazy_tile_v<L> && is_lazy_tile_v<R> &&
216  (left_is_consumable || right_is_consumable)>* = nullptr>
217  auto operator()(L&& left, R&& right) const {
218  auto eval_left = invoke_cast(std::forward<L>(left));
219  auto eval_right = invoke_cast(std::forward<R>(right));
220  auto continuation = [this](
221  madness::future_to_ref_t<decltype(eval_left)> l,
222  madness::future_to_ref_t<decltype(eval_right)> r) {
223  return BinaryWrapper_::operator()(l, r);
224  };
225  return meta::invoke(continuation, eval_left, eval_right);
226  }
227 
229 
239  template <
240  typename L, typename R,
241  std::enable_if_t<is_lazy_tile_v<L> &&
242  (!is_lazy_tile_v<R>)&&(left_is_consumable ||
243  right_is_consumable)>* = nullptr>
244  auto operator()(L&& left, R&& right) const {
245  auto eval_left = invoke_cast(std::forward<L>(left));
246  auto continuation = [this](madness::future_to_ref_t<decltype(eval_left)> l,
247  R&& r) {
248  return BinaryWrapper_::operator()(l, std::forward<R>(r));
249  };
250  return meta::invoke(continuation, eval_left, std::forward<R>(right));
251  }
252 
254 
264  template <
265  typename L, typename R,
266  std::enable_if_t<(!is_lazy_tile_v<L>)&&is_lazy_tile_v<R> &&
267  (left_is_consumable || right_is_consumable)>* = nullptr>
268  auto operator()(L&& left, R&& right) const {
269  auto eval_right = invoke_cast(std::forward<R>(right));
270  auto continuation =
271  [this](L&& l, madness::future_to_ref_t<decltype(eval_right)> r) {
272  return BinaryWrapper_::operator()(std::forward<L>(l), r);
273  };
274  return meta::invoke(continuation, std::forward<L>(left), eval_right);
275  }
276 
278 
287  template <
288  typename L, typename R,
289  std::enable_if_t<is_array_tile_v<L> && is_array_tile_v<R> &&
290  !(left_is_consumable || right_is_consumable)>* = nullptr>
291  auto operator()(L&& left, R&& right) const {
292  auto eval_left = invoke_cast(std::forward<L>(left));
293  auto eval_right = invoke_cast(std::forward<R>(right));
294 
295  if (perm_) return meta::invoke(op_, eval_left, eval_right, perm_);
296 
297  auto op_left = [=](eval_t<L>& _left, eval_t<R>& _right) {
298  return op_.consume_left(_left, _right);
299  };
300  auto op_right = [=](eval_t<L>& _left, eval_t<R>& _right) {
301  return op_.consume_right(_left, _right);
302  };
303  // Override consumable
304  if (is_consumable_tile<eval_t<L>>::value && left.is_consumable())
305  return meta::invoke(op_left, eval_left, eval_right);
306  if (is_consumable_tile<eval_t<R>>::value && right.is_consumable())
307  return meta::invoke(op_right, eval_left, eval_right);
308 
309  return meta::invoke(op_, eval_left, eval_right);
310  }
311 
312  template <
313  typename L, typename R,
314  std::enable_if_t<is_array_tile_v<L> &&
315  (!is_lazy_tile_v<R>)&&!(left_is_consumable ||
316  right_is_consumable)>* = nullptr>
317  auto operator()(L&& left, R&& right) const {
318  auto eval_left = invoke_cast(std::forward<L>(left));
319 
320  if (perm_) return op_(eval_left, std::forward<R>(right), perm_);
321 
322  // Override consumable
323  if (is_consumable_tile<eval_t<L>>::value && left.is_consumable())
324  return op_.consume_left(eval_left, std::forward<R>(right));
325 
326  return op_(eval_left, std::forward<R>(right));
327  }
328 
329  template <
330  typename L, typename R,
331  std::enable_if_t<is_array_tile_v<L> && is_nonarray_lazy_tile_v<R> &&
332  !(left_is_consumable || right_is_consumable)>* = nullptr>
333  auto operator()(L&& left, R&& right) const {
334  auto eval_left = invoke_cast(std::forward<L>(left));
335  auto eval_right = invoke_cast(std::forward<R>(right));
336 
337  if (perm_) return op_(eval_left, eval_right, perm_);
338 
339  // Override consumable
340  if (is_consumable_tile<eval_t<L>>::value && left.is_consumable())
341  return op_.consume_left(eval_left, eval_right);
342 
343  return op_(eval_left, eval_right);
344  }
345 
346  template <
347  typename L, typename R,
348  std::enable_if_t<(!is_lazy_tile_v<L>)&&is_array_tile_v<R> &&
349  !(left_is_consumable || right_is_consumable)>* = nullptr>
350  auto operator()(L&& left, R&& right) const {
351  auto eval_right = invoke_cast(std::forward<R>(right));
352 
353  if (perm_) return op_(std::forward<L>(left), eval_right, perm_);
354 
355  // Override consumable
356  if (is_consumable_tile<eval_t<R>>::value && right.is_consumable())
357  return op_.consume_right(std::forward<L>(left), eval_right);
358 
359  return op_(std::forward<L>(left), eval_right);
360  }
361 
362  template <
363  typename L, typename R,
364  std::enable_if_t<is_nonarray_lazy_tile_v<L> && is_array_tile_v<R> &&
365  !(left_is_consumable || right_is_consumable)>* = nullptr>
366  auto operator()(L&& left, R&& right) const {
367  auto eval_left = invoke_cast(std::forward<L>(left));
368  auto eval_right = invoke_cast(std::forward<R>(right));
369 
370  if (perm_) return op_(eval_left, eval_right, perm_);
371 
372  // Override consumable
373  if (is_consumable_tile<eval_t<R>>::value && right.is_consumable())
374  return op_.consume_right(eval_left, eval_right);
375 
376  return op_(eval_left, eval_right);
377  }
378 
379 }; // class BinaryWrapper
380 
381 } // namespace detail
382 } // namespace TiledArray
383 
384 #endif // TILEDARRAY_TILE_OP_BINARY_WRAPPER_H__INCLUDED
::blas::Op Op
Definition: blas.h:46
Consumable tile type trait.
Definition: type_traits.h:611
auto invoke_cast(Arg &&arg)
Definition: cast.h:176
static constexpr bool is_lazy_tile_v
static constexpr bool is_nonarray_lazy_tile_v
BinaryWrapper(const BinaryWrapper< Op > &)=default
static constexpr bool is_array_tile_v
Op::right_type right_type
Right-hand argument type.
BinaryWrapper(const Op &op, const Perm &perm)
Binary tile operation wrapper.
BinaryWrapper< Op > & operator=(const BinaryWrapper< Op > &)=default
Op::left_type left_type
Left-hand argument type.
auto operator()(const ZeroTensor &left, R &&right) const
Evaluate a zero tile to a non-zero tiles and possibly permute.
auto operator()(L &&left, R &&right) const
Evaluate two non-zero tiles and possibly permute.
BinaryWrapper< Op > & operator=(BinaryWrapper< Op > &&)=default
auto operator()(L &&left, const ZeroTensor &right) const
Evaluate a non-zero tiles to a zero tile and possibly permute.
BinaryWrapper(BinaryWrapper< Op > &&)=default
auto invoke(Function &&fn, Args &&... args) -> typename std::enable_if< !or_reduce< false, madness::is_future< std::decay_t< Args >>::value... >::value, decltype(fn(args...))>::type
Definition: meta.h:52
Place-holder object for a zero tensor.
Definition: zero_tensor.h:32
Op::result_type result_type
The result tile type.
typename eval_trait< std::decay_t< T > >::type eval_t
Detect tiles used by ArrayEvalImpl.
Definition: type_traits.h:670
Determine the object type used in the evaluation of tensor expressions.
Definition: type_traits.h:580
static constexpr bool right_is_consumable
Permutation of a bipartite set.
Definition: permutation.h:610
BinaryWrapper< Op > BinaryWrapper_
Detect lazy evaluation tiles.
Definition: type_traits.h:591
static constexpr bool left_is_consumable
Boolean value that indicates the left-hand argument can always be consumed.