Program Listing for File biorthogonalization.hpp

Return to documentation for file (SeQuant/domain/mbpt/biorthogonalization.hpp)

#ifndef SEQUANT_DOMAIN_MBPT_BIORTHOGONALIZE_HPP
#define SEQUANT_DOMAIN_MBPT_BIORTHOGONALIZE_HPP

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/slotted_index.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/memoize.hpp>

#if defined(SEQUANT_HAS_TILEDARRAY)
#include <SeQuant/core/eval/backends/tiledarray/eval_expr.hpp>
#include <SeQuant/core/eval/backends/tiledarray/result.hpp>
#endif
#if defined(SEQUANT_HAS_BTAS)
#include <SeQuant/core/eval/backends/btas/eval_expr.hpp>
#include <SeQuant/core/eval/backends/btas/result.hpp>
#endif
#if defined(SEQUANT_HAS_TAPP)
#include <SeQuant/core/eval/backends/tapp/ops.hpp>
#include <SeQuant/core/eval/backends/tapp/tensor.hpp>
#endif

#include <concepts>
#include <condition_variable>
#include <cstddef>
#include <mutex>
#include <optional>
#include <vector>

namespace sequant::mbpt {

static constexpr double default_biorthogonalizer_pseudoinverse_threshold =
    1e-12;

void biorthogonal_transform(
    ResultExpr& expr, double pseudoinverse_threshold =
                          default_biorthogonalizer_pseudoinverse_threshold);

void biorthogonal_transform(
    container::svector<ResultExpr>& exprs,
    double pseudoinverse_threshold =
        default_biorthogonalizer_pseudoinverse_threshold);

[[nodiscard]] ExprPtr biorthogonal_transform(
    const ExprPtr& expr,
    const container::svector<container::svector<sequant::SlottedIndex>>&
        ext_index_groups = {},
    double pseudoinverse_threshold =
        default_biorthogonalizer_pseudoinverse_threshold);
[[nodiscard]] ExprPtr biorthogonal_transform(
    const ExprPtr& expr,
    const container::svector<container::svector<sequant::Index>>&
        ext_index_groups = {},
    double pseudoinverse_threshold =
        default_biorthogonalizer_pseudoinverse_threshold);

ExprPtr WK_biorthogonalization_filter(
    ExprPtr expr,
    const container::svector<container::svector<SlottedIndex>>& ext_idxs);
ExprPtr WK_biorthogonalization_filter(
    ExprPtr expr,
    const container::svector<container::svector<Index>>& ext_idxs);

ExprPtr biorthogonal_transform_pre_nnsproject(
    ExprPtr& expr,
    const container::svector<container::svector<SlottedIndex>>& ext_idxs,
    bool factor_out_nns_projector = true);
ExprPtr biorthogonal_transform_pre_nnsproject(
    ExprPtr& expr,
    const container::svector<container::svector<Index>>& ext_idxs,
    bool factor_out_nns_projector = true);

namespace detail {

[[nodiscard]] std::vector<double> compute_nns_p_coeffs(
    std::size_t n_particles,
    double pseudoinverse_threshold =
        default_biorthogonalizer_pseudoinverse_threshold);

container::svector<size_t> compute_permuted_indices(
    const container::svector<size_t>& indices, size_t perm_rank,
    size_t n_particles);

template <typename T>
  requires(std::floating_point<T> || meta::is_complex_v<T>)
std::optional<std::vector<T>> hardcoded_nns_projector(std::size_t n_particles) {
  switch (n_particles) {
    case 1:
      return std::vector<T>{T(1) / T(1)};

    case 2:
      return std::vector<T>{T(0) / T(1), T(1) / T(1)};

    case 3:
      return std::vector<T>{T(-1) / T(5), T(-1) / T(5), T(-1) / T(5),
                            T(-1) / T(5), T(-1) / T(5), T(1) / T(1)};

    case 4:
      return std::vector<T>{
          T(1) / T(7),   T(1) / T(7),   T(1) / T(7),   T(-1) / T(14),
          T(1) / T(7),   T(1) / T(7),   T(1) / T(7),   T(-1) / T(14),
          T(-1) / T(14), T(-1) / T(14), T(1) / T(7),   T(-2) / T(7),
          T(-1) / T(14), T(1) / T(7),   T(-1) / T(14), T(-2) / T(7),
          T(1) / T(7),   T(-1) / T(14), T(-1) / T(14), T(-2) / T(7),
          T(-2) / T(7),  T(-2) / T(7),  T(-2) / T(7),  T(1) / T(1)};

    case 5:
      return std::vector<T>{
          T(-1) / T(14), T(-1) / T(14), T(-1) / T(14), T(-1) / T(14),
          T(2) / T(21),  T(-1) / T(14), T(-1) / T(14), T(-1) / T(14),
          T(-1) / T(14), T(2) / T(21),  T(-1) / T(14), T(-1) / T(14),
          T(-1) / T(14), T(-1) / T(14), T(2) / T(21),  T(2) / T(21),
          T(2) / T(21),  T(2) / T(21),  T(-1) / T(21), T(0) / T(1),
          T(-1) / T(14), T(-1) / T(14), T(-1) / T(14), T(-1) / T(14),
          T(2) / T(21),  T(-1) / T(14), T(-1) / T(14), T(-1) / T(14),
          T(-1) / T(14), T(2) / T(21),  T(-1) / T(14), T(-1) / T(14),
          T(-1) / T(14), T(-1) / T(14), T(2) / T(21),  T(2) / T(21),
          T(2) / T(21),  T(2) / T(21),  T(-1) / T(21), T(0) / T(1),
          T(2) / T(21),  T(2) / T(21),  T(-1) / T(21), T(2) / T(21),
          T(0) / T(1),   T(2) / T(21),  T(2) / T(21),  T(-1) / T(21),
          T(2) / T(21),  T(0) / T(1),   T(-1) / T(21), T(-1) / T(21),
          T(-1) / T(21), T(-1) / T(21), T(1) / T(7),   T(0) / T(1),
          T(0) / T(1),   T(1) / T(7),   T(1) / T(7),   T(-1) / T(3),
          T(2) / T(21),  T(-1) / T(21), T(2) / T(21),  T(2) / T(21),
          T(0) / T(1),   T(-1) / T(21), T(-1) / T(21), T(-1) / T(21),
          T(-1) / T(21), T(1) / T(7),   T(2) / T(21),  T(-1) / T(21),
          T(2) / T(21),  T(2) / T(21),  T(0) / T(1),   T(0) / T(1),
          T(1) / T(7),   T(0) / T(1),   T(1) / T(7),   T(-1) / T(3),
          T(-1) / T(21), T(-1) / T(21), T(-1) / T(21), T(-1) / T(21),
          T(1) / T(7),   T(-1) / T(21), T(2) / T(21),  T(2) / T(21),
          T(2) / T(21),  T(0) / T(1),   T(-1) / T(21), T(2) / T(21),
          T(2) / T(21),  T(2) / T(21),  T(0) / T(1),   T(1) / T(7),
          T(0) / T(1),   T(0) / T(1),   T(1) / T(7),   T(-1) / T(3),
          T(0) / T(1),   T(1) / T(7),   T(1) / T(7),   T(0) / T(1),
          T(-1) / T(3),  T(1) / T(7),   T(0) / T(1),   T(1) / T(7),
          T(0) / T(1),   T(-1) / T(3),  T(1) / T(7),   T(1) / T(7),
          T(0) / T(1),   T(0) / T(1),   T(-1) / T(3),  T(-1) / T(3),
          T(-1) / T(3),  T(-1) / T(3),  T(-1) / T(3),  T(1) / T(1)};

    default:
      return std::nullopt;
  }
}

template <typename T>
  requires(std::floating_point<T> || meta::is_complex_v<T>)
[[nodiscard]] const std::vector<T>& nns_projection_weights(
    std::size_t n_particles,
    double pseudoinverse_threshold =
        default_biorthogonalizer_pseudoinverse_threshold) {
  static const std::vector<T> empty_vec{};

  if (n_particles < 3) {
    return empty_vec;
  }

  using CacheKey = std::pair<std::size_t, double>;

  static std::mutex cache_mutex;
  static std::condition_variable cache_cv;
  static container::map<CacheKey, std::optional<std::vector<T>>> cache;

  CacheKey key{n_particles, pseudoinverse_threshold};

  return sequant::detail::memoize(
      cache, cache_mutex, cache_cv, key, [&]() -> std::vector<T> {
        constexpr std::size_t max_rank_hardcoded_nns_projector = 5;
        if (n_particles <= max_rank_hardcoded_nns_projector) {
          if (auto hardcoded_coeffs = hardcoded_nns_projector<T>(n_particles)) {
            return std::move(hardcoded_coeffs.value());
          }
        }
        auto coeffs =
            detail::compute_nns_p_coeffs(n_particles, pseudoinverse_threshold);
        std::vector<T> nns_p_coeffs;
        nns_p_coeffs.reserve(coeffs.size());
        for (const auto& c : coeffs) {
          nns_p_coeffs.push_back(static_cast<T>(c));
        }
        return nns_p_coeffs;
      });
}

}  // namespace detail

#if defined(SEQUANT_HAS_TILEDARRAY)

template <typename... Args>
auto biorthogonal_nns_project_ta(TA::DistArray<Args...> const& arr,
                                 size_t bra_rank) {
  using ranges::views::iota;
  size_t const rank = arr.trange().rank();
  SEQUANT_ASSERT(bra_rank <= rank);
  size_t const ket_rank = rank - bra_rank;

  // Residuals of rank 4 or less have no redundancy and don't require NNS
  // projection
  if (rank <= 4) return arr;

  using numeric_type = typename TA::DistArray<Args...>::numeric_type;

  const auto& nns_p_coeffs =
      detail::nns_projection_weights<numeric_type>(ket_rank);

  TA::DistArray<Args...> result;

  perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
  perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
  perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;

  const auto lannot = ords_to_annot(perm);

  if (ket_rank > 2 && !nns_p_coeffs.empty()) {
    const auto bra_annot = bra_rank == 0 ? "" : ords_to_annot(bra_perm);

    size_t num_perms = nns_p_coeffs.size();
    for (size_t perm_rank = 0; perm_rank < num_perms; ++perm_rank) {
      perm_t permuted_ket =
          detail::compute_permuted_indices(ket_perm, perm_rank, ket_rank);

      numeric_type coeff = nns_p_coeffs[perm_rank];

      const auto ket_annot = ords_to_annot(permuted_ket);
      const auto annot =
          bra_annot.empty() ? ket_annot : bra_annot + "," + ket_annot;

      if (result.is_initialized()) {
        result(lannot) += coeff * arr(annot);
      } else {
        result(lannot) = coeff * arr(annot);
      }
    }
  } else {
    result(lannot) = arr(lannot);
  }

  TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
  return result;
}

template <typename... Args>
auto biorthogonal_nns_project(TA::DistArray<Args...> const& arr,
                              size_t bra_rank) {
  return biorthogonal_nns_project_ta(arr, bra_rank);
}

#endif  // defined(SEQUANT_HAS_TILEDARRAY)

#if defined(SEQUANT_HAS_BTAS)

template <typename... Args>
auto biorthogonal_nns_project_btas(btas::Tensor<Args...> const& arr,
                                   size_t bra_rank) {
  using ranges::views::iota;
  size_t const rank = arr.rank();
  SEQUANT_ASSERT(bra_rank <= rank);
  size_t const ket_rank = rank - bra_rank;

  // Residuals of rank 4 or less have no redundancy and don't require NNS
  // projection
  if (rank <= 4) return arr;

  using numeric_type = typename btas::Tensor<Args...>::numeric_type;

  const auto& nns_p_coeffs =
      detail::nns_projection_weights<numeric_type>(ket_rank);

  btas::Tensor<Args...> result;

  perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
  perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
  perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;

  if (ket_rank > 2 && !nns_p_coeffs.empty()) {
    bool result_initialized = false;

    size_t num_perms = nns_p_coeffs.size();
    for (size_t perm_rank = 0; perm_rank < num_perms; ++perm_rank) {
      perm_t permuted_ket =
          detail::compute_permuted_indices(ket_perm, perm_rank, ket_rank);

      numeric_type coeff = nns_p_coeffs[perm_rank];

      perm_t annot = bra_perm;
      annot.insert(annot.end(), permuted_ket.begin(), permuted_ket.end());

      btas::Tensor<Args...> temp;
      btas::permute(arr, annot, temp, perm);
      btas::scal(coeff, temp);

      if (result_initialized) {
        result += temp;
      } else {
        result = temp;
        result_initialized = true;
      }
    }

  } else {
    result = arr;
  }

  return result;
}

template <typename... Args>
auto biorthogonal_nns_project(btas::Tensor<Args...> const& arr,
                              size_t bra_rank) {
  return biorthogonal_nns_project_btas(arr, bra_rank);
}

#endif  // defined(SEQUANT_HAS_BTAS)

#if defined(SEQUANT_HAS_TAPP)

template <typename T, typename Alloc>
auto biorthogonal_nns_project_tapp(TAPPTensor<T, Alloc> const& arr,
                                   size_t bra_rank) {
  using ranges::views::iota;
  size_t const rank = arr.rank();
  SEQUANT_ASSERT(bra_rank <= rank);
  size_t const ket_rank = rank - bra_rank;

  // Residuals of rank 4 or less have no redundancy and don't require NNS
  // projection
  if (rank <= 4) return arr;

  using numeric_type = T;

  const auto& nns_p_coeffs =
      detail::nns_projection_weights<numeric_type>(ket_rank);

  using perm_type = container::svector<size_t>;

  TAPPTensor<T, Alloc> result;

  perm_type perm = iota(size_t{0}, rank) | ranges::to<perm_type>;
  perm_type bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_type>;
  perm_type ket_perm = iota(bra_rank, rank) | ranges::to<perm_type>;

  if (ket_rank > 2 && !nns_p_coeffs.empty()) {
    bool result_initialized = false;

    size_t num_perms = nns_p_coeffs.size();
    for (size_t perm_rank = 0; perm_rank < num_perms; ++perm_rank) {
      perm_type permuted_ket =
          detail::compute_permuted_indices(ket_perm, perm_rank, ket_rank);

      numeric_type coeff = nns_p_coeffs[perm_rank];

      perm_type annot = bra_perm;
      annot.insert(annot.end(), permuted_ket.begin(), permuted_ket.end());

      container::svector<int64_t> annot_i64(annot.begin(), annot.end());
      container::svector<int64_t> perm_i64(perm.begin(), perm.end());

      TAPPTensor<T, Alloc> temp;
      tapp_ops::permute(arr, annot_i64, temp, perm_i64);
      tapp_ops::scal(coeff, temp);

      if (result_initialized) {
        result += temp;
      } else {
        result = temp;
        result_initialized = true;
      }
    }

  } else {
    result = arr;
  }

  return result;
}

template <typename T, typename Alloc>
auto biorthogonal_nns_project(TAPPTensor<T, Alloc> const& arr,
                              size_t bra_rank) {
  return biorthogonal_nns_project_tapp(arr, bra_rank);
}

#endif  // defined(SEQUANT_HAS_TAPP)

}  // namespace sequant::mbpt

#endif