cutt.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2018 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Chong Peng
19  * Department of Chemistry, Virginia Tech
20  * Aug 15, 2018
21  *
22  */
23 
24 #ifndef TILEDARRAY_EXTERNAL_CUTT_H__INCLUDED
25 #define TILEDARRAY_EXTERNAL_CUTT_H__INCLUDED
26 
27 #include <TiledArray/config.h>
28 
29 #ifdef TILEDARRAY_HAS_CUDA
30 
31 #include <algorithm>
32 #include <vector>
33 
34 #include <cutt.h>
35 
36 #include <TiledArray/permutation.h>
37 #include <TiledArray/range.h>
38 
39 namespace TiledArray {
40 
46 inline void extent_to_col_major(std::vector<int>& extent) {
47  std::reverse(extent.begin(), extent.end());
48 }
49 
54 inline void permutation_to_col_major(std::vector<int>& perm) {
55  int size = perm.size();
56 
57  std::vector<int> col_major_perm(size, 0);
58 
59  for (int input_index = 0; input_index < size; input_index++) {
60  int output_index = perm[input_index];
61 
62  // change input and output index to col major
63  int input_index_col_major = size - input_index - 1;
64  int output_index_col_major = size - output_index - 1;
65 
66  col_major_perm[input_index_col_major] = output_index_col_major;
67  }
68 
69  perm.swap(col_major_perm);
70 }
71 
79 template <typename T>
80 void cutt_permute(T* inData, T* outData, const TiledArray::Range& range,
81  const TiledArray::Permutation& perm, cudaStream_t stream) {
82  auto extent = range.extent();
83  std::vector<int> extent_int(extent.begin(), extent.end());
84 
85  // cuTT uses FROM notation
86  auto perm_inv = perm.inv();
87  std::vector<int> perm_int(perm_inv.begin(), perm_inv.end());
88 
89  // cuTT uses ColMajor
90  TiledArray::extent_to_col_major(extent_int);
91  TiledArray::permutation_to_col_major(perm_int);
92 
93  cuttResult_t status;
94 
95  cuttHandle plan;
96  status = cuttPlan(&plan, range.rank(), extent_int.data(), perm_int.data(),
97  sizeof(T), stream);
98 
99  TA_ASSERT(status == CUTT_SUCCESS);
100 
101  status = cuttExecute(plan, inData, outData);
102 
103  TA_ASSERT(status == CUTT_SUCCESS);
104 
105  status = cuttDestroy(plan);
106 
107  TA_ASSERT(status == CUTT_SUCCESS);
108 }
109 
110 } // namespace TiledArray
111 
112 #endif // TILEDARRAY_HAS_CUDA
113 
114 #endif // TILEDARRAY_EXTERNAL_CUTT_H__INCLUDED
Permutation inv() const
Construct the inverse of this permutation.
Definition: permutation.h:334
Permutation of a sequence of objects indexed by base-0 indices.
Definition: permutation.h:130
#define TA_ASSERT(EXPR,...)
Definition: error.h:39
index_view_type extent() const
Range extent accessor.
Definition: range.h:741
unsigned int rank() const
Rank accessor.
Definition: range.h:669
A (hyperrectangular) interval on , space of integer -indices.
Definition: range.h:46