Program Listing for File permutation.hpp

Return to documentation for file (SeQuant/core/utility/permutation.hpp)

#ifndef SEQUANT_PERMUTATION_HPP
#define SEQUANT_PERMUTATION_HPP

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <range/v3/algorithm/all_of.hpp>
#include <range/v3/algorithm/find.hpp>
#include <range/v3/algorithm/for_each.hpp>

#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <numeric>
#include <ranges>

namespace sequant {


template <typename Seq0, typename Seq1>
std::size_t count_cycles(Seq0&& v0, Seq1&& v1) {
  using std::ranges::begin;
  using std::ranges::end;
  using std::ranges::size;
  const std::size_t n0 = size(v0);
  const std::size_t n1 = size(v1);
  const std::size_t n = n0 + n1;

  // preconditions (debug builds only):
  //  (1) the two rows have equal length: every column pairs a v0 slot with a
  //      v1 slot. The column-edge loop below only pairs min(n0, n1) slots, so
  //      with n0 != n1 the surplus slots in the longer row would stay unpaired
  //      and be miscounted as their own components -- a meaningless loop count
  //      rather than a detected error.
  //  (2) every value occurs in exactly two slots across v0 and v1 combined.
  //      This is what makes the column+contraction graph 2-regular on its
  //      internal slots, so it decomposes into disjoint cycles and the
  //      component count is the loop count. It generalizes the former "v0 is a
  //      permutation of v1" contract (which also implied exactly two
  //      occurrences, one per row) to allow both occurrences in the same row
  //      (bra-bra / ket-ket edges from reoriented bra-ket-symmetric tensors).
  //      Without it malformed input (a value appearing once, or 3+ times) is
  //      silently accepted.
  if constexpr (assert_enabled()) {
    SEQUANT_ASSERT(n0 == n1);
    container::map<std::ranges::range_value_t<Seq0>, std::size_t> counts;
    for (auto&& x : v0) ++counts[x];
    for (auto&& x : v1) ++counts[x];
    SEQUANT_ASSERT(
        ranges::all_of(counts, [](auto const& kv) { return kv.second == 2; }));
  }

  // slot ids: row-0 slot i -> i ; row-1 slot i -> n0 + i. Size the union-find
  // by n0 + n1 (not 2*n0) so it is safe even if the two rows differ in length.
  // union-find with path halving
  container::svector<std::size_t> parent(n);
  std::iota(parent.begin(), parent.end(), std::size_t{0});
  auto find = [&parent](std::size_t x) {
    while (parent[x] != x) {
      parent[x] = parent[parent[x]];
      x = parent[x];
    }
    return x;
  };
  auto unite = [&](std::size_t a, std::size_t b) { parent[find(a)] = find(b); };

  // column edges: the two rows of each column carry the same spin
  const std::size_t ncols = std::min(n0, n1);
  for (std::size_t i = 0; i < ncols; ++i) unite(i, n0 + i);

  // contraction edges: the slots sharing an index value carry the same spin
  container::map<std::ranges::range_value_t<Seq0>, std::size_t> first_slot;
  auto add_slot = [&first_slot, &unite](auto const& idx, std::size_t slot) {
    auto [it, inserted] = first_slot.try_emplace(idx, slot);
    if (!inserted)
      unite(it->second, slot);  // contraction with first occurrence
  };
  {
    std::size_t i = 0;
    for (auto it = begin(v0); it != end(v0); ++it, ++i) add_slot(*it, i);
    i = 0;
    for (auto it = begin(v1); it != end(v1); ++it, ++i) add_slot(*it, n0 + i);
  }

  // number of connected components = number of spin loops
  std::size_t n_cycles = 0;
  for (std::size_t i = 0; i < n; ++i)
    if (find(i) == i) ++n_cycles;
  return n_cycles;
};

template <std::integral T>
int permutation_parity(std::span<T> p, bool overwrite = false) {
  // https://stackoverflow.com/a/20703469
  // compute cycles, mutating elements of p to mark used elements
  const std::size_t N = p.size();
  int parity = 1;
  // search for next element to start cycle with
  for (std::size_t k = 0; k != N; ++k) {
    if (p[k] >= N) continue;
    std::size_t i = k;
    std::size_t cycle_length = 1;
    do {
      i = p[i];
      p[i] += N;
      ++cycle_length;
    } while (p[i] < N);
    if (cycle_length % 2 == 0) parity *= -1;
  }

  if (overwrite) {
    std::ranges::for_each(p, [N](auto& e) { e -= N; });
  }

  return parity;
}

}  // namespace sequant

#endif