TiledArray  0.7.0
contract_reduce.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2013 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * contract_reduce.h
22  * Oct 9, 2013
23  *
24  */
25 
26 #ifndef TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED
27 #define TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED
28 
29 #include <TiledArray/permutation.h>
32 #include "../tile_interface/add.h"
33 #include "../tile_interface/permute.h"
35 
36 namespace TiledArray {
37  namespace detail {
38 
40 
47  template <typename Result, typename Left, typename Right, typename Scalar>
49  public:
52  typedef const Left& first_argument_type;
53  typedef const Right& second_argument_type;
54  typedef Result result_type;
55  typedef Scalar scalar_type;
56 
57  private:
58 
59  struct Impl {
60  Impl(const madness::cblas::CBLAS_TRANSPOSE left_op,
61  const madness::cblas::CBLAS_TRANSPOSE right_op,
62  const scalar_type alpha, const unsigned int result_rank,
63  const unsigned int left_rank, const unsigned int right_rank,
64  const Permutation& perm = Permutation()) :
65  gemm_helper_(left_op, right_op, result_rank, left_rank, right_rank),
66  alpha_(alpha), perm_(perm)
67  { }
68 
69  math::GemmHelper gemm_helper_;
70  scalar_type alpha_;
71  Permutation perm_;
73  };
75 
76  std::shared_ptr<Impl> pimpl_;
77 
78  public:
79 
80  // Compiler generated defaults are fine
81 
82  ContractReduceBase() = default;
83  ContractReduceBase(const ContractReduceBase_&) = default;
85  ~ContractReduceBase() = default;
88 
90 
99  ContractReduceBase(const madness::cblas::CBLAS_TRANSPOSE left_op,
100  const madness::cblas::CBLAS_TRANSPOSE right_op,
101  const scalar_type alpha, const unsigned int result_rank,
102  const unsigned int left_rank, const unsigned int right_rank,
103  const Permutation& perm = Permutation()) :
104  pimpl_(std::make_shared<Impl>(left_op, right_op, alpha, result_rank, left_rank,
105  right_rank, perm))
106  { }
107 
108 
110 
112  const math::GemmHelper& gemm_helper() const {
113  TA_ASSERT(pimpl_);
114  return pimpl_->gemm_helper_;
115  }
116 
118 
120  const Permutation& perm() const {
121  TA_ASSERT(pimpl_);
122  return pimpl_->perm_;
123  }
124 
125 
127 
129  scalar_type factor() const {
130  TA_ASSERT(pimpl_);
131  return pimpl_->alpha_;
132  }
133 
134  //-------------- these are only used for unit tests -----------------
135 
137 
139  unsigned int num_contract_ranks() const {
140  TA_ASSERT(pimpl_);
141  return pimpl_->gemm_helper_.num_contract_ranks();
142  }
143 
145 
147  unsigned int result_rank() const {
148  TA_ASSERT(pimpl_);
149  return pimpl_->gemm_helper_.result_rank();
150  }
151 
153 
155  unsigned int left_rank() const {
156  TA_ASSERT(pimpl_);
157  return pimpl_->gemm_helper_.left_rank();
158  }
159 
161 
163  unsigned int right_rank() const {
164  TA_ASSERT(pimpl_);
165  return pimpl_->gemm_helper_.right_rank();
166  }
167 
168  }; // class ContractReduceBase
169 
171 
177  template <typename Result, typename Left, typename Right, typename Scalar>
179  public ContractReduceBase<Result, Left, Right, Scalar>
180  {
181  public:
190  typedef Result result_type;
191  typedef Scalar scalar_type;
192 
193  // Compiler generated defaults are fine. N.B. this is shallow-copy.
194 
195  ContractReduce() = default;
196  ContractReduce(const ContractReduce_&) = default;
197  ContractReduce(ContractReduce_&&) = default;
198  ~ContractReduce() = default;
199  ContractReduce_& operator=(const ContractReduce_&) = default;
201 
203 
212  ContractReduce(const madness::cblas::CBLAS_TRANSPOSE left_op,
213  const madness::cblas::CBLAS_TRANSPOSE right_op,
214  const scalar_type alpha, const unsigned int result_rank,
215  const unsigned int left_rank, const unsigned int right_rank,
216  const Permutation& perm = Permutation()) :
217  ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
218  right_rank, perm)
219  { }
220 
221 
223 
226  return result_type();
227  }
228 
230  result_type operator()(const result_type& temp) const {
231  using TiledArray::empty;
232  TA_ASSERT(! empty(temp));
233 
235  return temp;
236 
238  return permute(temp, ContractReduceBase_::perm());
239  }
240 
242 
247  void operator()(result_type& result, const result_type& arg) const {
248  using TiledArray::add_to;
249  add_to(result, arg);
250  }
251 
253 
260  second_argument_type right) const
261  {
262  using TiledArray::empty;
263  using TiledArray::gemm;
264  if(empty(result))
265  result = gemm(left, right, ContractReduceBase_::factor(),
267  else
268  gemm(result, left, right, ContractReduceBase_::factor(),
270  }
271 
272  }; // class ContractReduce
273 
274 
276 
281  template <typename Result, typename Left, typename Right>
282  class ContractReduce<Result, Left, Right,
284  public ContractReduceBase<Result, Left, Right,
285  TiledArray::detail::ComplexConjugate<void> >
286  {
287  public:
288  typedef ContractReduce<Result, Left, Right,
291  typedef ContractReduceBase<Result, Left, Right,
298  typedef decltype(gemm(std::declval<Left>(), std::declval<Right>(), 1,
299  std::declval<math::GemmHelper>()))
301  typedef TiledArray::detail::ComplexConjugate<void> scalar_type;
302 
303  // Compiler generated defaults are fine. N.B. This has shallow copy semantics.
304 
305  ContractReduce() = default;
306  ContractReduce(const ContractReduce_&) = default;
307  ContractReduce(ContractReduce_&&) = default;
308  ~ContractReduce() = default;
309  ContractReduce_& operator=(const ContractReduce_&) = default;
310  ContractReduce_& operator=(ContractReduce_&&) = default;
311 
313 
322  ContractReduce(const madness::cblas::CBLAS_TRANSPOSE left_op,
323  const madness::cblas::CBLAS_TRANSPOSE right_op,
324  const scalar_type alpha, const unsigned int result_rank,
325  const unsigned int left_rank, const unsigned int right_rank,
326  const Permutation& perm = Permutation()) :
327  ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
328  right_rank, perm)
329  { }
330 
331 
333 
336  return result_type();
337  }
338 
341  using TiledArray::empty;
342  TA_ASSERT(! empty(temp));
343 
344  if(! ContractReduceBase_::perm()) {
345  using TiledArray::conj_to;
346  return conj_to(temp);
347  }
348 
349  using TiledArray::conj;
350  return conj(temp, ContractReduceBase_::perm());
351  }
352 
354 
359  void operator()(result_type& result, const result_type& arg) const {
360  using TiledArray::add_to;
361  add_to(result, arg);
362  }
363 
365 
372  second_argument_type right) const
373  {
374  using TiledArray::empty;
375  using TiledArray::gemm;
376  if(empty(result))
377  result = gemm(left, right, 1, ContractReduceBase_::gemm_helper());
378  else
379  gemm(result, left, right, 1, ContractReduceBase_::gemm_helper());
380  }
381 
382  }; // class ContractReduce
383 
384 
386 
393  template <typename Result, typename Left, typename Right, typename Scalar>
394  class ContractReduce<Result, Left, Right,
396  public ContractReduceBase<Result, Left, Right,
397  TiledArray::detail::ComplexConjugate<Scalar> >
398  {
399  public:
400  typedef ContractReduce<Result, Left, Right,
403  typedef ContractReduceBase<Result, Left, Right,
410  typedef decltype(gemm(std::declval<Left>(), std::declval<Right>(), 1,
411  std::declval<math::GemmHelper>()))
413  typedef TiledArray::detail::ComplexConjugate<Scalar> scalar_type;
414 
416  ContractReduce() = default;
417  ContractReduce(const ContractReduce_&) = default;
418  ContractReduce(ContractReduce_&&) = default;
419  ~ContractReduce() = default;
420  ContractReduce_& operator=(const ContractReduce_&) = default;
421  ContractReduce_& operator=(ContractReduce_&&) = default;
422 
424 
433  ContractReduce(const madness::cblas::CBLAS_TRANSPOSE left_op,
434  const madness::cblas::CBLAS_TRANSPOSE right_op,
435  const scalar_type alpha, const unsigned int result_rank,
436  const unsigned int left_rank, const unsigned int right_rank,
437  const Permutation& perm = Permutation()) :
438  ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank,
439  right_rank, perm)
440  { }
441 
442 
444 
447  return result_type();
448  }
449 
452  using TiledArray::empty;
453  TA_ASSERT(! empty(temp));
454 
455  if(! ContractReduceBase_::perm()) {
456  using TiledArray::conj_to;
457  return conj_to(temp, ContractReduceBase_::factor().factor());
458  }
459 
460  using TiledArray::conj;
461  return conj(temp, ContractReduceBase_::factor().factor(),
463  }
464 
466 
471  void operator()(result_type& result, const result_type& arg) const {
472  using TiledArray::add_to;
473  add_to(result, arg);
474  }
475 
477 
484  second_argument_type right) const
485  {
486  using TiledArray::empty;
487  using TiledArray::gemm;
488  if(empty(result))
489  result = gemm(left, right, 1, ContractReduceBase_::gemm_helper());
490  else
491  gemm(result, left, right, 1, ContractReduceBase_::gemm_helper());
492  }
493 
494  }; // class ContractReduce
495 
496  } // namespace detail
497 } // namespace TiledArray
498 
499 #endif // TILEDARRAY_CONTRACT_REDUCE_H__INCLUDED
ContractReduceBase_::first_argument_type first_argument_type
The left tile type.
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduceBase_
This class type.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
ContractReduceBase< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduceBase_
This class type.
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
ContractReduce_ & operator=(const ContractReduce_ &)=default
void operator()(result_type &result, const result_type &arg) const
Reduce two result objects.
const Left & first_argument_type
The left tile type.
void permute(InputOp &&input_op, OutputOp &&output_op, Result &result, const Permutation &perm, const Arg0 &arg0, const Args &... args)
Construct a permuted tensor copy.
Definition: permute.h:122
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
result_type operator()(const result_type &temp) const
Post processing step.
result_type operator()() const
Create a result type object.
ContractReduce< Result, Left, Right, Scalar > ContractReduce_
This class type.
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
STL namespace.
Tile< Result > & add_to(Tile< Result > &result, const Tile< Arg > &arg)
Add to the result tile.
Definition: tile.h:441
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
unsigned int right_rank() const
Right-hand argument rank accessor.
const math::GemmHelper & gemm_helper() const
Gemm meta data accessor.
decltype(auto) conj(const Tile< Arg > &arg)
Create a complex conjugated copy of a tile.
Definition: tile.h:772
Permute a tile.
Definition: permute.h:130
unsigned int result_rank() const
Result rank accessor.
const Right & second_argument_type
The right tile type.
ContractReduceBase< Result, Left, Right, Scalar > ContractReduceBase_
This class type.
ContractReduceBase(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Permutation &perm=Permutation())
Construct contract/reduce functor.
Result result_type
The result tile type.
Contract and (sum) reduce base.
ContractReduceBase_ & operator=(const ContractReduceBase_ &)=default
constexpr bool empty()
Test for empty tensors in an empty list.
Definition: utility.h:374
ContractReduceBase_::second_argument_type second_argument_type
The right tile type.
#define TA_ASSERT(a)
Definition: error.h:107
Scalar scalar_type
The scaling factor type.
scalar_type factor() const
Scaling factor accessor.
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
TILEDARRAY_FORCE_INLINE R conj(const R r)
Wrapper function for std::conj
Definition: complex.h:44
Contraction to *GEMM helper.
Definition: gemm_helper.h:39
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:119
bool empty(const Tile< Arg > &arg)
Check that arg is empty (no data)
Definition: tile.h:305
Contract and (sum) reduce operation.
unsigned int left_rank() const
Left-hand argument rank accessor.
ContractReduce(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, const Permutation &perm=Permutation())
Construct contract/reduce functor.
void operator()(result_type &result, first_argument_type left, second_argument_type right) const
Contract a pair of tiles and add to a target tile.
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< Scalar > > ContractReduce_
This class type.
ContractReduce< Result, Left, Right, TiledArray::detail::ComplexConjugate< void > > ContractReduce_
This class type.
const Permutation & perm() const
Permutation accessor.
decltype(auto) gemm(const Tile< Left > &left, const Tile< Right > &right, const Scalar factor, const math::GemmHelper &gemm_config)
Contract and scale tile arguments.
Definition: tile.h:857
Result & conj_to(Tile< Result > &result)
In-place complex conjugate a tile.
Definition: tile.h:820
Specialization of ComplexConjugate<S> for the case of a unit/identity factor.
Definition: complex.h:135