Program Listing for File indices.hpp

Return to documentation for file (SeQuant/core/utility/indices.hpp)

#ifndef SEQUANT_CORE_UTILITY_INDICES_HPP
#define SEQUANT_CORE_UTILITY_INDICES_HPP

#include <SeQuant/core/attr.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/op.hpp>
#include <SeQuant/core/reserved.hpp>
#include <SeQuant/core/slotted_index.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <range/v3/view.hpp>

#include <algorithm>
#include <concepts>
#include <iterator>
#include <optional>
#include <ranges>
#include <set>
#include <tuple>
#include <type_traits>
#include <vector>

namespace sequant {

namespace detail {
template <typename Range>
struct not_in {
  const Range& range;

  not_in(const Range& range) : range(range) {}

  template <typename T>
  bool operator()(const T& element) const {
    return std::find(range.begin(), range.end(), element) == range.end();
  }
};

template <typename Container, typename Element>
void remove_one(Container& container, const Element& e) {
  auto iter = std::find(container.begin(), container.end(), e);

  if (iter != container.end()) {
    container.erase(iter);
  }
}

}  // namespace detail

template <typename Container = std::vector<Index>>
struct IndexGroups {
  Container bra;
  Container ket;
  Container aux;

  bool operator==(const IndexGroups<Container>& other) const {
    return bra == other.bra && ket == other.ket && aux == other.aux;
  }

  bool operator!=(const IndexGroups<Container>& other) const {
    return !(*this == other);
  }
};

template <typename Container = std::vector<Index>>
struct TensorOfTensorIndices {
  Container outer;
  Container inner;
};

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const ExprPtr& expr);

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Constant&) {
  return {};
}

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Variable&) {
  return {};
}

template <typename Container = container::svector<Index>>
IndexGroups<Container> get_uncontracted_indices(const Tensor& t1,
                                                const Tensor& t2) {
  static_assert(std::is_same_v<typename Container::value_type, Index>);

  IndexGroups<Container> groups;

  // Bra indices
  std::copy_if(t1.bra().begin(), t1.bra().end(), std::back_inserter(groups.bra),
               detail::not_in{t2.ket()});
  std::copy_if(t2.bra().begin(), t2.bra().end(), std::back_inserter(groups.bra),
               detail::not_in{t1.ket()});

  // Ket indices
  std::copy_if(t1.ket().begin(), t1.ket().end(), std::back_inserter(groups.ket),
               detail::not_in{t2.bra()});
  std::copy_if(t2.ket().begin(), t2.ket().end(), std::back_inserter(groups.ket),
               detail::not_in{t1.bra()});

  // Auxiliary indices
  std::copy_if(t1.aux().begin(), t1.aux().end(), std::back_inserter(groups.aux),
               detail::not_in{t2.aux()});
  std::copy_if(t2.aux().begin(), t2.aux().end(), std::back_inserter(groups.aux),
               detail::not_in{t1.aux()});

  return groups;
}

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Tensor& tensor) {
  IndexGroups<Container> groups;
  std::set<Index> encounteredIndices;

  for (const Index& current : tensor.bra()) {
    if (encounteredIndices.find(current) == encounteredIndices.end()) {
      groups.bra.push_back(current);
      encounteredIndices.insert(current);
    } else {
      // There can't be indices in bra at this point, so we don't have to remove
      // from that
      detail::remove_one(groups.bra, current);
    }
  }

  for (const Index& current : tensor.ket()) {
    if (encounteredIndices.find(current) == encounteredIndices.end()) {
      groups.ket.push_back(current);
      encounteredIndices.insert(current);
    } else {
      detail::remove_one(groups.bra, current);
      detail::remove_one(groups.ket, current);
    }
  }

  for (const Index& current : tensor.aux()) {
    if (encounteredIndices.find(current) == encounteredIndices.end()) {
      groups.aux.push_back(current);
      encounteredIndices.insert(current);
    } else {
      detail::remove_one(groups.bra, current);
      detail::remove_one(groups.ket, current);
      detail::remove_one(groups.aux, current);
    }
  }

  return groups;
}

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Sum& sum) {
  // In order for the sum to be valid, all summands must have the same
  // external indices, so it suffices to look only at the first one
  return sum.summands().empty() ? IndexGroups<Container>{}
                                : get_unique_indices<Container>(sum.summand(0));
}

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Product& product) {
  std::set<Index> encounteredIndices;
  IndexGroups<Container> groups;

  for (const ExprPtr& current : product) {
    IndexGroups<Container> currentGroups =
        get_unique_indices<Container>(current);

    for (Index& current : currentGroups.bra) {
      if (encounteredIndices.find(current) == encounteredIndices.end()) {
        encounteredIndices.insert(current);
        groups.bra.push_back(std::move(current));
      } else {
        detail::remove_one(groups.bra, current);
        detail::remove_one(groups.ket, current);
        detail::remove_one(groups.aux, current);
      }
    }

    // Same for ket indices
    for (Index& current : currentGroups.ket) {
      if (encounteredIndices.find(current) == encounteredIndices.end()) {
        encounteredIndices.insert(current);
        groups.ket.push_back(std::move(current));
      } else {
        detail::remove_one(groups.bra, current);
        detail::remove_one(groups.ket, current);
        detail::remove_one(groups.aux, current);
      }
    }

    // Same for aux indices
    for (Index& current : currentGroups.aux) {
      if (encounteredIndices.find(current) == encounteredIndices.end()) {
        encounteredIndices.insert(current);
        groups.aux.push_back(std::move(current));
      } else {
        detail::remove_one(groups.bra, current);
        detail::remove_one(groups.ket, current);
        detail::remove_one(groups.aux, current);
      }
    }
  }

  return groups;
}

template <typename Container = std::vector<Index>>
IndexGroups<Container> get_unique_indices(const Expr& expr) {
  if (expr.is<Constant>()) {
    return get_unique_indices<Container>(expr.as<Constant>());
  } else if (expr.is<Variable>()) {
    return get_unique_indices<Container>(expr.as<Variable>());
  } else if (expr.is<Tensor>()) {
    return get_unique_indices<Container>(expr.as<Tensor>());
  } else if (expr.is<Sum>()) {
    return get_unique_indices<Container>(expr.as<Sum>());
  } else if (expr.is<Product>()) {
    return get_unique_indices<Container>(expr.as<Product>());
  } else {
    throw std::runtime_error(
        "Encountered unsupported expression type in get_unique_indices");
  }
}
template <typename Container>
IndexGroups<Container> get_unique_indices(const ExprPtr& expr) {
  return get_unique_indices<Container>(*expr);
}

struct IndexSlotCounters {
  std::int64_t bra = 0;
  std::int64_t ket = 0;
  std::int64_t aux = 0;
  std::int64_t proto = 0;

  std::int64_t total() const { return bra + ket + aux + proto; }
  std::int64_t nonproto() const { return bra + ket + aux; }

  IndexSlotCounters& increment(SlotType slot) {
    switch (slot) {
      case SlotType::Bra:
        ++bra;
        break;
      case SlotType::Ket:
        ++ket;
        break;
      case SlotType::Aux:
        ++aux;
        break;
      case SlotType::Proto:
        ++proto;
        break;
    }
    return *this;
  }
};

template <typename Map = container::map<Index, IndexSlotCounters>>
Map get_used_indices_with_counts(const Expr& expr) {
  Map all_indices;
  auto emplace = [&all_indices](const Index& idx, SlotType slot) {
    if (!idx.nonnull()) return;
    auto add = [&all_indices](const Index& i, SlotType slot) {
      auto it = all_indices.find(i);
      if (it == all_indices.end()) {
        std::tie(it, std::ignore) = all_indices.emplace(i, IndexSlotCounters{});
      }
      it->second.increment(slot);
    };
    add(idx, slot);
    for (const auto& i : idx.proto_indices()) {
      add(i, SlotType::Proto);
    }
  };
  auto process_abstract_tensor = [&emplace](const AbstractTensor& t) {
    ranges::for_each(
        t._bra(), [&emplace](const auto& idx) { emplace(idx, SlotType::Bra); });
    ranges::for_each(
        t._ket(), [&emplace](const auto& idx) { emplace(idx, SlotType::Ket); });
    ranges::for_each(
        t._aux(), [&emplace](const auto& idx) { emplace(idx, SlotType::Aux); });
  };

  auto collect_indices = [&process_abstract_tensor](const Expr& expr) {
    if (expr.is<AbstractTensor>()) {
      process_abstract_tensor(expr.as<AbstractTensor>());
    } else if (expr.is<NormalOperatorSequence<Statistics::FermiDirac>>()) {
      for (auto& op :
           expr.as<NormalOperatorSequence<Statistics::FermiDirac>>()) {
        process_abstract_tensor(op.as<AbstractTensor>());
      }
    } else if (expr.is<NormalOperatorSequence<Statistics::BoseEinstein>>()) {
      for (auto& op :
           expr.as<NormalOperatorSequence<Statistics::BoseEinstein>>()) {
        process_abstract_tensor(op.as<AbstractTensor>());
      }
    } else if (expr.is<NormalOperatorSequence<Statistics::Arbitrary>>()) {
      for (auto& op :
           expr.as<NormalOperatorSequence<Statistics::Arbitrary>>()) {
        process_abstract_tensor(op.as<AbstractTensor>());
      }
    }
  };

  if (expr.is_atom()) {
    collect_indices(expr);
  } else {
    expr.visit([&](const ExprPtr& expr) { collect_indices(*expr); },
               /*atoms_only=*/true);
  }

  return all_indices;
}

template <typename Map = container::map<Index, IndexSlotCounters>>
Map get_used_indices_with_counts(const ExprPtr& expr) {
  return get_used_indices_with_counts<Map>(*expr);
}

template <typename Set = container::set<Index>>
Set get_used_indices(const Expr& expr) {
  return get_used_indices_with_counts(expr) |
         std::views::transform(
             [](const auto& idx_count) { return idx_count.first; }) |
         ranges::to<Set>;
}

template <typename Set = container::set<Index>>
Set get_used_indices(const ExprPtr& expr) {
  return get_used_indices<Set>(*expr);
}

template <typename Container = std::vector<Index>, typename Rng>
TensorOfTensorIndices<Container> tot_indices(Rng const& idxs) {
  using ranges::not_fn;
  using ranges::views::concat;
  using ranges::views::filter;
  using ranges::views::join;
  using ranges::views::transform;

  constexpr auto emplace_into = [](Container& target, auto&& value) {
    if constexpr (requires { target.emplace_back(value); }) {
      // for sequence containers like vectors, lists
      target.emplace_back(value);
    } else if constexpr (requires { target.emplace(value); }) {
      // for associative containers like set
      target.emplace(value);
    } else {
      static_assert(false,
                    "Container does not support emplace_back or emplace");
    }
  };

  TensorOfTensorIndices<Container> result;
  auto& outer = result.outer;

  for (auto&& i : idxs | transform(&Index::proto_indices) | join)
    if (!ranges::contains(outer, i)) emplace_into(outer, i);

  for (auto&& i : idxs | filter(not_fn(&Index::has_proto_indices)))
    if (!ranges::contains(outer, i)) emplace_into(outer, i);

  auto& inner = result.inner;
  for (auto&& i : idxs | filter(&Index::has_proto_indices))
    emplace_into(inner, i);

  return result;
}

inline bool ordinal_compare(Index const& idx1, Index const& idx2) {
  return idx1.ordinal() < idx2.ordinal();
}

std::string csv_labels(meta::range_of<Index> auto&& idxs) {
  using ranges::views::concat;
  using ranges::views::intersperse;
  using ranges::views::join;
  using ranges::views::single;
  using ranges::views::transform;

  auto str = [](Index const& i) {
    auto v = concat(single(i.label()),                             //
                    i.proto_indices() | transform(&Index::label))  //
             | join;
    return toUtf8(v | ranges::to<std::wstring>);
  };

  return std::forward<decltype(idxs)>(idxs)  //
         | transform(str)                    //
         | intersperse(",")                  //
         | join                              //
         | ranges::to<std::string>;
}

template <template <class> class Container = container::svector,
          template <class> class Group = container::svector>
Container<Group<SlottedIndex>> external_indices(const Expr& expr) {
  using HolderType = Container<Group<SlottedIndex>>;

  if (!expr.is<Sum>() && !expr.is<Product>() && !expr.is<Tensor>()) {
    return {};
  }

  if (expr.is<Tensor>()) {
    const Tensor& tensor = expr.as<Tensor>();

    const std::size_t num_braket =
        std::max(tensor.bra_rank(), tensor.ket_rank());
    HolderType cont;
    cont.resize(num_braket + tensor.aux_rank());

    for (std::size_t i = 0; i < tensor.bra_rank(); ++i) {
      cont.at(i).emplace_back(tensor.bra()[i], SlotType::Bra);
    }
    for (std::size_t i = 0; i < tensor.ket_rank(); ++i) {
      cont.at(i).emplace_back(tensor.ket()[i], SlotType::Ket);
    }
    for (std::size_t i = 0; i < tensor.aux_rank(); ++i) {
      cont.at(i + num_braket).emplace_back(tensor.aux()[i], SlotType::Aux);
    }

    if (tensor.label() == reserved::symm_label() ||
        tensor.label() == reserved::antisymm_label()) {
      // Note: In tensors representing symmetrization operators, bra and ket
      // indices are conjugated (reversed). Hence, we have to swap the
      // determined bra and ket indices.
      for (std::size_t i = 0; i < num_braket; ++i) {
        std::swap(cont.at(i).at(0).index(), cont.at(i).at(1).index());
      }
    }

    return cont;
  }

  std::optional<Tensor> symmetrizer;
  expr.visit(
      [&](const ExprPtr& expr) {
        if (expr.is<Tensor>() &&
            (expr.as<Tensor>().label() == reserved::symm_label() ||
             expr.as<Tensor>().label() == reserved::antisymm_label())) {
          SEQUANT_ASSERT(!symmetrizer.has_value() ||
                         symmetrizer.value() == expr.as<Tensor>());
          symmetrizer = expr.as<Tensor>();
        }
      },
      true);

  if (symmetrizer.has_value()) {
    // Generate external index list from symmetrization operator
    return external_indices<Container, Group>(symmetrizer.value());
  }

  IndexGroups groups = get_unique_indices<container::svector<Index>>(expr);

  const std::size_t num_braket = std::max(groups.bra.size(), groups.ket.size());

  HolderType cont;
  cont.resize(num_braket + groups.aux.size());

  for (std::size_t i = 0; i < groups.bra.size(); ++i) {
    cont.at(i).emplace_back(groups.bra[i], SlotType::Bra);
  }
  for (std::size_t i = 0; i < groups.ket.size(); ++i) {
    cont.at(i).emplace_back(groups.ket[i], SlotType::Ket);
  }
  for (std::size_t i = 0; i < groups.aux.size(); ++i) {
    cont.at(num_braket + i).emplace_back(groups.aux[i], SlotType::Aux);
  }

  return cont;
}

template <template <class> class Container = container::svector,
          template <class> class Group = container::svector>
Container<Group<SlottedIndex>> external_indices(const ExprPtr& expr) {
  return external_indices<Container, Group>(*expr);
}

template <typename T>
concept SlottedIndexContainer =
    std::ranges::range<T> &&
    std::same_as<std::ranges::range_value_t<T>, SlottedIndex>;

template <typename T>
concept IndexContainer =
    std::ranges::range<T> && std::same_as<std::ranges::range_value_t<T>, Index>;

template <typename T>
concept SlottedIndexTuple =
    std::is_same_v<std::tuple_element_t<0, std::remove_cvref_t<T>>,
                   SlottedIndex> &&
    std::is_same_v<std::tuple_element_t<1, std::remove_cvref_t<T>>,
                   SlottedIndex>;
template <typename T>
concept IndexTuple =
    std::is_same_v<std::tuple_element_t<0, std::remove_cvref_t<T>>, Index> &&
    std::is_same_v<std::tuple_element_t<1, std::remove_cvref_t<T>>, Index>;

template <typename T>
concept SlottedIndexGroup = SlottedIndexContainer<T> || SlottedIndexTuple<T>;

template <typename T>
concept IndexGroup = IndexContainer<T> || IndexTuple<T>;

template <typename T>
concept SlottedIndexGroupContainer =
    std::ranges::range<T> && SlottedIndexGroup<std::ranges::range_value_t<T>>;

template <typename T>
concept IndexGroupContainer =
    std::ranges::range<T> && IndexGroup<std::ranges::range_value_t<T>>;

template <SlottedIndexGroup Group>
decltype(auto) get_bra_idx(Group&& group) {
  if constexpr (SlottedIndexTuple<Group>) {
    SEQUANT_ASSERT(std::get<0>(group).slot_type() == SlotType::Bra ||
                   std::get<1>(group).slot_type() == SlotType::Bra);
    return std::get<0>(group).slot_type() == SlotType::Bra
               ? std::get<0>(group).index()
               : std::get<1>(group).index();
  } else {
    using std::begin;
    using std::end;
    auto it =
        std::find_if(begin(group), end(group), [](const SlottedIndex& idx) {
          return idx.slot_type() == SlotType::Bra;
        });
    SEQUANT_ASSERT(it != end(group));
    return it->index();
  }
}

template <IndexGroup Group>
decltype(auto) get_bra_idx(Group&& group) {
  if constexpr (IndexTuple<Group>) {
    return std::get<0>(group);
  } else {
    using std::begin;
    using std::end;
    auto it = begin(group);
    SEQUANT_ASSERT(it != end(group));
    return *it;
  }
}

template <SlottedIndexGroup Group>
decltype(auto) get_ket_idx(Group&& group) {
  if constexpr (SlottedIndexTuple<Group>) {
    SEQUANT_ASSERT(std::get<0>(group).slot_type() == SlotType::Ket ||
                   std::get<1>(group).slot_type() == SlotType::Ket);
    return std::get<1>(group).slot_type() == SlotType::Ket
               ? std::get<1>(group).index()
               : std::get<0>(group).index();
  } else {
    // Note: We're using reverse iteration order as typically, we expect the ket
    // slot to be the second of two
    using std::rbegin;
    using std::rend;
    auto it =
        std::find_if(rbegin(group), rend(group), [](const SlottedIndex& idx) {
          return idx.slot_type() == SlotType::Ket;
        });
    SEQUANT_ASSERT(it != rend(group));
    return it->index();
  }
}

template <IndexGroup Group>
decltype(auto) get_ket_idx(Group&& group) {
  if constexpr (IndexTuple<Group>) {
    return std::get<1>(group);
  } else {
    using std::begin;
    using std::end;
    auto it = begin(group);
    SEQUANT_ASSERT(it != end(group));
    std::advance(it, 1);
    SEQUANT_ASSERT(it != end(group));
    return *it;
  }
}

decltype(auto) as_index_group_view(SlottedIndexGroup auto&& group) {
  static_assert(static_cast<int>(SlotType::Bra) == 0);
  static_assert(static_cast<int>(SlotType::Ket) == 1);

  // We have to ensure a unique order of indices if we're getting rid of the
  // SlotType tag
  if constexpr (!meta::is_immutable_v<decltype(group)>) {
    std::ranges::sort(group, std::less{}, &SlottedIndex::slot_type);
  } else {
    SEQUANT_ASSERT(
        std::ranges::is_sorted(group, std::less{}, &SlottedIndex::slot_type));
  }

  return group | std::ranges::views::transform(
                     [](auto&& idx) -> decltype(auto) { return idx.index(); });
}

decltype(auto) as_view_of_index_groups(
    SlottedIndexGroupContainer auto&& groups) {
  return groups |
         std::ranges::views::transform([](auto&& group) -> decltype(auto) {
           return as_index_group_view(group);
         });
}

auto subset_index_counts(meta::range_of<Index, 2> auto const& rng) {
  size_t const N = ranges::distance(rng);
  SEQUANT_ASSERT(N <= 24 &&
                 "subset_index_counts: N > 24 would require excessive memory");
  container::vector<container::map<Index, size_t, Index::FullLabelCompare>>
      result((size_t{1} << N));
  for (size_t i = 1; i < result.size(); ++i) {
    for (auto&& ixs : bits::on_bits_index(i) | bits::sieve(rng)) {
      for (auto&& ix : ixs)
        if (auto [it, inserted] = result[i].try_emplace(ix, 1); !inserted)  //
          ++(it->second);
    }
  }
  return result;
}

auto subset_target_indices(meta::range_of<Index, 2> auto const& rng,
                           meta::range_of<Index> auto const& tixs) {
  using IndexSet = container::set<Index, Index::FullLabelCompare>;
  size_t const N = ranges::distance(rng);
  SEQUANT_ASSERT(
      N <= 24 &&
      "subset_target_indices: N > 24 would require excessive memory");
  container::vector<IndexSet> result((size_t{1} << N));

  for (size_t i = 0; i < N; ++i)
    for (auto&& ix : ranges::at(rng, i)) result[(size_t{1} << i)].emplace(ix);

  auto counts = subset_index_counts(rng);

  for (auto&& [k, v] : *counts.rbegin())
    if (v == 1 || ranges::contains(tixs, k)) result.rbegin()->emplace(k);

  for (size_t i = 0; i < result.size(); ++i)
    for (auto&& [k, v] : counts[i])
      if (v == 1 || (v > 0 && counts.at(counts.size() - i - 1).contains(k)))
        result[i].emplace(k);

  return result;
}

template <typename T, typename Set = std::set<T>,
          typename Vec = std::vector<Set>>
struct LTRUncontractedIndices {
  Vec children;
  Vec imed;
};

template <typename T, typename Set = std::set<T>>
auto left_to_right_binarization_indices(meta::range_of<Set> auto const& rng,
                                        Set const& uncontract) {
  using ranges::views::filter;
  using CountMap = std::map<T, size_t, typename Set::key_compare>;
  LTRUncontractedIndices<T, Set> result;

  std::vector<CountMap> counts;
  for (auto acc = CountMap{}; auto&& ixs : rng) {
    for (auto&& ix : ixs) {
      auto [it, inserted] = acc.emplace(ix, 1);
      if (!inserted) ++(it->second);
    }
    counts.push_back(acc);
  }

  auto const& max_count = counts.back();

  auto survives_in_children = [&max_count,
                               &uncontract](auto const& ix) -> bool {
    return max_count.at(ix) > 1 || uncontract.contains(ix);
  };

  for (auto&& ixs : rng)
    result.children.emplace_back(ixs                             //
                                 | filter(survives_in_children)  //
                                 | ranges::to<Set>);

  auto survives_in_imed = [&max_count, &uncontract](auto&& kv) -> bool {
    auto&& [k, v] = kv;
    auto mk = max_count.at(k);
    return v < mk || uncontract.contains(k);
  };

  for (auto&& ixcs : counts) {
    result.imed.emplace_back(ixcs                        //
                             | filter(survives_in_imed)  //
                             | std::views::elements<0>   //
                             | ranges::to<Set>);
  }

  return result;
}

}  // namespace sequant

#endif  // SEQUANT_CORE_UTILITY_INDICES_HPP