Program Listing for File biorthogonalization.cpp

Return to documentation for file (SeQuant/domain/mbpt/biorthogonalization.cpp)

#include <SeQuant/domain/mbpt/biorthogonalization.hpp>
#include <SeQuant/domain/mbpt/detail/concepts.hpp>
#include <SeQuant/domain/mbpt/spin.hpp>

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/math.hpp>
#include <SeQuant/core/reserved.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/utility/expr.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/core/utility/permutation.hpp>

#include <Eigen/Core>
#include <Eigen/Eigenvalues>

#include <libperm/Permutation.hpp>
#include <libperm/Rank.hpp>
#include <libperm/Utils.hpp>

#include <algorithm>

namespace sequant::mbpt {

template <typename T>
struct compare_first_less {
  bool operator()(const T& lhs, const T& rhs) const {
    return lhs.first < rhs.first;
  }
};

using IndexPair = std::pair<Index, Index>;
using ParticlePairings = container::svector<IndexPair>;

// clang-format off
// clang-format on
std::vector<sequant::rational> hardcoded_biorthogonalizer_row(
    std::size_t n_particles) {
  switch (n_particles) {
    case 1:
      return std::vector<sequant::rational>{ratio(1, 2)};

    case 2:
      return std::vector<sequant::rational>{ratio(1, 3), ratio(1, 6)};

    case 3:
      return std::vector<sequant::rational>{ratio(17, 120), ratio(-7, 120),
                                            ratio(-1, 120), ratio(-1, 120),
                                            ratio(-1, 120), ratio(-7, 120)};

    case 4:
      return std::vector<sequant::rational>{
          ratio(43, 840), ratio(-19, 1680), ratio(-19, 1680),
          ratio(-1, 105), ratio(-19, 1680), ratio(-19, 1680),
          ratio(13, 840), ratio(1, 120),    ratio(-1, 105),
          ratio(1, 120),  ratio(-1, 105),   ratio(-19, 1680),
          ratio(-1, 105), ratio(1, 120),    ratio(1, 120),
          ratio(13, 840), ratio(-1, 105),   ratio(-1, 105),
          ratio(1, 120),  ratio(-19, 1680), ratio(-19, 1680),
          ratio(13, 840), ratio(-19, 1680), ratio(1, 120)};

    case 5:
      return std::vector<sequant::rational>{
          ratio(59, 3780),   ratio(-5, 3024),   ratio(-5, 3024),
          ratio(-5, 3024),   ratio(-31, 7560),  ratio(-5, 3024),
          ratio(-5, 3024),   ratio(-23, 30240), ratio(19, 7560),
          ratio(37, 15120),  ratio(-5, 3024),   ratio(-23, 30240),
          ratio(-5, 3024),   ratio(19, 7560),   ratio(37, 15120),
          ratio(-31, 7560),  ratio(37, 15120),  ratio(37, 15120),
          ratio(-31, 7560),  ratio(-5, 3024),   ratio(-5, 3024),
          ratio(-23, 30240), ratio(-23, 30240), ratio(-23, 30240),
          ratio(-13, 7560),  ratio(-5, 3024),   ratio(-5, 3024),
          ratio(19, 7560),   ratio(-23, 30240), ratio(37, 15120),
          ratio(19, 7560),   ratio(-23, 30240), ratio(19, 7560),
          ratio(-23, 30240), ratio(-13, 7560),  ratio(37, 15120),
          ratio(-13, 7560),  ratio(-13, 7560),  ratio(37, 15120),
          ratio(-23, 30240), ratio(-31, 7560),  ratio(-13, 7560),
          ratio(37, 15120),  ratio(37, 15120),  ratio(19, 7560),
          ratio(37, 15120),  ratio(37, 15120),  ratio(-13, 7560),
          ratio(-13, 7560),  ratio(-23, 30240), ratio(-31, 7560),
          ratio(37, 15120),  ratio(-31, 7560),  ratio(37, 15120),
          ratio(-5, 3024),   ratio(-5, 3024),   ratio(-23, 30240),
          ratio(19, 7560),   ratio(-5, 3024),   ratio(37, 15120),
          ratio(-31, 7560),  ratio(37, 15120),  ratio(37, 15120),
          ratio(-13, 7560),  ratio(19, 7560),   ratio(37, 15120),
          ratio(37, 15120),  ratio(-13, 7560),  ratio(-13, 7560),
          ratio(-23, 30240), ratio(37, 15120),  ratio(-13, 7560),
          ratio(37, 15120),  ratio(-13, 7560),  ratio(-23, 30240),
          ratio(19, 7560),   ratio(-23, 30240), ratio(-23, 30240),
          ratio(19, 7560),   ratio(-13, 7560),  ratio(-31, 7560),
          ratio(37, 15120),  ratio(-13, 7560),  ratio(37, 15120),
          ratio(19, 7560),   ratio(-31, 7560),  ratio(-31, 7560),
          ratio(37, 15120),  ratio(37, 15120),  ratio(-5, 3024),
          ratio(37, 15120),  ratio(-13, 7560),  ratio(37, 15120),
          ratio(-13, 7560),  ratio(-23, 30240), ratio(-5, 3024),
          ratio(19, 7560),   ratio(-23, 30240), ratio(-5, 3024),
          ratio(37, 15120),  ratio(-5, 3024),   ratio(-23, 30240),
          ratio(-23, 30240), ratio(-23, 30240), ratio(-13, 7560),
          ratio(19, 7560),   ratio(19, 7560),   ratio(-23, 30240),
          ratio(-23, 30240), ratio(-13, 7560),  ratio(-5, 3024),
          ratio(19, 7560),   ratio(-5, 3024),   ratio(-23, 30240),
          ratio(37, 15120),  ratio(37, 15120),  ratio(-13, 7560),
          ratio(-13, 7560),  ratio(37, 15120),  ratio(-23, 30240)};

    default:
      throw std::runtime_error(
          "hardcoded biorthogonal coefficients only available for ranks 1-5, "
          "requested rank is : " +
          std::to_string(n_particles));
  }
}

Eigen::Matrix<sequant::rational, Eigen::Dynamic, Eigen::Dynamic>
make_hardcoded_biorthogonalizer_matrix(
    const std::vector<sequant::rational>& first_row, std::size_t n_particles) {
  const auto n = first_row.size();
  Eigen::Matrix<sequant::rational, Eigen::Dynamic, Eigen::Dynamic> M(n, n);

  for (std::size_t row = 0; row < n; ++row) {
    for (std::size_t col = 0; col < n; ++col) {
      perm::Permutation row_perm = perm::unrank(n - 1 - row, n_particles);
      perm::Permutation col_perm = perm::unrank(col, n_particles);

      col_perm->preMultiply(row_perm);

      std::size_t source_idx = perm::rank(col_perm, n_particles);
      M(row, col) = first_row[source_idx];
    }
  }
  return M;
}

Eigen::Matrix<sequant::rational, Eigen::Dynamic, Eigen::Dynamic>
hardcoded_biorthogonalizer_matrix(std::size_t n_particles) {
  auto first_row = hardcoded_biorthogonalizer_row(n_particles);
  return make_hardcoded_biorthogonalizer_matrix(first_row, n_particles);
}

ResultExpr biorthogonal_transform_copy(
    const ResultExpr& expr,
    double threshold = default_biorthogonalizer_pseudoinverse_threshold) {
  container::svector<ResultExpr> wrapper = {expr.clone()};

  biorthogonal_transform(wrapper, threshold);

  return wrapper.front();
}

container::svector<ResultExpr> biorthogonal_transform_copy(
    const container::svector<ResultExpr>& exprs,
    double threshold = default_biorthogonalizer_pseudoinverse_threshold) {
  container::svector<ResultExpr> copy;
  copy.reserve(exprs.size());

  std::transform(exprs.begin(), exprs.end(), std::back_inserter(copy),
                 [](const ResultExpr& expr) { return expr.clone(); });

  biorthogonal_transform(copy, threshold);

  return copy;
}

void biorthogonal_transform(ResultExpr& expr, double threshold) {
  // TODO: avoid copy
  expr = biorthogonal_transform_copy(expr, threshold);
}

Eigen::MatrixXd permutational_overlap_matrix(std::size_t n_particles) {
  const auto n = boost::numeric_cast<Eigen::Index>(factorial(n_particles));

  // The matrix only contains integer entries but all operations we want to do
  // with the matrix will (in general) require non-integer scalars which in
  // Eigen only works if you start from a non-integer matrix.
  Eigen::MatrixXd M(n, n);
  M.setZero();

  // TODO: Can we fill the entire matrix only by knowing the entries of one
  // row/column? For n_particles < 4, every consecutive col/row is only rotated
  // by one compared to the one before
  for (std::size_t row = 0; row < n; ++row) {
    perm::Permutation ref = perm::unrank(row, n_particles);
    ref->invert();

    // The identity permutation always has as many disjoint cycles as the number
    // of elements it acts on
    M(row, row) = std::pow(-2, n_particles);

    for (std::size_t col = row + 1; col < n; ++col) {
      // Get permutation that transforms the permutation of rank1 into the one
      // of current rank i
      perm::Permutation current = perm::unrank(col, n_particles);
      current->postMultiply(ref);

      auto cycles = current->toDisjointCycles(n_particles);
      std::size_t n_cycles = std::distance(cycles.begin(), cycles.end());

      auto entry = std::pow(-2, n_cycles);

      M(row, col) = entry;
      M(col, row) = entry;
    }
  }

  if (n_particles % 2 != 0) {
    M *= -1;
  }

  SEQUANT_ASSERT(M.isApprox(M.transpose()));

  return M;
}

Eigen::MatrixXd compute_biorthogonalizer_matrix(std::size_t n_particles,
                                                double threshold) {
  auto perm_ovlp_mat = permutational_overlap_matrix(n_particles);
  SEQUANT_ASSERT(perm_ovlp_mat.rows() == perm_ovlp_mat.cols());
  SEQUANT_ASSERT(perm_ovlp_mat.isApprox(perm_ovlp_mat.transpose()));

  // Find Pseudo Inverse
  auto decomp =
      Eigen::CompleteOrthogonalDecomposition<decltype(perm_ovlp_mat)>();
  decomp.setThreshold(threshold);
  decomp.compute(perm_ovlp_mat);

  Eigen::MatrixXd pinv = decomp.pseudoInverse();
  // The pseudo inverse of a symmetric matrix should also be symmetric
  SEQUANT_ASSERT(pinv.isApprox(pinv.transpose()));

  // We need to normalize to the amount of non-zero eigenvalues via
  // normalization = #eigenvalues / #non-zero eigenvalues
  // Since perm_ovlp_mat is symmetric, it is diagonalizable and for every
  // diagonalizable matrix, its rank equals the amount of non-zero eigenvalues.
  double normalization =
      static_cast<double>(perm_ovlp_mat.rows()) / decomp.rank();

  pinv *= normalization;

  return pinv;
}

void sort_pairings(ParticlePairings& pairing) {
  std::stable_sort(pairing.begin(), pairing.end(),
                   compare_first_less<IndexPair>{});
}

std::size_t rank_transformation_perms(const ParticlePairings& reference,
                                      const ParticlePairings& current) {
  SEQUANT_ASSERT(reference.size() == current.size());
  SEQUANT_ASSERT(std::is_sorted(reference.begin(), reference.end(),
                                compare_first_less<IndexPair>{}));
  SEQUANT_ASSERT(std::is_sorted(current.begin(), current.end(),
                                compare_first_less<IndexPair>{}));

  perm::Permutation perm = perm::computeTransformationPermutation(
      reference, current, [](const IndexPair& lhs, const IndexPair& rhs) {
        return lhs.second < rhs.second;
      });

  return perm::rank(perm, reference.size());
}

ExprPtr create_expr_for(const ParticlePairings& ref_pairing,
                        const perm::Permutation& perm,
                        const container::svector<ParticlePairings>& pairings,
                        const container::svector<ExprPtr>& base_exprs) {
  // Note: perm only applies to the p->second for every pair p in ref_pairing

  // assert that all pairings are sorted w.r.t. first
  SEQUANT_ASSERT(std::all_of(
      pairings.begin(), pairings.end(), [](const ParticlePairings& pairing) {
        return std::is_sorted(pairing.begin(), pairing.end(),
                              compare_first_less<IndexPair>{});
      }));
  SEQUANT_ASSERT(std::is_sorted(ref_pairing.begin(), ref_pairing.end(),
                                compare_first_less<IndexPair>{}));

  container::set<std::pair<IndexSpace, IndexSpace>> ref_space_pairing;
  ref_space_pairing.reserve(ref_pairing.size());
  for (std::size_t i = 0; i < ref_pairing.size(); ++i) {
    ref_space_pairing.emplace(ref_pairing[i].first.space(),
                              ref_pairing[perm->image(i)].second.space());
  }

  // Look for a ParticlePairings object that pairs indices belonging to index
  // spaces compatible with ref_space_pairing
  auto it = std::find_if(
      pairings.begin(), pairings.end(), [&](const ParticlePairings& p) {
        SEQUANT_ASSERT(p.size() == ref_pairing.size());

        for (const IndexPair& pair : p) {
          if (ref_space_pairing.find(
                  std::make_pair(pair.first.space(), pair.second.space())) ==
              ref_space_pairing.end()) {
            return false;
          }
        }

        return true;
      });

  if (it == pairings.end()) {
    throw std::runtime_error(
        "Missing explicit expression for a required index pairing in "
        "biorthogonalization");
  }

  auto idx = std::distance(pairings.begin(), it);
  const ParticlePairings& base = *it;

  SEQUANT_ASSERT(base.size() == ref_pairing.size());

  container::map<Index, Index> replacements;
  for (std::size_t i = 0; i < base.size(); ++i) {
    std::size_t ref_idx = perm->image(i);

    // Remember that all index pairings are sorted w.r.t. first and hence we are
    // only looking for permutations in second
    SEQUANT_ASSERT(base[i].first == ref_pairing[i].first);
    const bool differs_in_second =
        base[i].second != ref_pairing[ref_idx].second;

    if (!differs_in_second) {
      // This particle pairing is identical
      continue;
    }

    SEQUANT_ASSERT(differs_in_second);

    // Note: we may only permute indices belonging to the same space
    // (otherwise, we would produce non-sensical expressions)
    if (base[i].second.space() == ref_pairing[ref_idx].second.space()) {
      // base and ref_pairing differ in the second index of the current
      // pairing and their index space matches -> can just permute them
      replacements.emplace(base[i].second, ref_pairing[ref_idx].second);
    } else {
      // Index spaces of the differing index (second) in the pairings are
      // different as well. Since the tensors are assumed to be
      // particle-symmetric, we can instead permute the first indices in the
      // pairings, which are of the same space (that's guaranteed by the way we
      // chose base).
      SEQUANT_ASSERT(base[i].first.space() ==
                     ref_pairing[ref_idx].first.space());
      replacements.emplace(base[i].first, ref_pairing[ref_idx].first);
    }
  }

  ExprPtr expr = base_exprs.at(idx)->clone();

  if (!replacements.empty()) {
    if constexpr (assert_enabled()) {
      for ([[maybe_unused]] const auto& [first, second] : replacements) {
        SEQUANT_ASSERT(first.space() == second.space());
      }
    }
    expr = transform_expr(expr, replacements);
  }

  return expr;
}

void biorthogonal_transform(container::svector<ResultExpr>& result_exprs,
                            double threshold) {
  if (result_exprs.empty()) {
    return;
  }

  // We expect all ResultExpr objects to be equal except for the permutation of
  // indices
  // Also, we are assuming that all given ResultExpr objects are
  // particle-symmetric
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.has_label() == result_exprs.front().has_label() &&
               (!expr.has_label() ||
                expr.label() == result_exprs.front().label());
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.symmetry() == result_exprs.front().symmetry();
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.braket_symmetry() == result_exprs.front().braket_symmetry();
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.column_symmetry() == result_exprs.front().column_symmetry();
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.bra().size() == result_exprs.front().bra().size() &&
               std::is_permutation(expr.bra().begin(), expr.bra().end(),
                                   result_exprs.front().bra().begin());
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.ket().size() == result_exprs.front().ket().size() &&
               std::is_permutation(expr.ket().begin(), expr.ket().end(),
                                   result_exprs.front().ket().begin());
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [&](const ResultExpr& expr) {
        return expr.aux().size() == result_exprs.front().aux().size() &&
               std::is_permutation(expr.aux().begin(), expr.aux().end(),
                                   result_exprs.front().aux().begin());
      }));
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [](const ResultExpr& res) {
        return res.column_symmetry() == ColumnSymmetry::Symm;
      }));

  // Furthermore, we expect that there is no symmetrization operator present in
  // the expressions as that would imply transforming also the symmetrization
  // operator, which is incorrect. This is because the idea during
  // biorthogonalization is that we project onto e.g.
  // \tilde{E}^{IJ}_{AB} = c_1 E^{IJ}_{AB} + c_2 E^{JI}_{AB}
  // instead of E^{IJ}_{AB} directly. In either case though, the result looks
  // like R^{IJ}_{AB} and the index pairing of the result is what determines
  // the required symmetrization. Hence, the symmetrization operator must not
  // be changed when transforming from one representation into the other.
  SEQUANT_ASSERT(std::all_of(
      result_exprs.begin(), result_exprs.end(), [](const ResultExpr& res) {
        bool found = false;
        res.expression()->visit(
            [&](const ExprPtr& expr) {
              if (expr->is<Tensor>() &&
                  (expr->as<Tensor>().label() == reserved::symm_label() ||
                   expr->as<Tensor>().label() == reserved::antisymm_label())) {
                found = true;
              };
            },
            true);
        return !found;
      }));

  auto externals = result_exprs |
                   ranges::views::transform([](const ResultExpr& expr) {
                     return expr.index_particle_grouping<IndexPair>();
                   }) |
                   ranges::to<container::svector<ParticlePairings>>();
  ranges::for_each(externals, sort_pairings);

  auto ranks = externals | ranges::views::transform([&](const auto& p) {
                 return rank_transformation_perms(externals.front(), p);
               }) |
               ranges::to<container::svector<std::size_t>>();

  const std::size_t n_particles = externals.front().size();
  auto num_perms = factorial(n_particles);

  auto original_exprs = result_exprs |
                        ranges::views::transform([](const ResultExpr& res) {
                          return res.expression();
                        }) |
                        ranges::to<container::svector<ExprPtr>>();

  using HardcodedMatrix =
      Eigen::Matrix<sequant::rational, Eigen::Dynamic, Eigen::Dynamic>;
  using ComputedMatrix = Eigen::MatrixXd;
  using CacheKey = std::pair<std::size_t, double>;

  static std::mutex cache_mutex;
  static std::condition_variable cache_cv;
  static container::map<CacheKey, std::optional<HardcodedMatrix>>
      hardcoded_cache;
  static container::map<CacheKey, std::optional<ComputedMatrix>> computed_cache;

  constexpr std::size_t max_rank_hardcoded_biorthogonalizer_matrix = 5;
  CacheKey key{n_particles, threshold};

  const HardcodedMatrix* hardcoded_coefficients = nullptr;
  const ComputedMatrix* computed_coefficients = nullptr;

  if (n_particles <= max_rank_hardcoded_biorthogonalizer_matrix) {
    hardcoded_coefficients = &sequant::detail::memoize(
        hardcoded_cache, cache_mutex, cache_cv, key,
        [&] { return hardcoded_biorthogonalizer_matrix(n_particles); });
  } else {
    computed_coefficients = &sequant::detail::memoize(
        computed_cache, cache_mutex, cache_cv, key, [&] {
          return compute_biorthogonalizer_matrix(n_particles, threshold);
        });
    SEQUANT_ASSERT(num_perms == computed_coefficients->rows());
    SEQUANT_ASSERT(num_perms == computed_coefficients->cols());
  }

  for (std::size_t i = 0; i < result_exprs.size(); ++i) {
    result_exprs.at(i).expression() = ex<Constant>(0);
    perm::Permutation reference = perm::unrank(ranks.at(i), n_particles);
    reference->invert();

    for (std::size_t rank = 0; rank < num_perms; ++rank) {
      perm::Permutation perm = perm::unrank(rank, n_particles);
      perm->postMultiply(reference);

      sequant::rational coeff =
          (n_particles <= max_rank_hardcoded_biorthogonalizer_matrix)
              ? (*hardcoded_coefficients)(ranks.at(i), rank)
              : to_rational((*computed_coefficients)(ranks.at(i), rank),
                            threshold);

      result_exprs.at(i).expression() +=
          ex<Constant>(coeff) *
          create_expr_for(externals.at(i), perm, externals, original_exprs);
    }

    simplify(result_exprs.at(i).expression());
  }
}

template <detail::index_group_range IdxGroups>
ExprPtr biorthogonal_transform_impl(const sequant::ExprPtr& expr,
                                    IdxGroups&& ext_index_groups,
                                    const double threshold) {
  ResultExpr res(
      bra(ext_index_groups | ranges::views::transform([](const auto& pair) {
            return get_ket_idx(pair);
          }) |
          ranges::to<container::svector<Index>>()),
      ket(ext_index_groups | ranges::views::transform([](const auto& pair) {
            return get_bra_idx(pair);
          }) |
          ranges::to<container::svector<Index>>()),
      aux(IndexList{}), Symmetry::Nonsymm, BraKetSymmetry::Nonsymm,
      ColumnSymmetry::Symm, {}, expr);

  biorthogonal_transform(res, threshold);

  return res.expression();
}

ExprPtr biorthogonal_transform(
    const sequant::ExprPtr& expr,
    const container::svector<container::svector<sequant::SlottedIndex>>&
        ext_index_groups,
    const double threshold) {
  return biorthogonal_transform_impl(
      expr, as_view_of_index_groups(ext_index_groups), threshold);
}

ExprPtr biorthogonal_transform(
    const sequant::ExprPtr& expr,
    const container::svector<container::svector<sequant::Index>>&
        ext_index_groups,
    const double threshold) {
  return biorthogonal_transform_impl(expr, ext_index_groups, threshold);
}

template <detail::index_group_range IdxGroups>
ExprPtr WK_biorthogonalization_filter_impl(ExprPtr expr, IdxGroups&& ext_idxs) {
  if (!expr->is<Sum>()) return expr;
  if (ext_idxs.size() <= 2) return expr;  // always skip R1 and R2

  // hash filtering logic for R > 2
  container::map<std::size_t, container::vector<ExprPtr>> largest_coeff_terms;

  for (const auto& term : *expr) {
    if (!term->is<Product>()) continue;

    auto product = term.as_shared_ptr<Product>();
    auto scalar = product->scalar();

    sequant::TensorNetwork tn(*product);
    auto hash =
        tn.canonicalize_slots(TensorCanonicalizer::cardinal_tensor_labels())
            .hash_value();

    auto it = largest_coeff_terms.find(hash);
    if (it == largest_coeff_terms.end()) {
      largest_coeff_terms[hash] = {term};
    } else {
      if (!it->second.empty()) {
        auto existing_scalar = it->second[0]->as<Product>().scalar();
        auto existing_abs = abs(existing_scalar);
        auto current_abs = abs(scalar);

        if (current_abs > existing_abs) {
          it->second.clear();
          it->second.push_back(term);
        } else if (current_abs == existing_abs) {
          it->second.push_back(term);
        }
      }
    }
  }

  Sum filtered;
  for (const auto& [_, terms] : largest_coeff_terms) {
    for (const auto& t : terms) {
      filtered.append(t);
    }
  }
  auto result = ex<Sum>(filtered);

  return result;
}

ExprPtr WK_biorthogonalization_filter(
    ExprPtr expr,
    const container::svector<container::svector<SlottedIndex>>& ext_idxs) {
  return WK_biorthogonalization_filter_impl(expr,
                                            as_view_of_index_groups(ext_idxs));
}
ExprPtr WK_biorthogonalization_filter(
    ExprPtr expr,
    const container::svector<container::svector<Index>>& ext_idxs) {
  return WK_biorthogonalization_filter_impl(expr, ext_idxs);
}

template <detail::index_group_range IdxGroups>
ExprPtr biorthogonal_transform_pre_nnsproject_impl(
    ExprPtr& expr, IdxGroups&& ext_idxs, bool factor_out_nns_projector) {
  using ranges::views::transform;

  // Remove leading S operator if present
  for (auto& term : *expr) {
    if (term->is<Product>())
      term =
          remove_tensor(term.as_shared_ptr<Product>(), reserved::symm_label());
  }

  auto bt = biorthogonal_transform_impl(
      expr, ext_idxs, default_biorthogonalizer_pseudoinverse_threshold);

  auto bixs = ext_idxs | transform([](auto&& vec) { return get_bra_idx(vec); });
  auto kixs = ext_idxs | transform([](auto&& vec) { return get_ket_idx(vec); });
  ExprPtr S_tensor =
      ex<Tensor>(Tensor{reserved::symm_label(), bra(kixs), ket(bixs)});

  if (factor_out_nns_projector) {
    if (ext_idxs.size() > 1) {
      bt = S_tensor * bt;
    }
    simplify(bt);

    bt = S_maps(bt);
    canonicalize(bt);
    bt = WK_biorthogonalization_filter_impl(bt, ext_idxs);
  }

  bt = S_tensor * bt;
  simplify(bt);

  return bt;
}

ExprPtr biorthogonal_transform_pre_nnsproject(
    ExprPtr& expr,
    const container::svector<container::svector<SlottedIndex>>& ext_idxs,
    bool factor_out_nns_projector) {
  return biorthogonal_transform_pre_nnsproject_impl(
      expr, as_view_of_index_groups(ext_idxs), factor_out_nns_projector);
}

ExprPtr biorthogonal_transform_pre_nnsproject(
    ExprPtr& expr,
    const container::svector<container::svector<Index>>& ext_idxs,
    bool factor_out_nns_projector) {
  return biorthogonal_transform_pre_nnsproject_impl(expr, ext_idxs,
                                                    factor_out_nns_projector);
}

namespace detail {

std::vector<double> compute_nns_p_coeffs(std::size_t n_particles,
                                         double threshold) {
  auto perm_ovlp_mat = permutational_overlap_matrix(n_particles);
  auto normalized_pinv =
      compute_biorthogonalizer_matrix(n_particles, threshold);
  Eigen::MatrixXd nns_matrix = perm_ovlp_mat * normalized_pinv;

  auto num_perms = nns_matrix.rows();
  std::vector<double> coeffs;
  coeffs.reserve(num_perms);
  for (std::size_t i = 0; i < num_perms; ++i) {
    coeffs.push_back(nns_matrix(num_perms - 1, i));
  }
  return coeffs;
}

container::svector<size_t> compute_permuted_indices(
    const container::svector<size_t>& indices, size_t perm_rank,
    size_t n_particles) {
  perm::Permutation perm_obj = perm::unrank(perm_rank, n_particles);

  container::svector<size_t> permuted_indices(n_particles);
  for (size_t i = 0; i < n_particles; ++i) {
    permuted_indices[i] = indices[perm_obj[i]];
  }
  return permuted_indices;
}

}  // namespace detail

}  // namespace sequant::mbpt