Program Listing for File sum.cpp

Return to documentation for file (SeQuant/core/optimize/sum.cpp)

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/eval_expr.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/optimize/sum.hpp>

#include <stack>

namespace sequant::opt {

bool has_only_single_atom(const ExprPtr& term) {
  if (term->is_atom()) {
    return true;
  }

  // Recursively check that all elements in the expression tree have only a
  // single element in them. At this point this means checking for Sum or
  // Product objects that only have a single addend or factor respectively.
  return term->size() == 1 && has_only_single_atom(*term->begin());
}

container::vector<container::vector<size_t>> clusters(Sum const& expr) {
  using ranges::views::tail;
  using ranges::views::transform;
  using hash_t = size_t;
  using pos_t = size_t;
  using stack_t = std::stack<pos_t, container::vector<pos_t>>;

  container::map<hash_t, container::set<pos_t>> positions;
  {
    pos_t pos = 0;
    auto visitor = [&positions, &pos](auto const& n) {
      auto h = hash::value(*n);
      if (auto&& found = positions.find(h); found != positions.end()) {
        found->second.emplace(pos);
      } else {
        positions.emplace(h, decltype(positions)::mapped_type{pos});
      }
    };

    for (auto const& term : expr) {
      auto const node = binarize(term);
      if (has_only_single_atom(term)) {
        visitor(node);
      } else {
        node.visit_internal(visitor);
      }
      ++pos;
    }
  }

  container::map<pos_t, container::vector<pos_t>> connections;
  {
    for (auto const& [_, v] : positions) {
      auto const v0 = ranges::front(v);
      auto const v_ = ranges::views::tail(v) |
                      ranges::to<decltype(connections)::mapped_type>;
      if (auto&& found = connections.find(v0); found != connections.end())
        for (auto p : v_) found->second.push_back(p);
      else
        connections.emplace(v0, v_);
    }
  }
  positions.clear();

  container::vector<container::vector<pos_t>> result;
  {
    container::set<pos_t> visited;
    for (auto k : connections | ranges::views::keys)
      if (!visited.contains(k)) {
        stack_t dfs_stack;
        dfs_stack.push(k);
        container::vector<pos_t> clstr;
        while (!dfs_stack.empty()) {
          auto p = dfs_stack.top();
          dfs_stack.pop();
          if (!visited.contains(p)) {
            clstr.push_back(p);
            visited.emplace(p);
          }
          if (auto&& found = connections.find(p); found != connections.end())
            for (auto p_ : ranges::views::reverse(found->second))
              dfs_stack.push(p_);
        }
        result.emplace_back(std::move(clstr));
      }
  }
  return result;
}

Sum reorder(Sum const& sum) {
  Sum result;

  for (auto const& clstr : clusters(sum))
    for (auto p : clstr) result.append(sum.at(p));

  SEQUANT_ASSERT(result.size() == sum.size());
  return result;
}

}  // namespace sequant::opt