svd.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2020 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * David Williams-Young
19  * Computational Research Division, Lawrence Berkeley National Laboratory
20  *
21  * svd.h
22  * Created: 12 June, 2020
23  *
24  */
25 #ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED
27 
28 #include <TiledArray/config.h>
29 #if TILEDARRAY_HAS_SCALAPACK
30 
33 
34 #include <scalapackpp/svd.hpp>
35 
37 
62 template <SVD::Vectors Vectors, typename Array>
63 auto svd(const Array& A, TiledRange u_trange, TiledRange vt_trange,
64  size_t MB = default_block_size(), size_t NB = default_block_size()) {
65  using value_type = typename Array::element_type;
66  using real_type = scalapackpp::detail::real_t<value_type>;
67 
68  auto& world = A.world();
69  auto world_comm = world.mpi.comm().Get_mpi_comm();
70  // auto world_comm = MPI_COMM_WORLD;
71  blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);
72 
73  world.gop.fence(); // stage ScaLAPACK execution
74  auto matrix = scalapack::array_to_block_cyclic(A, grid, MB, NB);
75  world.gop.fence(); // stage ScaLAPACK execution
76 
77  auto [M, N] = matrix.dims();
78  auto SVD_SIZE = std::min(M, N);
79 
80  auto [AMloc, ANloc] = matrix.dist().get_local_dims(M, N);
81  auto [UMloc, UNloc] = matrix.dist().get_local_dims(M, SVD_SIZE);
82  auto [VTMloc, VTNloc] = matrix.dist().get_local_dims(SVD_SIZE, N);
83 
84  auto desc_a = matrix.dist().descinit_noerror(M, N, AMloc);
85  auto desc_u = matrix.dist().descinit_noerror(M, SVD_SIZE, UMloc);
86  auto desc_vt = matrix.dist().descinit_noerror(SVD_SIZE, N, VTMloc);
87 
88  std::vector<real_type> S(SVD_SIZE);
89 
90  constexpr bool need_uv = (Vectors == SVD::AllVectors);
91  constexpr bool need_u = (Vectors == SVD::LeftVectors) or need_uv;
92  constexpr bool need_vt = (Vectors == SVD::RightVectors) or need_uv;
93 
94  std::shared_ptr<scalapack::BlockCyclicMatrix<value_type>> U = nullptr,
95  VT = nullptr;
96 
97  scalapackpp::VectorFlag JOBU = scalapackpp::VectorFlag::NoVectors;
98  scalapackpp::VectorFlag JOBVT = scalapackpp::VectorFlag::NoVectors;
99 
100  value_type* U_ptr = nullptr;
101  value_type* VT_ptr = nullptr;
102 
103  if constexpr (need_u) {
104  JOBU = scalapackpp::VectorFlag::Vectors;
105  U = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
106  world, grid, M, SVD_SIZE, MB, NB);
107 
108  U_ptr = U->local_mat().data();
109  }
110 
111  if constexpr (need_vt) {
112  JOBVT = scalapackpp::VectorFlag::Vectors;
113  VT = std::make_shared<scalapack::BlockCyclicMatrix<value_type>>(
114  world, grid, SVD_SIZE, N, MB, NB);
115 
116  VT_ptr = VT->local_mat().data();
117  }
118 
119  auto info = scalapackpp::pgesvd(JOBU, JOBVT, M, N, matrix.local_mat().data(),
120  1, 1, desc_a, S.data(), U_ptr, 1, 1, desc_u,
121  VT_ptr, 1, 1, desc_vt);
122  if (info) TA_EXCEPTION("SVD Failed");
123 
124  world.gop.fence();
125 
126  if constexpr (need_uv) {
127  auto U_ta = scalapack::block_cyclic_to_array<Array>(*U, u_trange);
128  auto VT_ta = scalapack::block_cyclic_to_array<Array>(*VT, vt_trange);
129  world.gop.fence();
130 
131  return std::tuple(S, U_ta, VT_ta);
132 
133  } else if constexpr (need_u) {
134  auto U_ta = scalapack::block_cyclic_to_array<Array>(*U, u_trange);
135  world.gop.fence();
136 
137  return std::tuple(S, U_ta);
138 
139  } else if constexpr (need_vt) {
140  auto VT_ta = scalapack::block_cyclic_to_array<Array>(*VT, vt_trange);
141  world.gop.fence();
142 
143  return std::tuple(S, VT_ta);
144 
145  } else {
146  return S;
147  }
148 }
149 
150 } // namespace TiledArray::math::linalg::scalapack
151 
152 #endif // TILEDARRAY_HAS_SCALAPACK
153 #endif // TILEDARRAY_MATH_LINALG_SCALAPACK_SVD_H__INCLUDED
auto svd(const Array &A, TiledRange u_trange, TiledRange vt_trange, size_t MB=default_block_size(), size_t NB=default_block_size())
Compute the singular value decomposition (SVD) via ScaLAPACK.
Definition: svd.h:63
std::size_t default_block_size()
Definition: util.h:88
KroneckerDeltaTile< _N >::numeric_type min(const KroneckerDeltaTile< _N > &arg)
#define TA_EXCEPTION(m)
Definition: error.h:83
Range data of a tiled array.
Definition: tiled_range.h:32
Forward declarations.
Definition: dist_array.h:57
value_type::value_type element_type
Definition: dist_array.h:102
World & world() const
World accessor.
Definition: dist_array.h:1007
BlockCyclicMatrix< typename std::remove_cv_t< Array >::element_type > array_to_block_cyclic(const Array &array, const blacspp::Grid &grid, size_t MB, size_t NB)
Convert a dense DistArray to block-cyclic storage format.
Definition: block_cyclic.h:308