25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
34 #include <scalapackpp/svd.hpp>
62 template <SVD::Vectors Vectors,
typename Array>
66 using real_type = scalapackpp::detail::real_t<value_type>;
68 auto& world = A.
world();
69 auto world_comm = world.mpi.comm().Get_mpi_comm();
71 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
77 auto [M, N] = matrix.dims();
80 auto [AMloc, ANloc] = matrix.dist().get_local_dims(M, N);
81 auto [UMloc, UNloc] = matrix.dist().get_local_dims(M, SVD_SIZE);
82 auto [VTMloc, VTNloc] = matrix.dist().get_local_dims(SVD_SIZE, N);
84 auto desc_a = matrix.dist().descinit_noerror(M, N, AMloc);
85 auto desc_u = matrix.dist().descinit_noerror(M, SVD_SIZE, UMloc);
86 auto desc_vt = matrix.dist().descinit_noerror(SVD_SIZE, N, VTMloc);
88 std::vector<real_type> S(SVD_SIZE);
94 std::shared_ptr<scalapack::BlockCyclicMatrix<value_type>> U =
nullptr,
97 scalapackpp::VectorFlag JOBU = scalapackpp::VectorFlag::NoVectors;
98 scalapackpp::VectorFlag JOBVT = scalapackpp::VectorFlag::NoVectors;
100 value_type* U_ptr =
nullptr;
101 value_type* VT_ptr =
nullptr;
103 if constexpr (need_u) {
104 JOBU = scalapackpp::VectorFlag::Vectors;
105 U = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
106 world, grid, M, SVD_SIZE, MB, NB);
108 U_ptr = U->local_mat().data();
111 if constexpr (need_vt) {
112 JOBVT = scalapackpp::VectorFlag::Vectors;
113 VT = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
114 world, grid, SVD_SIZE, N, MB, NB);
116 VT_ptr = VT->local_mat().data();
119 auto info = scalapackpp::pgesvd(JOBU, JOBVT, M, N, matrix.local_mat().data(),
120 1, 1, desc_a, S.data(), U_ptr, 1, 1, desc_u,
121 VT_ptr, 1, 1, desc_vt);
126 if constexpr (need_uv) {
127 auto U_ta = scalapack::block_cyclic_to_array<Array>(*U, u_trange);
128 auto VT_ta = scalapack::block_cyclic_to_array<Array>(*VT, vt_trange);
131 return std::tuple(S, U_ta, VT_ta);
133 }
else if constexpr (need_u) {
134 auto U_ta = scalapack::block_cyclic_to_array<Array>(*U, u_trange);
137 return std::tuple(S, U_ta);
139 }
else if constexpr (need_vt) {
140 auto VT_ta = scalapack::block_cyclic_to_array<Array>(*VT, vt_trange);
143 return std::tuple(S, VT_ta);
152 #endif // TILEDARRAY_HAS_SCALAPACK
153 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED