25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
34 #include <scalapackpp/factorizations/potrf.hpp>
35 #include <scalapackpp/linear_systems/posv.hpp>
36 #include <scalapackpp/linear_systems/trtrs.hpp>
37 #include <scalapackpp/matrix_inverse/trtri.hpp>
59 template <
typename Array>
62 auto& world = A.
world();
63 auto world_comm = world.mpi.comm().Get_mpi_comm();
64 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
70 auto [M, N] = matrix.dims();
71 if (M != N)
TA_EXCEPTION(
"Matrix must be square for Cholesky");
73 auto [Mloc, Nloc] = matrix.dist().get_local_dims(N, N);
74 auto desc = matrix.dist().descinit_noerror(N, N, Mloc);
76 auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
77 matrix.local_mat().data(), 1, 1, desc);
83 if (l_trange.rank() == 0) l_trange = A.
trange();
86 auto L = scalapack::block_cyclic_to_array<Array>(matrix, l_trange);
113 template <
bool Both,
typename Array>
118 auto& world = A.
world();
119 auto world_comm = world.mpi.comm().Get_mpi_comm();
120 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
126 auto [M, N] = matrix.dims();
127 if (M != N)
TA_EXCEPTION(
"Matrix must be square for Cholesky");
129 auto [Mloc, Nloc] = matrix.dist().get_local_dims(N, N);
130 auto desc = matrix.dist().descinit_noerror(N, N, Mloc);
132 auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
133 matrix.local_mat().data(), 1, 1, desc);
140 std::shared_ptr<scalapack::BlockCyclicMatrix<value_type>> L_sca =
nullptr;
141 if constexpr (Both) {
142 L_sca = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
143 world, grid, N, N, NB, NB);
144 L_sca->local_mat() = matrix.local_mat();
149 scalapackpp::ptrtri(blacspp::Triangle::Lower, blacspp::Diagonal::NonUnit,
150 N, matrix.local_mat().data(), 1, 1, desc);
153 if (l_trange.rank() == 0) l_trange = A.
trange();
156 auto Linv = scalapack::block_cyclic_to_array<Array>(matrix, l_trange);
159 if constexpr (Both) {
160 auto L = scalapack::block_cyclic_to_array<Array>(*L_sca, l_trange);
162 return std::tuple(L, Linv);
168 template <
typename Array>
172 auto& world = A.
world();
178 auto world_comm = world.mpi.comm().Get_mpi_comm();
179 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
186 auto [M, N] = A_sca.dims();
187 if (M != N)
TA_EXCEPTION(
"A must be square for Cholesky Solve");
189 auto [B_N, NRHS] = B_sca.dims();
192 scalapackpp::scalapack_desc desc_a, desc_b;
194 auto [Mloc, Nloc] = A_sca.dist().get_local_dims(N, N);
195 desc_a = A_sca.dist().descinit_noerror(N, N, Mloc);
199 auto [Mloc, Nloc] = B_sca.dist().get_local_dims(N, NRHS);
200 desc_b = B_sca.dist().descinit_noerror(N, NRHS, Mloc);
203 auto info = scalapackpp::pposv(blacspp::Triangle::Lower, N, NRHS,
204 A_sca.local_mat().data(), 1, 1, desc_a,
205 B_sca.local_mat().data(), 1, 1, desc_b);
208 if (x_trange.rank() == 0) x_trange = B.
trange();
211 auto X = scalapack::block_cyclic_to_array<Array>(B_sca, x_trange);
217 template <
typename Array>
222 auto& world = A.
world();
228 auto world_comm = world.mpi.comm().Get_mpi_comm();
229 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
236 auto [M, N] = A_sca.dims();
237 if (M != N)
TA_EXCEPTION(
"A must be square for Cholesky Solve");
239 auto [B_N, NRHS] = B_sca.dims();
242 scalapackpp::scalapack_desc desc_a, desc_b;
244 auto [Mloc, Nloc] = A_sca.dist().get_local_dims(N, N);
245 desc_a = A_sca.dist().descinit_noerror(N, N, Mloc);
249 auto [Mloc, Nloc] = B_sca.dist().get_local_dims(N, NRHS);
250 desc_b = B_sca.dist().descinit_noerror(N, NRHS, Mloc);
253 auto info = scalapackpp::ppotrf(blacspp::Triangle::Lower, N,
254 A_sca.local_mat().data(), 1, 1, desc_a);
257 info = scalapackpp::ptrtrs(
259 blacspp::Diagonal::NonUnit, N, NRHS, A_sca.local_mat().data(), 1, 1,
260 desc_a, B_sca.local_mat().data(), 1, 1, desc_b);
266 if (l_trange.rank() == 0) l_trange = A.
trange();
267 if (x_trange.rank() == 0) x_trange = B.
trange();
270 auto L = scalapack::block_cyclic_to_array<Array>(A_sca, l_trange);
271 auto X = scalapack::block_cyclic_to_array<Array>(B_sca, x_trange);
274 return std::tuple(L, X);
279 #endif // TILEDARRAY_HAS_SCALAPACK
280 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_CHOL_H__INCLUDED