25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
33 #include <scalapackpp/factorizations/getrf.hpp>
34 #include <scalapackpp/linear_systems/gesv.hpp>
35 #include <scalapackpp/matrix_inverse/getri.hpp>
42 template <
typename ArrayA,
typename ArrayB>
43 auto lu_solve(
const ArrayA& A,
const ArrayB& B,
47 using value_type =
typename ArrayA::element_type;
48 static_assert(std::is_same_v<value_type, typename ArrayB::element_type>);
50 auto& world = A.world();
51 auto world_comm = world.mpi.comm().Get_mpi_comm();
52 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
59 auto [M, N] = A_sca.dims();
60 if (M != N)
TA_EXCEPTION(
"A must be square for LU Solve");
61 auto [B_N, NRHS] = B_sca.dims();
64 auto [A_Mloc, A_Nloc] = A_sca.dist().get_local_dims(N, N);
65 auto desc_a = A_sca.dist().descinit_noerror(N, N, A_Mloc);
67 auto [B_Mloc, B_Nloc] = B_sca.dist().get_local_dims(N, NRHS);
68 auto desc_b = B_sca.dist().descinit_noerror(N, NRHS, B_Mloc);
70 std::vector<scalapackpp::scalapack_int> IPIV(A_Mloc + MB);
73 scalapackpp::pgesv(N, NRHS, A_sca.local_mat().data(), 1, 1, desc_a,
74 IPIV.data(), B_sca.local_mat().data(), 1, 1, desc_b);
77 if (x_trange.rank() == 0) x_trange = B.trange();
80 auto X = scalapack::block_cyclic_to_array<ArrayB>(B_sca, x_trange);
89 template <
typename Array>
93 auto& world = A.
world();
94 auto world_comm = world.mpi.comm().Get_mpi_comm();
95 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
101 auto [M, N] = A_sca.dims();
102 if (M != N)
TA_EXCEPTION(
"A must be square for LU Inverse");
104 auto [A_Mloc, A_Nloc] = A_sca.dist().get_local_dims(N, N);
105 auto desc_a = A_sca.dist().descinit_noerror(N, N, A_Mloc);
107 std::vector<scalapackpp::scalapack_int> IPIV(A_Mloc + MB);
110 auto info = scalapackpp::pgetrf(N, N, A_sca.local_mat().data(), 1, 1,
111 desc_a, IPIV.data());
116 auto info = scalapackpp::pgetri(N, A_sca.local_mat().data(), 1, 1, desc_a,
121 if (ainv_trange.rank() == 0) ainv_trange = A.
trange();
124 auto Ainv = scalapack::block_cyclic_to_array<Array>(A_sca, ainv_trange);
132 #endif // TILEDARRAY_HAS_SCALAPACK
133 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_LU_H__INCLUDED