svd.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2020 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * David Williams-Young
19  * Computational Research Division, Lawrence Berkeley National Laboratory
20  *
21  * svd.h
22  * Created: 12 June, 2020
23  *
24  */
25 #ifndef TILEDARRAY_MATH_LINALG_NON_DISTRIBUTED_SVD_H__INCLUDED
26 #define TILEDARRAY_MATH_LINALG_NON_DISTRIBUTED_SVD_H__INCLUDED
27 
28 #include <TiledArray/config.h>
29 
33 
35 
58 template<SVD::Vectors Vectors, typename Array>
59 auto svd(const Array& A, TiledRange u_trange = TiledRange(), TiledRange vt_trange = TiledRange()) {
60 
61  using T = typename Array::numeric_type;
63 
64  World& world = A.world();
65  auto A_eig = detail::make_matrix(A);
66 
67  constexpr bool svd_all_vectors = (Vectors == SVD::AllVectors);
68  constexpr bool need_u = (Vectors == SVD::LeftVectors) or svd_all_vectors;
69  constexpr bool need_vt = (Vectors == SVD::RightVectors) or svd_all_vectors;
70 
71  std::vector<T> S;
72  std::unique_ptr<Matrix> U, VT;
73 
74  if constexpr (need_u) U = std::make_unique<Matrix>();
75  if constexpr (need_vt) VT = std::make_unique<Matrix>();
76 
77  if (world.rank() == 0) {
78  linalg::rank_local::svd(A_eig, S, U.get(), VT.get());
79  }
80 
81  world.gop.broadcast_serializable(S, 0);
82  if (U) world.gop.broadcast_serializable(*U, 0);
83  if (VT) world.gop.broadcast_serializable(*VT, 0);
84 
85  auto make_array = [&world](auto && ... args) {
86  return eigen_to_array<Array>(world, args...);
87  };
88 
89  if constexpr (need_u && need_vt) {
90  return std::tuple(S, make_array(u_trange, *U), make_array(vt_trange, *VT));
91  }
92  if constexpr (need_u && !need_vt) {
93  return std::tuple(S, make_array(u_trange, *U));
94  }
95  if constexpr (!need_u && need_vt) {
96  return std::tuple(S, make_array(vt_trange, *VT));
97  }
98 
99  if constexpr (!need_u && !need_vt) return S;
100 
101 }
102 
103 } // namespace TiledArray::math::linalg::non_distributed
104 
105 #endif // TILEDARRAY_MATH_LINALG_NON_DISTRIBUTED_SVD_H__INCLUDED
detail::numeric_type< Tile >::type numeric_type
Definition: dist_array.h:69
auto svd(const Array &A, TiledRange u_trange=TiledRange(), TiledRange vt_trange=TiledRange())
Compute the singular value decomposition (SVD) via ScaLAPACK.
Definition: svd.h:59
void svd(Job jobu, Job jobvt, Matrix< T > &A, std::vector< T > &S, Matrix< T > *U, Matrix< T > *VT)
Definition: rank-local.cpp:143
Array make_array(World &world, const detail::trange_t< Array > &trange, const std::shared_ptr< detail::pmap_t< Array > > &pmap, Op &&op)
Construct dense Array.
Definition: make_array.h:73
Range data of a tiled array.
Definition: tiled_range.h:32
Forward declarations.
Definition: dist_array.h:57
::Eigen::Matrix< T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options > Matrix
Definition: rank-local.h:16
World & world() const
World accessor.
Definition: dist_array.h:1007
::Eigen::Matrix< T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options > Matrix
Definition: blas.h:63
auto make_matrix(const DistArray< Tile, Policy > &A)
Definition: util.h:46