26 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_HEIG_H__INCLUDED
27 #define TILEDARRAY_MATH_LINALG_SCALAPACK_HEIG_H__INCLUDED
29 #include <TiledArray/config.h>
30 #if TILEDARRAY_HAS_SCALAPACK
34 #include <scalapackpp/eigenvalue_problem/gevp.hpp>
35 #include <scalapackpp/eigenvalue_problem/sevp.hpp>
58 template <
typename Array>
62 using real_type = scalapackpp::detail::real_t<value_type>;
64 auto& world = A.
world();
65 auto world_comm = world.mpi.comm().Get_mpi_comm();
67 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
73 auto [M, N] = matrix.dims();
74 if (M != N)
TA_EXCEPTION(
"Matrix must be square for EVP");
76 auto [Mloc, Nloc] = matrix.dist().get_local_dims(N, N);
77 auto desc = matrix.dist().descinit_noerror(N, N, Mloc);
79 std::vector<real_type> evals(N);
82 auto info = scalapackpp::hereig(
83 scalapackpp::VectorFlag::Vectors, blacspp::Triangle::Lower, N,
84 matrix.local_mat().data(), 1, 1, desc, evals.data(),
88 if (evec_trange.rank() == 0) evec_trange = A.
trange();
91 auto evecs_ta = scalapack::block_cyclic_to_array<Array>(evecs, evec_trange);
94 return std::tuple(evals, evecs_ta);
121 template <
typename ArrayA,
typename ArrayB,
typename EVecType = ArrayA>
122 auto heig(
const ArrayA& A,
const ArrayB& B,
125 using value_type =
typename ArrayA::element_type;
126 static_assert(std::is_same_v<typename ArrayB::element_type, value_type>);
127 using real_type = scalapackpp::detail::real_t<value_type>;
129 auto& world = A.world();
130 auto world_comm = world.mpi.comm().Get_mpi_comm();
132 blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
139 auto [M, N] = A_sca.dims();
140 if (M != N)
TA_EXCEPTION(
"Matrix must be square for EVP");
142 auto [B_M, B_N] = B_sca.dims();
143 if (B_M != M or B_N != N)
146 auto [Mloc, Nloc] = A_sca.dist().get_local_dims(N, N);
147 auto desc = A_sca.dist().descinit_noerror(N, N, Mloc);
149 std::vector<real_type> evals(N);
152 auto info = scalapackpp::hereig_gen(
153 scalapackpp::VectorFlag::Vectors, blacspp::Triangle::Lower, N,
154 A_sca.local_mat().data(), 1, 1, desc, B_sca.local_mat().data(), 1, 1,
155 desc, evals.data(), evecs.
local_mat().data(), 1, 1, desc);
158 if (evec_trange.rank() == 0) evec_trange = A.trange();
162 scalapack::block_cyclic_to_array<EVecType>(evecs, evec_trange);
165 return std::tuple(evals, evecs_ta);
170 #endif // TILEDARRAY_HAS_SCALAPACK
171 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_HEIG_H__INCLUDED