mult.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * mult.h
22  * May 8, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_TILE_OP_MULT_H__INCLUDED
27 #define TILEDARRAY_TILE_OP_MULT_H__INCLUDED
28 
29 #include <TiledArray/error.h>
32 #include <TiledArray/zero_tensor.h>
33 
34 namespace TiledArray {
35 namespace detail {
36 
38 
55 template <typename Result, typename Left, typename Right, bool LeftConsumable,
56  bool RightConsumable>
57 class Mult {
58  public:
60  typedef Left left_type;
61  typedef Right right_type;
62  typedef Result result_type;
63 
64  using left_value_type = typename left_type::value_type;
65  using right_value_type = typename right_type::value_type;
66  using result_value_type = typename result_type::value_type;
68  const right_value_type&);
69 
71  static constexpr bool left_is_consumable =
72  LeftConsumable && std::is_same<result_type, left_type>::value;
74  static constexpr bool right_is_consumable =
75  RightConsumable && std::is_same<result_type, right_type>::value;
76 
77  private:
81 
82  // Permuting tile evaluation function
83  // These operations cannot consume the argument tile since this operation
84  // requires temporary storage space.
85  template <typename Perm, typename = std::enable_if_t<
86  TiledArray::detail::is_permutation_v<Perm>>>
87  result_type eval(const left_type& first, const right_type& second,
88  const Perm& perm) const {
89  if (!element_op_) {
90  using TiledArray::mult;
91  return mult(first, second, perm);
92  } else {
93  using TiledArray::binary;
94  return binary(first, second, element_op_, perm);
95  }
96  }
97 
98  template <typename Perm, typename = std::enable_if_t<
99  TiledArray::detail::is_permutation_v<Perm>>>
100  result_type eval(ZeroTensor, const right_type& second,
101  const Perm& perm) const {
102  TA_ASSERT(false); // Invalid arguments for this operation
103  return result_type();
104  }
105 
106  template <typename Perm, typename = std::enable_if_t<
107  TiledArray::detail::is_permutation_v<Perm>>>
108  result_type eval(const left_type& first, ZeroTensor, const Perm& perm) const {
109  TA_ASSERT(false); // Invalid arguments for this operation
110  return result_type();
111  }
112 
113  // Non-permuting tile evaluation functions
114  // The compiler will select the correct functions based on the
115  // consumability of the arguments.
116 
117  template <bool LC, bool RC,
118  typename std::enable_if<!(LC || RC)>::type* = nullptr>
119  result_type eval(const left_type& first, const right_type& second) const {
120  if (!element_op_) {
121  using TiledArray::mult;
122  return mult(first, second);
123  } else {
124  using TiledArray::binary;
125  return binary(first, second, element_op_);
126  }
127  }
128 
129  template <bool LC, bool RC, typename std::enable_if<LC>::type* = nullptr>
130  result_type eval(left_type& first, const right_type& second) const {
131  TA_ASSERT(!element_op_);
132  using TiledArray::mult_to;
133  return mult_to(first, second);
134  }
135 
136  template <bool LC, bool RC,
137  typename std::enable_if<!LC && RC>::type* = nullptr>
138  result_type eval(const left_type& first, right_type& second) const {
139  TA_ASSERT(!element_op_);
140  using TiledArray::mult_to;
141  return mult_to(second, first);
142  }
143 
144  template <bool LC, bool RC, typename std::enable_if<!RC>::type* = nullptr>
145  result_type eval(ZeroTensor, const right_type& second) const {
146  TA_ASSERT(false); // Invalid arguments for this operation
147  return result_type();
148  }
149 
150  template <bool LC, bool RC, typename std::enable_if<RC>::type* = nullptr>
151  result_type eval(ZeroTensor, right_type& second) const {
152  TA_ASSERT(false); // Invalid arguments for this operation
153  return result_type();
154  }
155 
156  template <bool LC, bool RC, typename std::enable_if<!LC>::type* = nullptr>
157  result_type eval(const left_type& first, ZeroTensor) const {
158  TA_ASSERT(false); // Invalid arguments for this operation
159  return result_type();
160  }
161 
162  template <bool LC, bool RC, typename std::enable_if<LC>::type* = nullptr>
163  result_type eval(left_type& first, ZeroTensor) const {
164  TA_ASSERT(false); // Invalid arguments for this operation
165  return result_type();
166  }
167 
168  public:
172  Mult() = default;
176  template <typename ElementOp,
177  typename = std::enable_if_t<
178  !std::is_same_v<std::remove_reference_t<ElementOp>, Mult_> &&
179  std::is_invocable_r_v<
180  result_value_type, std::remove_reference_t<ElementOp>,
181  const left_value_type&, const right_value_type&>>>
182  explicit Mult(ElementOp&& op) : element_op_(std::forward<ElementOp>(op)) {}
183 
185 
193  template <
194  typename L, typename R, typename Perm,
195  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
196  result_type operator()(L&& left, R&& right, const Perm& perm) const {
197  return eval(std::forward<L>(left), std::forward<R>(right), perm);
198  }
199 
201 
208  template <typename L, typename R>
209  result_type operator()(L&& left, R&& right) const {
210  return Mult_::template eval<left_is_consumable, right_is_consumable>(
211  std::forward<L>(left), std::forward<R>(right));
212  }
213 
215 
221  template <typename R>
222  result_type consume_left(left_type& left, R&& right) const {
223  constexpr bool can_consume_left =
225  std::is_same<result_type, left_type>::value;
226  constexpr bool can_consume_right =
227  right_is_consumable && !(std::is_const<R>::value || can_consume_left);
228  return Mult_::template eval<can_consume_left, can_consume_right>(
229  left, std::forward<R>(right));
230  }
231 
233 
239  template <typename L>
240  result_type consume_right(L&& left, right_type& right) const {
241  constexpr bool can_consume_right =
243  std::is_same<result_type, right_type>::value;
244  constexpr bool can_consume_left =
245  left_is_consumable && !(std::is_const<L>::value || can_consume_right);
246  return Mult_::template eval<can_consume_left, can_consume_right>(
247  std::forward<L>(left), right);
248  }
249 
250 }; // class Mult
251 
253 
267 template <typename Result, typename Left, typename Right, typename Scalar,
268  bool LeftConsumable, bool RightConsumable>
269 class ScalMult {
270  public:
271  typedef ScalMult<Result, Left, Right, Scalar, LeftConsumable,
272  RightConsumable>
274  typedef Left left_type;
275  typedef Right right_type;
276  typedef Scalar scalar_type;
277  typedef Result result_type;
278 
280  static constexpr bool left_is_consumable =
281  LeftConsumable && std::is_same<result_type, left_type>::value;
283  static constexpr bool right_is_consumable =
284  RightConsumable && std::is_same<result_type, right_type>::value;
285 
286  private:
287  scalar_type factor_;
288 
289  // Permuting tile evaluation function
290  // These operations cannot consume the argument tile since this operation
291  // requires temporary storage space.
292 
293  template <typename Perm, typename = std::enable_if_t<
294  TiledArray::detail::is_permutation_v<Perm>>>
295  result_type eval(const left_type& first, const right_type& second,
296  const Perm& perm) const {
297  using TiledArray::mult;
298  return mult(first, second, factor_, perm);
299  }
300 
301  template <typename Perm, typename = std::enable_if_t<
302  TiledArray::detail::is_permutation_v<Perm>>>
303  result_type eval(ZeroTensor, const right_type& second,
304  const Perm& perm) const {
305  TA_ASSERT(false); // Invalid arguments for this operation
306  return result_type();
307  }
308 
309  template <typename Perm, typename = std::enable_if_t<
310  TiledArray::detail::is_permutation_v<Perm>>>
311  result_type eval(const left_type& first, ZeroTensor, const Perm& perm) const {
312  TA_ASSERT(false); // Invalid arguments for this operation
313  return result_type();
314  }
315 
316  // Non-permuting tile evaluation functions
317  // The compiler will select the correct functions based on the
318  // consumability of the arguments.
319 
320  template <bool LC, bool RC,
321  typename std::enable_if<!(LC || RC)>::type* = nullptr>
322  result_type eval(const left_type& first, const right_type& second) const {
323  using TiledArray::mult;
324  return mult(first, second, factor_);
325  }
326 
327  template <bool LC, bool RC, typename std::enable_if<LC>::type* = nullptr>
328  result_type eval(left_type& first, const right_type& second) const {
329  using TiledArray::mult_to;
330  return mult_to(first, second, factor_);
331  }
332 
333  template <bool LC, bool RC,
334  typename std::enable_if<!LC && RC>::type* = nullptr>
335  result_type eval(const left_type& first, right_type& second) const {
336  using TiledArray::mult_to;
337  return mult_to(second, first, factor_);
338  }
339 
340  template <bool LC, bool RC, typename std::enable_if<!RC>::type* = nullptr>
341  result_type eval(ZeroTensor, const right_type& second) const {
342  TA_ASSERT(false); // Invalid arguments for this operation
343  return result_type();
344  }
345 
346  template <bool LC, bool RC, typename std::enable_if<RC>::type* = nullptr>
347  result_type eval(ZeroTensor, right_type& second) const {
348  TA_ASSERT(false); // Invalid arguments for this operation
349  return result_type();
350  }
351 
352  template <bool LC, bool RC, typename std::enable_if<!LC>::type* = nullptr>
353  result_type eval(const left_type& first, ZeroTensor) const {
354  TA_ASSERT(false); // Invalid arguments for this operation
355  return result_type();
356  }
357 
358  template <bool LC, bool RC, typename std::enable_if<LC>::type* = nullptr>
359  result_type eval(left_type& first, ZeroTensor) const {
360  TA_ASSERT(false); // Invalid arguments for this operation
361  return result_type();
362  }
363 
364  public:
365  // Compiler generated functions
366  ScalMult(const ScalMult_&) = default;
367  ScalMult(ScalMult_&&) = default;
368  ~ScalMult() = default;
369  ScalMult_& operator=(const ScalMult_&) = default;
371 
373 
375  explicit ScalMult(const Scalar factor) : factor_(factor) {}
376 
378 
386  template <
387  typename L, typename R, typename Perm,
388  typename = std::enable_if_t<TiledArray::detail::is_permutation_v<Perm>>>
389  result_type operator()(L&& left, R&& right, const Perm& perm) const {
390  return eval(std::forward<L>(left), std::forward<R>(right), perm);
391  }
392 
394 
401  template <typename L, typename R>
402  result_type operator()(L&& left, R&& right) const {
403  return ScalMult_::template eval<left_is_consumable, right_is_consumable>(
404  std::forward<L>(left), std::forward<R>(right));
405  }
406 
408 
414  template <typename R>
415  result_type consume_left(left_type& left, R&& right) const {
416  constexpr bool can_consume_left =
418  std::is_same<result_type, left_type>::value;
419  constexpr bool can_consume_right =
420  right_is_consumable && !(std::is_const<R>::value || can_consume_left);
421  return ScalMult_::template eval<can_consume_left, can_consume_right>(
422  left, std::forward<R>(right));
423  }
424 
426 
433  template <typename L>
434  result_type consume_right(L&& left, right_type& right) const {
435  constexpr bool can_consume_right =
437  std::is_same<result_type, right_type>::value;
438  constexpr bool can_consume_left =
439  left_is_consumable && !(std::is_const<L>::value || can_consume_right);
440  return ScalMult_::template eval<can_consume_left, can_consume_right>(
441  std::forward<L>(left), right);
442  }
443 
444 }; // class ScalMult
445 
446 } // namespace detail
447 } // namespace TiledArray
448 
449 #endif // TILEDARRAY_TILE_OP_MULT_H__INCLUDED
result_type consume_right(L &&left, right_type &right) const
Multiply left to right.
Definition: mult.h:240
static constexpr bool left_is_consumable
Indicates whether it is possible to consume the left tile.
Definition: mult.h:280
static constexpr bool left_is_consumable
Indicates whether it is possible to consume the left tile.
Definition: mult.h:71
Consumable tile type trait.
Definition: type_traits.h:611
Tile scale-multiplication operation.
Definition: mult.h:269
typename left_type::value_type left_value_type
Definition: mult.h:64
Mult< Result, Left, Right, LeftConsumable, RightConsumable > Mult_
Definition: mult.h:59
Tile multiplication operation.
Definition: mult.h:57
typename right_type::value_type right_value_type
Definition: mult.h:65
Tile< Result > & mult_to(Tile< Result > &result, const Tile< Arg > &arg)
Multiply to the result tile.
Definition: tile.h:1081
static constexpr bool right_is_consumable
Indicates whether it is possible to consume the right tile.
Definition: mult.h:283
decltype(auto) binary(const Tile< Left > &left, const Tile< Right > &right, Op &&op)
Binary element-wise transform producing a new tile.
Definition: tile.h:1118
result_type operator()(L &&left, R &&right) const
Multiply operator.
Definition: mult.h:209
ScalMult_ & operator=(const ScalMult_ &)=default
Right right_type
Right-hand argument base type.
Definition: mult.h:61
ScalMult(ScalMult_ &&)=default
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
ScalMult(const ScalMult_ &)=default
Result result_type
The result tile type.
Definition: mult.h:62
typename result_type::value_type result_value_type
Definition: mult.h:66
decltype(auto) mult(const Tile< Left > &left, const Tile< Right > &right)
Multiplication tile arguments.
Definition: tile.h:1018
result_value_type(const left_value_type &, const right_value_type &) element_op_type
Definition: mult.h:68
result_type consume_right(L &&left, right_type &right) const
Multiply left to right and scale the result.
Definition: mult.h:434
ScalMult_ & operator=(ScalMult_ &&)=default
Place-holder object for a zero tensor.
Definition: zero_tensor.h:32
static constexpr bool right_is_consumable
Indicates whether it is possible to consume the right tile.
Definition: mult.h:74
Right right_type
Right-hand argument base type.
Definition: mult.h:275
ScalMult(const Scalar factor)
Constructor.
Definition: mult.h:375
Mult(ElementOp &&op)
Definition: mult.h:182
ScalMult< Result, Left, Right, Scalar, LeftConsumable, RightConsumable > ScalMult_
This class type.
Definition: mult.h:273
Scalar scalar_type
Scaling factor type.
Definition: mult.h:276
result_type consume_left(left_type &left, R &&right) const
Multiply right to left.
Definition: mult.h:222
Result result_type
Result tile type.
Definition: mult.h:277
Left left_type
Left-hand argument base type.
Definition: mult.h:60
result_type operator()(L &&left, R &&right, const Perm &perm) const
Multiply-and-permute operator.
Definition: mult.h:196
Left left_type
Left-hand argument base type.
Definition: mult.h:274
result_type operator()(L &&left, R &&right) const
Scale-and-multiply operator.
Definition: mult.h:402
result_type consume_left(left_type &left, R &&right) const
Multiply right to left and scale the result.
Definition: mult.h:415
result_type operator()(L &&left, R &&right, const Perm &perm) const
Scale-multiply-and-permute operator.
Definition: mult.h:389