cholesky.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2020 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * David Williams-Young
19  * Computational Research Division, Lawrence Berkeley National Laboratory
20  *
21  * cholesky.h
22  * Created: 8 June, 2020
23  *
24  */
25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED
27 
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
30 
33 
34 #include <scalapackpp/factorizations/potrf.hpp>
35 #include <scalapackpp/linear_systems/posv.hpp>
36 #include <scalapackpp/linear_systems/trtrs.hpp>
37 #include <scalapackpp/matrix_inverse/trtri.hpp>
38 
40 
59 template <typename Array>
60 auto cholesky(const Array& A, TiledRange l_trange = TiledRange(),
61  size_t NB = default_block_size()) {
62  auto& world = A.world();
63  auto world_comm = world.mpi.comm().Get_mpi_comm();
64  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
65 
66  world.gop.fence(); // stage ScaLAPACK execution
67  auto matrix = scalapack::array_to_block_cyclic(A, grid, NB, NB);
68  world.gop.fence(); // stage ScaLAPACK execution
69 
70  auto [M, N] = matrix.dims();
71  if (M != N) TA_EXCEPTION("Matrix must be square for Cholesky");
72 
73  auto [Mloc, Nloc] = matrix.dist().get_local_dims(N, N);
74  auto desc = matrix.dist().descinit_noerror(N, N, Mloc);
75 
76  auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
77  matrix.local_mat().data(), 1, 1, desc);
78  if (info) TA_EXCEPTION("Cholesky Failed");
79 
80  // Zero out the upper triangle
81  zero_triangle(blacspp::Triangle::Upper, matrix);
82 
83  if (l_trange.rank() == 0) l_trange = A.trange();
84 
85  world.gop.fence();
86  auto L = scalapack::block_cyclic_to_array<Array>(matrix, l_trange);
87  world.gop.fence();
88 
89  return L;
90 }
91 
113 template <bool Both, typename Array>
114 auto cholesky_linv(const Array& A, TiledRange l_trange = TiledRange(),
115  size_t NB = default_block_size()) {
116  using value_type = typename Array::element_type;
117 
118  auto& world = A.world();
119  auto world_comm = world.mpi.comm().Get_mpi_comm();
120  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
121 
122  world.gop.fence(); // stage ScaLAPACK execution
123  auto matrix = scalapack::array_to_block_cyclic(A, grid, NB, NB);
124  world.gop.fence(); // stage ScaLAPACK execution
125 
126  auto [M, N] = matrix.dims();
127  if (M != N) TA_EXCEPTION("Matrix must be square for Cholesky");
128 
129  auto [Mloc, Nloc] = matrix.dist().get_local_dims(N, N);
130  auto desc = matrix.dist().descinit_noerror(N, N, Mloc);
131 
132  auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
133  matrix.local_mat().data(), 1, 1, desc);
134  if (info) TA_EXCEPTION("Cholesky Failed");
135 
136  // Zero out the upper triangle
137  zero_triangle(blacspp::Triangle::Upper, matrix);
138 
139  // Copy L if needed
140  std::shared_ptr<scalapack::BlockCyclicMatrix<value_type>> L_sca = nullptr;
141  if constexpr (Both) {
142  L_sca = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
143  world, grid, N, N, NB, NB);
144  L_sca->local_mat() = matrix.local_mat();
145  }
146 
147  // Compute inverse
148  info =
149  scalapackpp::ptrtri(blacspp::Triangle::Lower, blacspp::Diagonal::NonUnit,
150  N, matrix.local_mat().data(), 1, 1, desc);
151  if (info) TA_EXCEPTION("TRTRI Failed");
152 
153  if (l_trange.rank() == 0) l_trange = A.trange();
154 
155  world.gop.fence();
156  auto Linv = scalapack::block_cyclic_to_array<Array>(matrix, l_trange);
157  world.gop.fence();
158 
159  if constexpr (Both) {
160  auto L = scalapack::block_cyclic_to_array<Array>(*L_sca, l_trange);
161  world.gop.fence();
162  return std::tuple(L, Linv);
163  } else {
164  return Linv;
165  }
166 }
167 
168 template <typename Array>
169 auto cholesky_solve(const Array& A, const Array& B,
170  TiledRange x_trange = TiledRange(),
171  size_t NB = default_block_size()) {
172  auto& world = A.world();
173  /*
174  if( world != B.world() ) {
175  TA_EXCEPTION("A and B must be distributed on same MADWorld context");
176  }
177  */
178  auto world_comm = world.mpi.comm().Get_mpi_comm();
179  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
180 
181  world.gop.fence(); // stage ScaLAPACK execution
182  auto A_sca = scalapack::array_to_block_cyclic(A, grid, NB, NB);
183  auto B_sca = scalapack::array_to_block_cyclic(B, grid, NB, NB);
184  world.gop.fence(); // stage ScaLAPACK execution
185 
186  auto [M, N] = A_sca.dims();
187  if (M != N) TA_EXCEPTION("A must be square for Cholesky Solve");
188 
189  auto [B_N, NRHS] = B_sca.dims();
190  if (B_N != N) TA_EXCEPTION("A and B dims must agree");
191 
192  scalapackpp::scalapack_desc desc_a, desc_b;
193  {
194  auto [Mloc, Nloc] = A_sca.dist().get_local_dims(N, N);
195  desc_a = A_sca.dist().descinit_noerror(N, N, Mloc);
196  }
197 
198  {
199  auto [Mloc, Nloc] = B_sca.dist().get_local_dims(N, NRHS);
200  desc_b = B_sca.dist().descinit_noerror(N, NRHS, Mloc);
201  }
202 
203  auto info = scalapackpp::pposv(blacspp::Triangle::Lower, N, NRHS,
204  A_sca.local_mat().data(), 1, 1, desc_a,
205  B_sca.local_mat().data(), 1, 1, desc_b);
206  if (info) TA_EXCEPTION("Cholesky Solve Failed");
207 
208  if (x_trange.rank() == 0) x_trange = B.trange();
209 
210  world.gop.fence();
211  auto X = scalapack::block_cyclic_to_array<Array>(B_sca, x_trange);
212  world.gop.fence();
213 
214  return X;
215 }
216 
217 template <typename Array>
218 auto cholesky_lsolve(Op trans, const Array& A, const Array& B,
219  TiledRange l_trange = TiledRange(),
220  TiledRange x_trange = TiledRange(),
221  size_t NB = default_block_size()) {
222  auto& world = A.world();
223  /*
224  if( world != B.world() ) {
225  TA_EXCEPTION("A and B must be distributed on same MADWorld context");
226  }
227  */
228  auto world_comm = world.mpi.comm().Get_mpi_comm();
229  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
230 
231  world.gop.fence(); // stage ScaLAPACK execution
232  auto A_sca = scalapack::array_to_block_cyclic(A, grid, NB, NB);
233  auto B_sca = scalapack::array_to_block_cyclic(B, grid, NB, NB);
234  world.gop.fence(); // stage ScaLAPACK execution
235 
236  auto [M, N] = A_sca.dims();
237  if (M != N) TA_EXCEPTION("A must be square for Cholesky Solve");
238 
239  auto [B_N, NRHS] = B_sca.dims();
240  if (B_N != N) TA_EXCEPTION("A and B dims must agree");
241 
242  scalapackpp::scalapack_desc desc_a, desc_b;
243  {
244  auto [Mloc, Nloc] = A_sca.dist().get_local_dims(N, N);
245  desc_a = A_sca.dist().descinit_noerror(N, N, Mloc);
246  }
247 
248  {
249  auto [Mloc, Nloc] = B_sca.dist().get_local_dims(N, NRHS);
250  desc_b = B_sca.dist().descinit_noerror(N, NRHS, Mloc);
251  }
252 
253  auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
254  A_sca.local_mat().data(), 1, 1, desc_a);
255  if (info) TA_EXCEPTION("Cholesky Failed");
256 
257  info = scalapackpp::ptrtrs(
258  blacspp::Triangle::Lower, to_scalapackpp_transposeflag(trans),
259  blacspp::Diagonal::NonUnit, N, NRHS, A_sca.local_mat().data(), 1, 1,
260  desc_a, B_sca.local_mat().data(), 1, 1, desc_b);
261  if (info) TA_EXCEPTION("TRTRS Failed");
262 
263  // Zero out the upper triangle
264  zero_triangle(blacspp::Triangle::Upper, A_sca);
265 
266  if (l_trange.rank() == 0) l_trange = A.trange();
267  if (x_trange.rank() == 0) x_trange = B.trange();
268 
269  world.gop.fence();
270  auto L = scalapack::block_cyclic_to_array<Array>(A_sca, l_trange);
271  auto X = scalapack::block_cyclic_to_array<Array>(B_sca, x_trange);
272  world.gop.fence();
273 
274  return std::tuple(L, X);
275 }
276 
277 } // namespace TiledArray::math::linalg::scalapack
278 
279 #endif // TILEDARRAY_HAS_SCALAPACK
280 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED
std::size_t default_block_size()
Definition: util.h:88
void zero_triangle(blacspp::Triangle tri, scalapack::BlockCyclicMatrix< T > &A, bool zero_diag=false)
Definition: util.h:50
#define TA_EXCEPTION(m)
Definition: error.h:83
scalapackpp::TransposeFlag to_scalapackpp_transposeflag(Op t)
Definition: util.h:36
auto cholesky_solve(const Array &A, const Array &B, TiledRange x_trange=TiledRange(), size_t NB=default_block_size())
Definition: cholesky.h:169
Range data of a tiled array.
Definition: tiled_range.h:32
const trange_type & trange() const
Tiled range accessor.
Definition: dist_array.h:917
Forward declarations.
Definition: dist_array.h:57
value_type::value_type element_type
Definition: dist_array.h:102
World & world() const
World accessor.
Definition: dist_array.h:1007
auto cholesky_lsolve(Op trans, const Array &A, const Array &B, TiledRange l_trange=TiledRange(), TiledRange x_trange=TiledRange(), size_t NB=default_block_size())
Definition: cholesky.h:218
auto cholesky_linv(const Array &A, TiledRange l_trange=TiledRange(), size_t NB=default_block_size())
Compute the inverse of the Cholesky factor of an HPD rank-2 tensor. Optionally return the Cholesky fa...
Definition: cholesky.h:114
BlockCyclicMatrix< typename std::remove_cv_t< Array >::element_type > array_to_block_cyclic(const Array &array, const blacspp::Grid &grid, size_t MB, size_t NB)
Convert a dense DistArray to block-cyclic storage format.
Definition: block_cyclic.h:308
auto cholesky(const Array &A, TiledRange l_trange=TiledRange(), size_t NB=default_block_size())
Compute the Cholesky factorization of a HPD rank-2 tensor.
Definition: cholesky.h:60