Program Listing for File eval_node.hpp

Return to documentation for file (SeQuant/core/eval_node.hpp)

//
// Created by Bimal Gaudel on 5/24/21.
//

#ifndef SEQUANT_EVAL_NODE_HPP
#define SEQUANT_EVAL_NODE_HPP

#include <SeQuant/core/asy_cost.hpp>
#include <SeQuant/core/binary_node.hpp>
#include <SeQuant/core/eval_expr.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/math.hpp>
#include <SeQuant/core/utility/macros.hpp>

namespace sequant {

template <meta::eval_node Node>
ExprPtr linearize_eval_node(Node const& node) {
  if (node.leaf()) return to_expr(node);

  ExprPtr lres = linearize_eval_node(node.left());
  ExprPtr rres = linearize_eval_node(node.right());

  SEQUANT_ASSERT(lres);
  SEQUANT_ASSERT(rres);

  if (node->op_type() == EvalOp::Sum) return ex<Sum>(ExprPtrList{lres, rres});
  SEQUANT_ASSERT(node->op_type() == EvalOp::Product);
  return ex<Product>(
      Product{1, ExprPtrList{lres, rres}, Product::Flatten::Yes});
}

namespace {

enum NodePos { Left = 0, Right, This };

[[maybe_unused]] std::pair<size_t, size_t> occ_virt(Tensor const& t) {
  auto bk_rank = t.bra_net_rank() + t.ket_net_rank();
  auto nocc = ranges::count_if(t.const_braket_indices(), [](Index const& idx) {
    return idx.space() ==
           get_default_context().index_space_registry()->hole_space(
               idx.space().qns());
  });
  auto nvirt = bk_rank - nocc;
  return {nocc, nvirt};
}

class ContractedIndexCount {
 public:
  explicit ContractedIndexCount(meta::eval_node auto const& n) {
    auto const L = NodePos::Left;
    auto const R = NodePos::Right;
    auto const T = NodePos::This;

    SEQUANT_ASSERT(n->is_tensor() && n.left()->is_tensor() &&
                   n.right()->is_tensor());

    for (auto p : {L, R, T}) {
      auto const& t = (p == L ? n.left() : p == R ? n.right() : n)->as_tensor();
      std::tie(occs_[p], virts_[p]) = occ_virt(t);
      ranks_[p] = occs_[p] + virts_[p];
    }

    // no. of contractions in occupied index space (always a whole number)
    occ_ = (occs_[L] + occs_[R] - occs_[T]) / 2;

    // no. of contractions in virtual index space (always a whole number)
    virt_ = (virts_[L] + virts_[R] - virts_[T]) / 2;

    is_outerprod_ = ranks_[L] + ranks_[R] == ranks_[T];
  }

  [[nodiscard]] size_t occ(NodePos p) const noexcept { return occs_[p]; }

  [[nodiscard]] size_t virt(NodePos p) const noexcept { return virts_[p]; }

  [[nodiscard]] size_t rank(NodePos p) const noexcept { return ranks_[p]; }

  [[nodiscard]] size_t occ() const noexcept { return occ_; }

  [[nodiscard]] size_t virt() const noexcept { return virt_; }

  [[nodiscard]] bool is_outerpod() const noexcept { return is_outerprod_; }

  [[nodiscard]] size_t unique_occs() const noexcept {
    return occ(NodePos::Left) + occ(NodePos::Right) - occ();
  }

  [[nodiscard]] size_t unique_virts() const noexcept {
    return virt(NodePos::Left) + virt(NodePos::Right) - virt();
  }

 private:
  std::array<size_t, 3> occs_{0, 0, 0};
  std::array<size_t, 3> virts_{0, 0, 0};
  std::array<size_t, 3> ranks_{0, 0, 0};
  size_t occ_ = 0;
  size_t virt_ = 0;
  bool is_outerprod_ = false;
};
}  // namespace

struct Flops {
  [[nodiscard]] AsyCost operator()(meta::eval_node auto const& n) const {
    if (n.leaf()) return AsyCost::zero();
    if (n->op_type() == EvalOp::Product  //
        && n.left()->is_tensor()         //
        && n.right()->is_tensor()) {
      auto const idx_count = ContractedIndexCount{n};
      auto c = AsyCost{idx_count.unique_occs(), idx_count.unique_virts()};
      return idx_count.is_outerpod() ? c : 2 * c;
    } else if (n->is_tensor()) {
      // scalar times a tensor
      // or a tensor plus a tensor
      return AsyCost{occ_virt(n->as_tensor())};
    } else /* scalar (+|*) scalar */
      return AsyCost::zero();
  }
};

struct Memory {
  [[nodiscard]] AsyCost operator()(meta::eval_node auto const& n) const {
    AsyCost result;
    auto add_cost = [&result](ExprPtr const& expr) {
      result += expr.is<Tensor>() ? AsyCost{occ_virt(expr.as<Tensor>())}
                                  : AsyCost::zero();
    };

    add_cost(n.left()->expr());
    add_cost(n.right()->expr());
    add_cost(n->expr());
    return result;
  }
};

struct FlopsWithSymm {
  [[nodiscard]] AsyCost operator()(meta::eval_node auto const& n) const {
    auto cost = Flops{}(n);
    if (n.leaf() || !(n->is_tensor()            //
                      && n.left()->is_tensor()  //
                      && n.right()->is_tensor()))
      return cost;

    // confirmed:
    // left, right and this node
    // all have tensor expression
    auto const& t = n->as_tensor();
    auto const tsymm = t.symmetry();
    //

    // ------
    // the rules of cost reduction are taken from
    //   doi:10.1016/j.procs.2012.04.044
    // ------
    if (tsymm == Symmetry::Symm || tsymm == Symmetry::Antisymm) {
      auto const op = n->op_type();
      auto const tbrank = t.bra_rank();
      auto const tkrank = t.ket_rank();
      if (op == EvalOp::Sum)
        cost = tsymm == Symmetry::Symm
                   ? cost / (factorial(tbrank) * factorial(tkrank))
                   : cost / factorial(tbrank);
      else if (op == EvalOp::Product) {
        auto const lsymm = n.left()->as_tensor().symmetry();
        auto const rsymm = n.right()->as_tensor().symmetry();
        cost = (lsymm == rsymm && lsymm == Symmetry::Nonsymm)
                   ? cost / factorial(tbrank)
                   : cost / (factorial(tbrank) * factorial(tkrank));
      } else
        SEQUANT_ASSERT(false &&
                       "Unsupported evaluation operation for asymptotic cost "
                       "computation.");
    }
    return cost;
  }
};

template <meta::eval_node Node, typename F = Flops>
  requires requires(F const& fn, Node const& n) {
    { fn(n) } -> std::same_as<AsyCost>;
  }
AsyCost asy_cost(Node const& node, F const& cost_fn = {}) {
  return node.leaf() ? cost_fn(node)
                     : asy_cost(node.left(), cost_fn) +
                           asy_cost(node.right(), cost_fn) + cost_fn(node);
}

AsyCost min_storage(meta::eval_node auto const& node) {
  auto result = AsyCost::zero();
  auto visitor = [&result](meta::eval_node auto const& n) {
    auto cost = AsyCost::zero();
    if (n.leaf() && n->is_tensor())
      cost = AsyCost{occ_virt(n->as_tensor())};
    else if (!n.leaf()) {
      cost += (n.left()->is_tensor() ? AsyCost{occ_virt(n.left()->as_tensor())}
                                     : AsyCost::zero());
      cost +=
          (n.right()->is_tensor() ? AsyCost{occ_virt(n.right()->as_tensor())}
                                  : AsyCost::zero());
      cost += (n->is_tensor() ? AsyCost{occ_virt(n->as_tensor())}
                              : AsyCost::zero());
    } else {
      // do nothing
    }
    result = std::max(result, cost);
  };
  node.visit(visitor);
  return result;
}

}  // namespace sequant

#endif  // SEQUANT_EVAL_NODE_HPP