permopt.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2020 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Eduard Valeyev
19  * Department of Chemistry, Virginia Tech
20  *
21  * permopt.h
22  * Nov 2, 2020
23  *
24  */
25 
26 #ifndef TILEDARRAY_EXPRESSIONS_PERMOPT_H__INCLUDED
27 #define TILEDARRAY_EXPRESSIONS_PERMOPT_H__INCLUDED
28 
31 #include <TiledArray/permutation.h>
32 #include <memory>
33 
34 namespace TiledArray {
35 namespace expressions {
36 
37 // clang-format off
43 // clang-format on
44 enum class PermutationType { identity = 1, matrix_transpose = 2, general = 3 };
45 
47  TA_ASSERT(permtype == PermutationType::matrix_transpose ||
48  permtype == PermutationType::identity);
49  return permtype == PermutationType::matrix_transpose
50  ? math::blas::Transpose
51  : math::blas::NoTranspose;
52 }
53 
56  public:
63  const IndexList& right_indices,
64  const bool prefer_to_permute_left = true)
65  : left_indices_(left_indices),
66  right_indices_(right_indices),
67  prefer_to_permute_left_(prefer_to_permute_left) {}
68 
78  const IndexList& left_indices,
79  const IndexList& right_indices,
80  const bool prefer_to_permute_left = true)
81  : result_indices_(result_indices),
82  left_indices_(left_indices),
83  right_indices_(right_indices),
84  prefer_to_permute_left_(prefer_to_permute_left) {}
85 
89  default;
90  virtual ~BinaryOpPermutationOptimizer() = default;
91 
93  const IndexList& result_indices() const {
94  TA_ASSERT(result_indices_);
95  return result_indices_;
96  }
98  const IndexList& left_indices() const { return left_indices_; }
100  const IndexList& right_indices() const { return right_indices_; }
102  bool prefer_to_permute_left() const { return prefer_to_permute_left_; }
103 
105  virtual const IndexList& target_left_indices() const = 0;
107  virtual const IndexList& target_right_indices() const = 0;
110  virtual const IndexList& target_result_indices() const = 0;
113  virtual PermutationType left_permtype() const = 0;
116  virtual PermutationType right_permtype() const = 0;
118  virtual TensorProduct op_type() const = 0;
119 
120  private:
121  IndexList result_indices_, left_indices_, right_indices_;
122  bool prefer_to_permute_left_;
123 };
124 
125 // clang-format off
129 // clang-format on
131  public:
134  default;
135  virtual ~GEMMPermutationOptimizer() = default;
136 
138  const IndexList& right_indices,
139  const bool prefer_to_permute_left = true)
142  std::tie(target_left_indices_, target_right_indices_,
143  target_result_indices_, left_permtype_, right_permtype_) =
144  compute_index_list_contraction(left_indices, right_indices,
146  }
147 
149  const IndexList& left_indices,
150  const IndexList& right_indices,
151  const bool prefer_to_permute_left = true)
154  std::tie(target_left_indices_, target_right_indices_,
155  target_result_indices_, left_permtype_, right_permtype_) =
156  compute_index_list_contraction(left_indices, right_indices,
158  }
159 
160  const IndexList& target_left_indices() const override final {
161  return target_left_indices_;
162  }
163  const IndexList& target_right_indices() const override final {
164  return target_right_indices_;
165  }
166  const IndexList& target_result_indices() const override final {
167  return target_result_indices_;
168  }
169  PermutationType left_permtype() const override final {
170  return left_permtype_;
171  }
172  PermutationType right_permtype() const override final {
173  return right_permtype_;
174  }
175  TensorProduct op_type() const override final {
176  return TensorProduct::Contraction;
177  }
178 
179  private:
180  IndexList target_left_indices_, target_right_indices_, target_result_indices_;
181  PermutationType left_permtype_, right_permtype_;
182 
183  static auto find(const IndexList& indices, const std::string& idx,
184  unsigned int i, const unsigned int n) {
185  const auto b = indices.begin() + i;
186  const auto e = indices.begin() + n;
187  const auto it = std::find(b, e, idx);
188  return i + std::distance(b, it);
189  };
190 
191  // clang-format off
199  // clang-format on
200  inline std::tuple<IndexList, IndexList, IndexList, PermutationType,
202  compute_index_list_contraction(const IndexList& left_indices,
203  const IndexList& right_indices,
204  const bool prefer_to_permute_left = true) {
205  const auto left_rank = left_indices.size();
206  const auto right_rank = right_indices.size();
207 
208  container::svector<std::string> result_left_indices;
209  result_left_indices.reserve(left_rank);
210  container::svector<std::string> result_right_indices;
211  result_right_indices.reserve(right_rank);
213  result_indices.reserve(std::max(left_rank, right_rank));
214 
215  // Extract left-most result and inner indices from the left-hand argument.
216  for (unsigned int i = 0ul; i < left_rank; ++i) {
217  const std::string& var = left_indices[i];
218  if (find(right_indices, var, 0u, right_rank) == right_rank) {
219  // Store outer left variable
220  result_left_indices.push_back(var);
221  result_indices.push_back(var);
222  } else {
223  // Store inner left variable
224  result_right_indices.push_back(var);
225  }
226  }
227 
228  // Compute the inner and outer dimension ranks.
229  const unsigned int inner_rank = result_right_indices.size();
230  const unsigned int left_outer_rank = result_left_indices.size();
231  const unsigned int right_outer_rank = right_rank - inner_rank;
232  const unsigned int result_rank = left_outer_rank + right_outer_rank;
233 
234  // Resize result indices if necessary.
235  result_indices.reserve(result_rank);
236 
237  // If an outer product, result = concat of free indices from left and right
238  if (inner_rank == 0u) {
239  // Extract the right most outer variables from right hand argument.
240  for (unsigned int i = 0ul; i < right_rank; ++i) {
241  const std::string& var = right_indices[i];
242  result_right_indices.push_back(var);
243  result_indices.push_back(var);
244  }
245  // early return for the inner product since will make the result to be
246  // pure concat of the left and right index lists
247  return std::make_tuple(
248  IndexList(result_left_indices), IndexList(result_right_indices),
251  }
252 
253  // Initialize flags that will be used to determine the type of permutation
254  // that will be applied to the arguments (i.e. no permutation, matrix
255  // transpose, or arbitrary permutation).
256  bool inner_indices_ordered = true, left_is_no_trans = true,
257  left_is_trans = true, right_is_no_trans = true, right_is_trans = true;
258 
259  // If the inner index lists of the arguments are not in the same
260  // order, one of them will need to be permuted. Here, we determine which
261  // argument, left or right, will be permuted if a permutation is
262  // required. The argument with the lowest rank is preferred since it is
263  // likely to have the smaller memory footprint.
264  const bool perm_left =
265  (left_rank < right_rank) ||
266  ((left_rank == right_rank) && prefer_to_permute_left);
267 
268  // Extract variables from the right-hand argument, collect information
269  // about the layout of the index lists, and ensure the inner variable
270  // lists are in the same order.
271  for (unsigned int i = 0ul; i < right_rank; ++i) {
272  const std::string& idx = right_indices[i];
273  const unsigned int j = find(left_indices, idx, 0u, left_rank);
274  if (j == left_rank) {
275  // Store outer right index
276  result_right_indices.push_back(idx);
277  result_indices.push_back(idx);
278  } else {
279  const unsigned int x = result_left_indices.size() - left_outer_rank;
280 
281  // Collect information about the relative position of variables
282  inner_indices_ordered =
283  inner_indices_ordered && (result_right_indices[x] == idx);
284  left_is_no_trans = left_is_no_trans && (j >= left_outer_rank);
285  left_is_trans = left_is_trans && (j < inner_rank);
286  right_is_no_trans = right_is_no_trans && (i < inner_rank);
287  right_is_trans = right_is_trans && (i >= right_outer_rank);
288 
289  // Store inner right index
290  if (inner_indices_ordered) {
291  // Left and right inner index list order is equal.
292  result_left_indices.push_back(idx);
293  } else if (perm_left) {
294  // Permute left so we need to store inner indices according to
295  // the order of the right-hand argument.
296  result_left_indices.push_back(idx);
297  result_right_indices[x] = idx;
298  left_is_no_trans = left_is_trans = false;
299  } else {
300  // Permute right so we need to store inner indices according to
301  // the order of the left-hand argument.
302  result_left_indices.push_back(result_right_indices[x]);
303  right_is_no_trans = right_is_trans = false;
304  }
305  }
306  }
307 
308  auto to_tensor_op = [](bool no_trans, bool trans) {
309  if (no_trans)
311  else if (trans)
312  return PermutationType::matrix_transpose;
313  else
315  };
316 
317  return std::make_tuple(IndexList(result_left_indices),
318  IndexList(result_right_indices),
319  IndexList(result_indices),
320  to_tensor_op(left_is_no_trans, left_is_trans),
321  to_tensor_op(right_is_no_trans, right_is_trans));
322  }
323 
324  // clang-format off
332  // clang-format on
333  inline std::tuple<IndexList, IndexList, IndexList, PermutationType,
335  compute_index_list_contraction(const IndexList& target_indices,
336  const IndexList& left_indices,
337  const IndexList& right_indices,
338  const bool prefer_to_permute_left = true) {
339  IndexList result_indices_, left_indices_,
340  right_indices_; // intermediate index lists, computed without taking
341  // target_indices into account
342  PermutationType left_op_, right_op_;
343  std::tie(result_indices_, left_indices_, right_indices_, left_op_,
344  right_op_) =
345  compute_index_list_contraction(left_indices, right_indices,
347 
348  container::svector<std::string> final_left_indices(left_indices_.size()),
349  final_right_indices(right_indices_.size()),
350  final_result_indices(result_indices_.size());
351 
352  // Only permute if the arguments can be permuted
353  if ((left_op_ == PermutationType::general) ||
354  (right_op_ == PermutationType::general)) {
355  // Compute ranks
356  const unsigned int result_rank = target_indices.size();
357  const unsigned int inner_rank =
358  (left_indices_.size() + right_indices_.size() - result_rank) >> 1;
359  const unsigned int left_outer_rank = left_indices_.size() - inner_rank;
360 
361  // Check that the left- and right-hand outer variables are correctly
362  // partitioned in the target index list.
363  bool target_partitioned = true;
364  for (unsigned int i = 0u; i < left_outer_rank; ++i)
365  target_partitioned =
366  target_partitioned && (find(target_indices, left_indices_[i], 0u,
367  left_outer_rank) < left_outer_rank);
368 
369  // If target is properly partitioned, then arguments can be permuted
370  // to fit the target.
371  if (target_partitioned) {
372  if (left_op_ == PermutationType::general) {
373  // Copy left-hand target variables to left and result index lists.
374  for (unsigned int i = 0u; i < left_outer_rank; ++i) {
375  const std::string& var = target_indices[i];
376  final_left_indices[i] = var;
377  final_result_indices[i] = var;
378  }
379  } else {
380  // Copy left-hand outer variables to that of result.
381  for (unsigned int i = 0u; i < left_outer_rank; ++i)
382  final_result_indices[i] = left_indices_[i];
383  }
384 
385  if (right_op_ == PermutationType::general) {
386  // Copy right-hand target variables to right and result variable
387  // lists.
388  for (unsigned int i = left_outer_rank, j = inner_rank;
389  i < result_rank; ++i, ++j) {
390  const std::string& var = target_indices[i];
391  final_right_indices[j] = var;
392  final_result_indices[i] = var;
393  }
394  } else {
395  // Copy right-hand outer variables to that of result.
396  for (unsigned int i = left_outer_rank, j = inner_rank;
397  i < result_rank; ++i, ++j)
398  final_result_indices[i] = right_indices_[j];
399  }
400  }
401  }
402 
403  return std::make_tuple(
404  IndexList(final_left_indices), IndexList(final_right_indices),
405  IndexList(final_result_indices), left_op_, right_op_);
406  }
407 };
408 
409 // clang-format off
412 // clang-format on
414  public:
417  default;
419 
421  const IndexList& right_indices,
422  const bool prefer_to_permute_left = true)
425  target_result_indices_(prefer_to_permute_left ? right_indices
426  : left_indices) {
428  }
429 
431  const IndexList& left_indices,
432  const IndexList& right_indices,
433  const bool prefer_to_permute_left = true)
438 
439  // Determine the equality of the index lists
440  bool left_target = true, right_target = true, left_right = true;
441  for (unsigned int i = 0u; i < result_indices.size(); ++i) {
442  left_target = left_target && left_indices[i] == result_indices[i];
443  right_target = right_target && right_indices[i] == result_indices[i];
444  left_right = left_right && left_indices[i] == right_indices[i];
445  }
446 
447  if (left_right) {
448  target_result_indices_ = left_indices;
449  } else {
450  // Determine which argument will be permuted
451  const bool perm_left =
452  (right_target ||
453  ((!(left_target || right_target)) && prefer_to_permute_left));
454 
455  target_result_indices_ = perm_left ? right_indices : left_indices;
456  }
457  }
458 
459  const IndexList& target_left_indices() const override final {
460  return target_result_indices_;
461  }
462  const IndexList& target_right_indices() const override final {
463  return target_result_indices_;
464  }
465  const IndexList& target_result_indices() const override final {
466  return target_result_indices_;
467  }
468  PermutationType left_permtype() const override final {
470  }
471  PermutationType right_permtype() const override final {
473  }
474  TensorProduct op_type() const override final {
475  return TensorProduct::Hadamard;
476  }
477 
478  private:
479  IndexList target_result_indices_;
480 };
481 
483  public:
485  default;
487  const NullBinaryOpPermutationOptimizer&) = default;
489 
491  const IndexList& right_indices,
492  const bool prefer_to_permute_left = true)
497  }
498 
500  const IndexList& left_indices,
501  const IndexList& right_indices,
502  const bool prefer_to_permute_left = true)
508  }
509 
510  const IndexList& target_left_indices() const override final {
511  return left_indices();
512  }
513  const IndexList& target_right_indices() const override final {
514  return right_indices();
515  }
516  const IndexList& target_result_indices() const override final {
517  return left_indices();
518  }
519  PermutationType left_permtype() const override final {
521  }
522  PermutationType right_permtype() const override final {
524  }
525  TensorProduct op_type() const override final {
526  return TensorProduct::Invalid;
527  }
528 };
529 
530 inline std::shared_ptr<BinaryOpPermutationOptimizer> make_permutation_optimizer(
531  TensorProduct product_type, const IndexList& left_indices,
532  const IndexList& right_indices, bool prefer_to_permute_left) {
533  switch (product_type) {
534  case TensorProduct::Hadamard:
535  return std::make_shared<HadamardPermutationOptimizer>(
536  left_indices, right_indices, prefer_to_permute_left);
537  case TensorProduct::Contraction:
538  return std::make_shared<GEMMPermutationOptimizer>(
539  left_indices, right_indices, prefer_to_permute_left);
541  return std::make_shared<NullBinaryOpPermutationOptimizer>(
542  left_indices, right_indices, prefer_to_permute_left);
543  default:
544  abort();
545  }
546 }
547 
548 inline std::shared_ptr<BinaryOpPermutationOptimizer> make_permutation_optimizer(
549  TensorProduct product_type, const IndexList& target_indices,
550  const IndexList& left_indices, const IndexList& right_indices,
551  bool prefer_to_permute_left) {
552  switch (product_type) {
553  case TensorProduct::Hadamard:
554  return std::make_shared<HadamardPermutationOptimizer>(
555  target_indices, left_indices, right_indices, prefer_to_permute_left);
556  case TensorProduct::Contraction:
557  return std::make_shared<GEMMPermutationOptimizer>(
558  target_indices, left_indices, right_indices, prefer_to_permute_left);
560  return std::make_shared<NullBinaryOpPermutationOptimizer>(
561  target_indices, left_indices, right_indices, prefer_to_permute_left);
562  default:
563  abort();
564  }
565 }
566 
567 inline std::shared_ptr<BinaryOpPermutationOptimizer> make_permutation_optimizer(
568  const IndexList& left_indices, const IndexList& right_indices,
569  bool prefer_to_permute_left) {
571  compute_product_type(left_indices, right_indices), left_indices,
572  right_indices, prefer_to_permute_left);
573 }
574 
575 inline std::shared_ptr<BinaryOpPermutationOptimizer> make_permutation_optimizer(
576  const IndexList& target_indices, const IndexList& left_indices,
577  const IndexList& right_indices, bool prefer_to_permute_left) {
579  compute_product_type(left_indices, right_indices, target_indices),
580  target_indices, left_indices, right_indices, prefer_to_permute_left);
581 }
582 
583 } // namespace expressions
584 } // namespace TiledArray
585 
586 #endif // TILEDARRAY_EXPRESSIONS_PERMOPT_H__INCLUDED
PermutationType left_permtype() const override final
Definition: permopt.h:519
HadamardPermutationOptimizer(const IndexList &result_indices, const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:430
boost::container::small_vector< T, N > svector
Definition: vector.h:43
::blas::Op Op
Definition: blas.h:46
bool is_permutation(const IndexList &other) const
Check that this index list is a permutation of other.
Definition: index_list.h:254
HadamardPermutationOptimizer(const HadamardPermutationOptimizer &)=default
const IndexList & target_left_indices() const override final
Definition: permopt.h:160
PermutationType right_permtype() const override final
Definition: permopt.h:172
NullBinaryOpPermutationOptimizer(const NullBinaryOpPermutationOptimizer &)=default
virtual const IndexList & target_left_indices() const =0
const IndexList & target_left_indices() const override final
Definition: permopt.h:510
NullBinaryOpPermutationOptimizer(const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:490
std::shared_ptr< BinaryOpPermutationOptimizer > make_permutation_optimizer(TensorProduct product_type, const IndexList &left_indices, const IndexList &right_indices, bool prefer_to_permute_left)
Definition: permopt.h:530
const IndexList & target_result_indices() const override final
Definition: permopt.h:465
const_iterator begin() const
Returns an iterator to the first index.
Definition: index_list.h:178
const IndexList & target_right_indices() const override final
Definition: permopt.h:462
TensorProduct op_type() const override final
Definition: permopt.h:474
const IndexList & target_right_indices() const override final
Definition: permopt.h:163
NullBinaryOpPermutationOptimizer & operator=(const NullBinaryOpPermutationOptimizer &)=default
GEMMPermutationOptimizer(const IndexList &result_indices, const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:148
blas::Op to_cblas_op(PermutationType permtype)
Definition: permopt.h:46
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
PermutationType left_permtype() const override final
Definition: permopt.h:468
virtual PermutationType right_permtype() const =0
TensorProduct op_type() const override final
Definition: permopt.h:525
const IndexList & target_result_indices() const override final
Definition: permopt.h:166
KroneckerDeltaTile< _N >::numeric_type max(const KroneckerDeltaTile< _N > &arg)
BinaryOpPermutationOptimizer(const IndexList &result_indices, const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:77
const IndexList & target_right_indices() const override final
Definition: permopt.h:513
BinaryOpPermutationOptimizer(const BinaryOpPermutationOptimizer &)=default
PermutationType left_permtype() const override final
Definition: permopt.h:169
GEMMPermutationOptimizer(const GEMMPermutationOptimizer &)=default
virtual PermutationType left_permtype() const =0
TensorProduct compute_product_type(const IndexList &left_indices, const IndexList &right_indices)
Definition: product.h:51
PermutationType right_permtype() const override final
Definition: permopt.h:522
TensorProduct
types of binary tensor products known to TiledArray
Definition: product.h:35
GEMMPermutationOptimizer & operator=(const GEMMPermutationOptimizer &)=default
virtual const IndexList & target_result_indices() const =0
BinaryOpPermutationOptimizer(const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:62
const IndexList & target_left_indices() const override final
Definition: permopt.h:459
HadamardPermutationOptimizer(const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:420
HadamardPermutationOptimizer & operator=(const HadamardPermutationOptimizer &)=default
unsigned int size() const
Returns the number of elements in the index list.
Definition: index_list.h:197
const IndexList & target_result_indices() const override final
Definition: permopt.h:516
NullBinaryOpPermutationOptimizer(const IndexList &result_indices, const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:499
BinaryOpPermutationOptimizer & operator=(const BinaryOpPermutationOptimizer &)=default
TensorProduct op_type() const override final
Definition: permopt.h:175
Abstract optimizer of permutations for a binary operation.
Definition: permopt.h:55
virtual const IndexList & target_right_indices() const =0
PermutationType right_permtype() const override final
Definition: permopt.h:471
GEMMPermutationOptimizer(const IndexList &left_indices, const IndexList &right_indices, const bool prefer_to_permute_left=true)
Definition: permopt.h:137