Program Listing for File sum.cpp

Return to documentation for file (SeQuant/core/optimize/sum.cpp)

#include <SeQuant/core/binary_node.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/eval_expr.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/optimize/sum.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <range/v3/range/operations.hpp>
#include <range/v3/view/map.hpp>
#include <range/v3/view/reverse.hpp>
#include <range/v3/view/tail.hpp>
#include <range/v3/view/transform.hpp>

#include <stack>

namespace sequant::opt {

bool has_only_single_atom(const ExprPtr& term) {
  if (term->is_atom()) {
    return true;
  }

  // Recursively check that all elements in the expression tree have only a
  // single element in them. At this point this means checking for Sum or
  // Product objects that only have a single addend or factor respectively.
  return term->size() == 1 && has_only_single_atom(*term->begin());
}

container::vector<container::vector<size_t>> clusters(
    Sum const& expr, container::vector<FullBinaryNode<EvalExpr>> const& nodes) {
  SEQUANT_ASSERT(nodes.size() == expr.size());
  using ranges::views::tail;
  using ranges::views::transform;
  using hash_t = size_t;
  using pos_t = size_t;
  using stack_t = std::stack<pos_t, container::vector<pos_t>>;

  container::map<hash_t, container::set<pos_t>> positions;
  {
    pos_t pos = 0;
    auto visitor = [&positions, &pos](auto const& n) {
      auto h = hash::value(*n);
      if (auto&& found = positions.find(h); found != positions.end()) {
        found->second.emplace(pos);
      } else {
        positions.emplace(h, decltype(positions)::mapped_type{pos});
      }
    };

    for (auto const& term : expr) {
      auto const& node = nodes[pos];
      if (has_only_single_atom(term)) {
        visitor(node);
      } else {
        node.visit_internal(visitor);
      }
      ++pos;
    }
  }

  container::map<pos_t, container::vector<pos_t>> connections;
  {
    for (auto const& [_, v] : positions) {
      auto const v0 = ranges::front(v);
      auto const v_ = ranges::views::tail(v) |
                      ranges::to<decltype(connections)::mapped_type>;
      if (auto&& found = connections.find(v0); found != connections.end())
        for (auto p : v_) found->second.push_back(p);
      else
        connections.emplace(v0, v_);
    }
  }
  positions.clear();

  container::vector<container::vector<pos_t>> result;
  {
    container::set<pos_t> visited;
    for (auto k : connections | ranges::views::keys)
      if (!visited.contains(k)) {
        stack_t dfs_stack;
        dfs_stack.push(k);
        container::vector<pos_t> clstr;
        while (!dfs_stack.empty()) {
          auto p = dfs_stack.top();
          dfs_stack.pop();
          if (!visited.contains(p)) {
            clstr.push_back(p);
            visited.emplace(p);
          }
          if (auto&& found = connections.find(p); found != connections.end())
            for (auto p_ : ranges::views::reverse(found->second))
              dfs_stack.push(p_);
        }
        result.emplace_back(std::move(clstr));
      }
  }
  return result;
}

container::vector<container::vector<size_t>> clusters(Sum const& expr) {
  container::vector<FullBinaryNode<EvalExpr>> nodes;
  nodes.reserve(expr.size());
  // binarize(ExprPtr) is deprecated for caller-visible head construction
  // (positional bra/ket split); here we only consume per-summand tree costs
  // for clustering, so the head's bra/ket layout never escapes this file.
  SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN
  for (auto const& term : expr) nodes.push_back(binarize(term));
  SEQUANT_PRAGMA_IGNORE_DEPRECATED_END
  return clusters(expr, nodes);
}

Sum reorder(Sum const& sum,
            container::vector<FullBinaryNode<EvalExpr>> const& nodes) {
  Sum result;

  for (auto const& clstr : clusters(sum, nodes))
    for (auto p : clstr) result.append(sum.at(p));

  SEQUANT_ASSERT(result.size() == sum.size());
  return result;
}

Sum reorder(Sum const& sum) {
  container::vector<FullBinaryNode<EvalExpr>> nodes;
  nodes.reserve(sum.size());
  // per-summand binarize for ordering only; positional head doesn't escape.
  SEQUANT_PRAGMA_IGNORE_DEPRECATED_BEGIN
  for (auto const& term : sum) nodes.push_back(binarize(term));
  SEQUANT_PRAGMA_IGNORE_DEPRECATED_END
  return reorder(sum, nodes);
}

}  // namespace sequant::opt