lu.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2020 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * David Williams-Young
19  * Computational Research Division, Lawrence Berkeley National Laboratory
20  *
21  * lu.h
22  * Created: 19 June, 2020
23  *
24  */
25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED
27 
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
30 
32 
33 #include <scalapackpp/factorizations/getrf.hpp>
34 #include <scalapackpp/linear_systems/gesv.hpp>
35 #include <scalapackpp/matrix_inverse/getri.hpp>
36 
38 
42 template <typename ArrayA, typename ArrayB>
43 auto lu_solve(const ArrayA& A, const ArrayB& B,
44  TiledRange x_trange = TiledRange(),
45  size_t NB = default_block_size(),
46  size_t MB = default_block_size()) {
47  using value_type = typename ArrayA::element_type;
48  static_assert(std::is_same_v<value_type, typename ArrayB::element_type>);
49 
50  auto& world = A.world();
51  auto world_comm = world.mpi.comm().Get_mpi_comm();
52  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
53 
54  world.gop.fence(); // stage ScaLAPACK execution
55  auto A_sca = scalapack::array_to_block_cyclic(A, grid, MB, NB);
56  auto B_sca = scalapack::array_to_block_cyclic(B, grid, MB, NB);
57  world.gop.fence(); // stage ScaLAPACK execution
58 
59  auto [M, N] = A_sca.dims();
60  if (M != N) TA_EXCEPTION("A must be square for LU Solve");
61  auto [B_N, NRHS] = B_sca.dims();
62  if (B_N != N) TA_EXCEPTION("A and B dims must agree");
63 
64  auto [A_Mloc, A_Nloc] = A_sca.dist().get_local_dims(N, N);
65  auto desc_a = A_sca.dist().descinit_noerror(N, N, A_Mloc);
66 
67  auto [B_Mloc, B_Nloc] = B_sca.dist().get_local_dims(N, NRHS);
68  auto desc_b = B_sca.dist().descinit_noerror(N, NRHS, B_Mloc);
69 
70  std::vector<scalapackpp::scalapack_int> IPIV(A_Mloc + MB);
71 
72  auto info =
73  scalapackpp::pgesv(N, NRHS, A_sca.local_mat().data(), 1, 1, desc_a,
74  IPIV.data(), B_sca.local_mat().data(), 1, 1, desc_b);
75  if (info) TA_EXCEPTION("LU Solve Failed");
76 
77  if (x_trange.rank() == 0) x_trange = B.trange();
78 
79  world.gop.fence();
80  auto X = scalapack::block_cyclic_to_array<ArrayB>(B_sca, x_trange);
81  world.gop.fence();
82 
83  return X;
84 }
85 
89 template <typename Array>
90 auto lu_inv(const Array& A, TiledRange ainv_trange = TiledRange(),
91  size_t NB = default_block_size(),
92  size_t MB = default_block_size()) {
93  auto& world = A.world();
94  auto world_comm = world.mpi.comm().Get_mpi_comm();
95  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
96 
97  world.gop.fence(); // stage ScaLAPACK execution
98  auto A_sca = scalapack::array_to_block_cyclic(A, grid, MB, NB);
99  world.gop.fence(); // stage ScaLAPACK execution
100 
101  auto [M, N] = A_sca.dims();
102  if (M != N) TA_EXCEPTION("A must be square for LU Inverse");
103 
104  auto [A_Mloc, A_Nloc] = A_sca.dist().get_local_dims(N, N);
105  auto desc_a = A_sca.dist().descinit_noerror(N, N, A_Mloc);
106 
107  std::vector<scalapackpp::scalapack_int> IPIV(A_Mloc + MB);
108 
109  {
110  auto info = scalapackpp::pgetrf(N, N, A_sca.local_mat().data(), 1, 1,
111  desc_a, IPIV.data());
112  if (info) TA_EXCEPTION("LU Failed");
113  }
114 
115  {
116  auto info = scalapackpp::pgetri(N, A_sca.local_mat().data(), 1, 1, desc_a,
117  IPIV.data());
118  if (info) TA_EXCEPTION("LU Inverse Failed");
119  }
120 
121  if (ainv_trange.rank() == 0) ainv_trange = A.trange();
122 
123  world.gop.fence();
124  auto Ainv = scalapack::block_cyclic_to_array<Array>(A_sca, ainv_trange);
125  world.gop.fence();
126 
127  return Ainv;
128 }
129 
130 } // namespace TiledArray::math::linalg::scalapack
131 
132 #endif // TILEDARRAY_HAS_SCALAPACK
133 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED
auto lu_solve(const ArrayA &A, const ArrayB &B, TiledRange x_trange=TiledRange(), size_t NB=default_block_size(), size_t MB=default_block_size())
Solve a linear system via LU factorization.
Definition: lu.h:43
std::size_t default_block_size()
Definition: util.h:88
#define TA_EXCEPTION(m)
Definition: error.h:83
auto lu_inv(const Array &A, TiledRange ainv_trange=TiledRange(), size_t NB=default_block_size(), size_t MB=default_block_size())
Invert a matrix via LU.
Definition: lu.h:90
Range data of a tiled array.
Definition: tiled_range.h:32
const trange_type & trange() const
Tiled range accessor.
Definition: dist_array.h:917
Forward declarations.
Definition: dist_array.h:57
World & world() const
World accessor.
Definition: dist_array.h:1007
BlockCyclicMatrix< typename std::remove_cv_t< Array >::element_type > array_to_block_cyclic(const Array &array, const blacspp::Grid &grid, size_t MB, size_t NB)
Convert a dense DistArray to block-cyclic storage format.
Definition: block_cyclic.h:308