contract_reduce.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * contract_reduce.h
22  * Oct 9, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED
27 #define TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED
28 
30 #include <TiledArray/permutation.h>
34 #include "../tile_interface/add.h"
35 #include "../tile_interface/permute.h"
36 
37 namespace TiledArray {
38 namespace detail {
39 
41 
51 template <typename Result, typename Left, typename Right, typename Scalar>
53  public:
56  typedef const Left& first_argument_type;
57  typedef const Right& second_argument_type;
58  typedef Result result_type;
59  typedef Scalar scalar_type;
60 
61  using left_value_type = typename Left::value_type;
62  using right_value_type = typename Right::value_type;
63  using result_value_type = typename Result::value_type;
65  const right_value_type&);
66 
67  static_assert(
68  TiledArray::detail::is_tensor_v<left_value_type> ==
69  TiledArray::detail::is_tensor_v<right_value_type> &&
70  TiledArray::detail::is_tensor_v<left_value_type> ==
71  TiledArray::detail::is_tensor_v<result_value_type>,
72  "ContractReduce can only handle plain tensors or nested tensors "
73  "(tensors-of-tensors); mixed contractions are not supported");
74  static constexpr bool plain_tensors =
75  !(TiledArray::detail::is_tensor_v<left_value_type> &&
76  TiledArray::detail::is_tensor_v<right_value_type> &&
77  TiledArray::detail::is_tensor_v<result_value_type>);
78 
79  private:
80  struct Impl {
81  template <
82  typename Perm = BipartitePermutation,
83  typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>,
84  typename = std::enable_if_t<
85  TiledArray::detail::is_permutation_v<Perm> &&
86  std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>,
88  const right_value_type&>>>
89  Impl(const math::blas::Op left_op, const math::blas::Op right_op,
90  const scalar_type alpha, const unsigned int result_rank,
91  const unsigned int left_rank, const unsigned int right_rank,
92  const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {})
93  : gemm_helper_(left_op, right_op, result_rank, left_rank, right_rank),
94  alpha_(alpha),
95  perm_(perm),
96  elem_muladd_op_(std::forward<ElemMultAddOp>(elem_muladd_op)) {
97  // non-unit alpha must be absorbed into elem_muladd_op
98  if (elem_muladd_op_) TA_ASSERT(alpha == scalar_type(1));
99  }
100 
101  math::GemmHelper gemm_helper_;
102  scalar_type alpha_;
103  BipartitePermutation perm_;
105 
110  };
111 
112  std::shared_ptr<Impl> pimpl_;
113 
114  public:
115  // Compiler generated defaults are fine
116 
117  ContractReduceBase() = default;
120  ~ContractReduceBase() = default;
123 
125 
137  template <
138  typename Perm = BipartitePermutation,
139  typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>,
140  typename = std::enable_if_t<
141  TiledArray::detail::is_permutation_v<Perm> &&
142  std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>,
144  const right_value_type&>>>
146  const math::blas::Op right_op, const scalar_type alpha,
147  const unsigned int result_rank,
148  const unsigned int left_rank,
149  const unsigned int right_rank, const Perm& perm = {},
150  ElemMultAddOp&& elem_muladd_op = {})
151  : pimpl_(std::make_shared<Impl>(
152  left_op, right_op, alpha, result_rank, left_rank, right_rank, perm,
153  std::forward<ElemMultAddOp>(elem_muladd_op))) {}
154 
156 
158  const math::GemmHelper& gemm_helper() const {
159  TA_ASSERT(pimpl_);
160  return pimpl_->gemm_helper_;
161  }
162 
164 
166  const BipartitePermutation& perm() const {
167  TA_ASSERT(pimpl_);
168  return pimpl_->perm_;
169  }
170 
172 
174  scalar_type factor() const {
175  TA_ASSERT(pimpl_);
176  return pimpl_->alpha_;
177  }
178 
180 
182  const auto& elem_muladd_op() const {
183  TA_ASSERT(pimpl_);
184  return pimpl_->elem_muladd_op_;
185  }
186 
187  //-------------- these are only used for unit tests -----------------
188 
190 
192  unsigned int num_contract_ranks() const {
193  TA_ASSERT(pimpl_);
194  return pimpl_->gemm_helper_.num_contract_ranks();
195  }
196 
198 
200  unsigned int result_rank() const {
201  TA_ASSERT(pimpl_);
202  return pimpl_->gemm_helper_.result_rank();
203  }
204 
206 
208  unsigned int left_rank() const {
209  TA_ASSERT(pimpl_);
210  return pimpl_->gemm_helper_.left_rank();
211  }
212 
214 
216  unsigned int right_rank() const {
217  TA_ASSERT(pimpl_);
218  return pimpl_->gemm_helper_.right_rank();
219  }
220 
221 }; // class ContractReduceBase
222 
224 
231 template <typename Result, typename Left, typename Right, typename Scalar>
232 class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {
233  public:
242  typedef Result result_type;
243  typedef Scalar scalar_type;
244 
249 
250  // Compiler generated defaults are fine. N.B. this is shallow-copy.
251 
252  ContractReduce() = default;
253  ContractReduce(const ContractReduce_&) = default;
255  ~ContractReduce() = default;
258 
260 
272  template <
273  typename Perm = BipartitePermutation,
274  typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>,
275  typename = std::enable_if_t<
276  TiledArray::detail::is_permutation_v<Perm> &&
277  std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>,
279  const right_value_type&>>>
280  ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op,
281  const scalar_type alpha, const unsigned int result_rank,
282  const unsigned int left_rank, const unsigned int right_rank,
283  const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {})
284  : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
285  right_rank, perm,
286  std::forward<ElemMultAddOp>(elem_muladd_op)) {}
287 
289 
291  result_type operator()() const { return result_type(); }
292 
294  result_type operator()(const result_type& temp) const {
295  using TiledArray::empty;
296  TA_ASSERT(!empty(temp));
297 
298  if (!ContractReduceBase_::perm()) return temp;
299 
301  return permute(temp, ContractReduceBase_::perm());
302  }
303 
305 
310  void operator()(result_type& result, const result_type& arg) const {
311  using TiledArray::add_to;
312  add_to(result, arg);
313  }
314 
316 
322  void operator()(result_type& result, const first_argument_type& left,
323  const second_argument_type& right) const {
324  if constexpr (!ContractReduceBase_::plain_tensors) {
325  TA_ASSERT(this->elem_muladd_op());
326  // not yet implemented
327  using TiledArray::empty;
328  using TiledArray::gemm;
329  gemm(result, left, right, ContractReduceBase_::gemm_helper(),
330  this->elem_muladd_op());
331  } else { // plain tensors
332  TA_ASSERT(!this->elem_muladd_op());
333  using TiledArray::empty;
334  using TiledArray::gemm;
335  if (empty(result))
336  result = gemm(left, right, ContractReduceBase_::factor(),
338  else
339  gemm(result, left, right, ContractReduceBase_::factor(),
341  }
342  }
343 
344 }; // class ContractReduce
345 
347 
353 template <typename Result, typename Left, typename Right>
354 class ContractReduce<Result, Left, Right,
356  : public ContractReduceBase<Result, Left, Right,
357  TiledArray::detail::ComplexConjugate<void>> {
358  public:
359  typedef ContractReduce<Result, Left, Right,
362  typedef ContractReduceBase<Result, Left, Right,
369  typedef decltype(gemm(std::declval<Left>(), std::declval<Right>(), 1,
370  std::declval<math::GemmHelper>()))
372  typedef TiledArray::detail::ComplexConjugate<void> scalar_type;
373 
375  using typename ContractReduceBase_::left_value_type;
376  using typename ContractReduceBase_::result_value_type;
377  using typename ContractReduceBase_::right_value_type;
378 
379  // Compiler generated defaults are fine. N.B. This has shallow copy semantics.
380 
381  ContractReduce() = default;
382  ContractReduce(const ContractReduce_&) = default;
384  ~ContractReduce() = default;
385  ContractReduce_& operator=(const ContractReduce_&) = default;
386  ContractReduce_& operator=(ContractReduce_&&) = default;
387 
389 
400  template <
401  typename Perm = BipartitePermutation,
402  typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>,
403  typename = std::enable_if_t<
404  TiledArray::detail::is_permutation_v<Perm> &&
405  std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>,
407  const right_value_type&>>>
408  ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op,
409  const scalar_type alpha, const unsigned int result_rank,
410  const unsigned int left_rank, const unsigned int right_rank,
411  const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {})
412  : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
413  right_rank, perm,
414  std::forward<ElemMultAddOp>(elem_muladd_op)) {}
415 
417 
419  result_type operator()() const { return result_type(); }
420 
423  using TiledArray::empty;
424  TA_ASSERT(!empty(temp));
425 
426  if (!ContractReduceBase_::perm()) {
427  using TiledArray::conj_to;
428  return conj_to(temp);
429  }
430 
431  using TiledArray::conj;
432  return conj(temp, ContractReduceBase_::perm());
433  }
434 
436 
441  void operator()(result_type& result, const result_type& arg) const {
442  using TiledArray::add_to;
443  add_to(result, arg);
444  }
445 
447 
453  void operator()(result_type& result, const first_argument_type& left,
454  const second_argument_type& right) const {
455  if constexpr (!ContractReduceBase_::plain_tensors) {
456  TA_ASSERT(this->elem_muladd_op());
457  // not yet implemented
458  abort();
459  } else {
460  TA_ASSERT(!this->elem_muladd_op());
461  using TiledArray::empty;
462  using TiledArray::gemm;
463  if (empty(result))
464  result = gemm(left, right, 1, ContractReduceBase_::gemm_helper());
465  else
466  gemm(result, left, right, 1, ContractReduceBase_::gemm_helper());
467  }
468  }
469 
470 }; // class ContractReduce
471 
473 
480 template <typename Result, typename Left, typename Right, typename Scalar>
481 class ContractReduce<Result, Left, Right,
483  : public ContractReduceBase<Result, Left, Right,
484  TiledArray::detail::ComplexConjugate<Scalar>> {
485  public:
486  typedef ContractReduce<Result, Left, Right,
489  typedef ContractReduceBase<Result, Left, Right,
496  typedef decltype(gemm(std::declval<Left>(), std::declval<Right>(), 1,
497  std::declval<math::GemmHelper>()))
499  typedef TiledArray::detail::ComplexConjugate<Scalar> scalar_type;
500 
502  using typename ContractReduceBase_::left_value_type;
503  using typename ContractReduceBase_::result_value_type;
504  using typename ContractReduceBase_::right_value_type;
505 
507  ContractReduce() = default;
508  ContractReduce(const ContractReduce_&) = default;
510  ~ContractReduce() = default;
511  ContractReduce_& operator=(const ContractReduce_&) = default;
512  ContractReduce_& operator=(ContractReduce_&&) = default;
513 
515 
526  template <
527  typename Perm = BipartitePermutation,
528  typename ElemMultAddOp = TiledArray::function_ref<elem_muladd_op_type>,
529  typename = std::enable_if_t<
530  TiledArray::detail::is_permutation_v<Perm> &&
531  std::is_invocable_r_v<void, std::remove_reference_t<ElemMultAddOp>,
533  const right_value_type&>>>
534  ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op,
535  const scalar_type alpha, const unsigned int result_rank,
536  const unsigned int left_rank, const unsigned int right_rank,
537  const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {})
538  : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
539  right_rank, perm,
540  std::forward<ElemMultAddOp>(elem_muladd_op)) {}
541 
543 
545  result_type operator()() const { return result_type(); }
546 
549  using TiledArray::empty;
550  TA_ASSERT(!empty(temp));
551 
552  if (!ContractReduceBase_::perm()) {
553  using TiledArray::conj_to;
554  return conj_to(temp, ContractReduceBase_::factor().factor());
555  }
556 
557  using TiledArray::conj;
558  return conj(temp, ContractReduceBase_::factor().factor(),
560  }
561 
563 
568  void operator()(result_type& result, const result_type& arg) const {
569  using TiledArray::add_to;
570  add_to(result, arg);
571  }
572 
574 
580  void operator()(result_type& result, const first_argument_type& left,
581  const second_argument_type& right) const {
582  if constexpr (!ContractReduceBase_::plain_tensors) {
583  TA_ASSERT(this->elem_muladd_op());
584  // not yet implemented
585  abort();
586  } else {
587  TA_ASSERT(!this->elem_muladd_op());
588  using TiledArray::empty;
589  using TiledArray::gemm;
590  if (empty(result))
591  result = gemm(left, right, 1, ContractReduceBase_::gemm_helper());
592  else
593  gemm(result, left, right, 1, ContractReduceBase_::gemm_helper());
594  }
595  }
596 
597 }; // class ContractReduce
598 
599 } // namespace detail
600 } // namespace TiledArray
601 
602 #endif // TILEDARRAY_CONTRACT_REDUCE_H__INCLUDED
Contraction to *GEMM helper.
Definition: gemm_helper.h:40
ContractReduceBase(const ContractReduceBase_ &)=default
::blas::Op Op
Definition: blas.h:46
const BipartitePermutation & perm() const
Permutation accessor.
ContractReduce_ & operator=(ContractReduce_ &&)=default
ContractReduceBase_ & operator=(const ContractReduceBase_ &)=default
Contract and (sum) reduce operation.
void operator()(result_type &result, const first_argument_type &left, const second_argument_type &right) const
Contract a pair of tiles and add to a target tile.
ContractReduce_ & operator=(const ContractReduce_ &)=default
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Perm &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
Definition: permute.h:117
ContractReduce(const ContractReduce_ &)=default
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduceBase_
This class type.
const math::GemmHelper & gemm_helper() const
Gemm meta data accessor.
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
ContractReduceBase(ContractReduceBase_ &&)=default
result_type operator()(const result_type &temp) const
Post processing step.
decltype(auto) conj(const Tile< Arg > &arg)
Create a complex conjugated copy of a tile.
Definition: tile.h:1256
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduce_
This class type.
ContractReduceBase(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={})
Construct contract/reduce functor.
ContractReduce(ContractReduce_ &&)=default
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
Contract and (sum) reduce base.
Result result_type
The result tile type.
ContractReduce< Result, Left, Right, Scalar > ContractReduce_
This class type.
const Left & first_argument_type
The left tile type.
unsigned int result_rank() const
Result rank accessor.
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
ContractReduceBase_::first_argument_type first_argument_type
The left tile type.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
Tile< Result > & add_to(Tile< Result > &result, const Tile< Arg > &arg)
Add to the result tile.
Definition: tile.h:831
void operator()(result_type &result, const first_argument_type &left, const second_argument_type &right) const
Contract a pair of tiles and add to a target tile.
result_type operator()() const
Create a result type object.
unsigned int left_rank() const
Left-hand argument rank accessor.
const auto & elem_muladd_op() const
Element multiply-add op accessor.
Permute a tile.
Definition: permute.h:134
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
scalar_type factor() const
Scaling factor accessor.
ContractReduceBase_ & operator=(ContractReduceBase_ &&)=default
constexpr bool empty()
Test for empty tensors in an empty list.
Definition: utility.h:320
unsigned int right_rank() const
Right-hand argument rank accessor.
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduce_
This class type.
Scalar scalar_type
The scaling factor type.
Specialization of ComplexConjugate for the case of a unit/identity factor.
Definition: complex.h:143
bool empty(const Tile< Arg > &arg)
Check that arg is empty (no data)
Definition: tile.h:646
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
ContractReduceBase_::second_argument_type second_argument_type
The right tile type.
ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Perm &perm={}, ElemMultAddOp &&elem_muladd_op={})
Construct contract/reduce functor.
const Right & second_argument_type
The right tile type.
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
Permutation of a bipartite set.
Definition: permutation.h:610
TILEDARRAY_FORCE_INLINE R conj(const R r)
Wrapper function for std::conj
Definition: complex.h:45
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduceBase_
This class type.
void operator()(result_type &result, const first_argument_type &left, const second_argument_type &right) const
Contract a pair of tiles and add to a target tile.
Tile< Result > & conj_to(Tile< Result > &result)
In-place complex conjugate a tile.
Definition: tile.h:1311
decltype(auto) gemm(const Tile< Left > &left, const Tile< Right > &right, const Scalar factor, const math::GemmHelper &gemm_config)
Contract 2 tensors over head/tail modes and scale the product.
Definition: tile.h:1396