Program Listing for File eval.hpp

Return to documentation for file (SeQuant/core/eval/eval.hpp)

#ifndef SEQUANT_EVAL_EVAL_HPP
#define SEQUANT_EVAL_EVAL_HPP

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/cache_manager.hpp>
#include <SeQuant/core/eval/eval_fwd.hpp>
#include <SeQuant/core/eval/result.hpp>
#include <SeQuant/core/eval_node.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/parse.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <btas/btas.h>
#include <tiledarray.h>

#include <chrono>
#include <range/v3/numeric.hpp>
#include <range/v3/view.hpp>

#include <any>
#include <iostream>
#include <stdexcept>
#include <type_traits>

namespace sequant {

namespace log {

using Duration = std::chrono::nanoseconds;

struct Bytes {
  size_t value;
};

template <typename... T>
  requires((std::same_as<ResultPtr, T> && ...))
[[nodiscard]] auto bytes(T const&... args) {
  return Bytes{(args->size_in_bytes() + ...)};
}

[[nodiscard]] inline auto bytes(CacheManager const& cman) {
  return cman.size_in_bytes();
}

[[nodiscard]] inline auto to_string(Bytes bs) noexcept {
  return std::format("{}B", bs.value);
}

enum struct EvalMode {
  Constant,
  Variable,
  Tensor,
  Permute,
  Product,
  MultByPhase,
  Sum,
  SumInplace,
  Symmetrize,
  Antisymmetrize,
  BiorthogonalNNSProject,
  Unknown
};

[[nodiscard]] EvalMode eval_mode(meta::eval_node auto const& node) {
  if (node.leaf()) {
    return node->is_constant()   ? EvalMode::Constant
           : node->is_variable() ? EvalMode::Variable
           : node->is_tensor()   ? EvalMode::Tensor
                                 : EvalMode::Unknown;
  } else {
    return node->is_product() ? EvalMode::Product
           : node->is_sum()   ? EvalMode::Sum
                              : EvalMode::Unknown;
  }
}

[[nodiscard]] constexpr auto to_string(EvalMode mode) noexcept {
  return (mode == EvalMode::Constant)                 ? "Constant"
         : (mode == EvalMode::Variable)               ? "Variable"
         : (mode == EvalMode::Tensor)                 ? "Tensor"
         : (mode == EvalMode::Permute)                ? "Permute"
         : (mode == EvalMode::Product)                ? "Product"
         : (mode == EvalMode::MultByPhase)            ? "MultByPhase"
         : (mode == EvalMode::Sum)                    ? "Sum"
         : (mode == EvalMode::SumInplace)             ? "SumInplace"
         : (mode == EvalMode::Symmetrize)             ? "Symmetrize"
         : (mode == EvalMode::Antisymmetrize)         ? "Antisymmetrize"
         : (mode == EvalMode::BiorthogonalNNSProject) ? "BiorthogonalNNSProject"
                                                      : "??";
}

enum struct CacheMode { Store, Access, Release };

[[nodiscard]] constexpr auto to_string(CacheMode mode) noexcept {
  return (mode == CacheMode::Store)    ? "Store"
         : (mode == CacheMode::Access) ? "Access"
                                       : "Release";
}

enum struct TermMode { Begin, End };

[[nodiscard]] constexpr auto to_string(TermMode mode) noexcept {
  return (mode == TermMode::Begin) ? "Begin" : "End";
}

struct EvalStat {
  EvalMode mode;
  Duration time;
  Bytes memory;
};

struct CacheStat {
  CacheMode mode;
  size_t key;
  int curr_life, max_life;
  size_t num_alive;
  Bytes memory;
};

template <typename Arg, typename... Args>
void log(Arg const& arg, Args const&... args) {
  auto& l = Logger::instance();
  if (l.eval.level > 0) write_log(l, arg, std::format(" | {}", args)..., '\n');
}

template <typename... Args>
auto eval(EvalStat const& stat, Args const&... args) {
  log("Eval",                  //
      to_string(stat.mode),    //
      stat.time,               //
      to_string(stat.memory),  //
      args...);
}

template <typename... Args>
auto cache(CacheStat const& stat, Args const&... args) {
  log("Cache",                                              //
      to_string(stat.mode),                                 //
      stat.key,                                             //
      std::format("{}/{}", stat.curr_life, stat.max_life),  //
      stat.num_alive,                                       //
      to_string(stat.memory),                               //
      args...);
}

template <typename... Args>
auto cache(size_t key, CacheManager const& cm, Args const&... args) {
  using CacheMode::Access;
  using CacheMode::Release;
  using CacheMode::Store;
  auto const cur_l = cm.life(key);
  auto const max_l = cm.max_life(key);
  bool const release = cur_l == 0;
  bool const store = cur_l + 1 == max_l;
  cache(CacheStat{.mode = store     ? Store
                          : release ? Release
                                    : Access,
                  .key = key,
                  .curr_life = cur_l,
                  .max_life = max_l,
                  .num_alive = cm.alive_count(),
                  .memory = {bytes(cm)}},
        args...);
}

inline auto term(TermMode mode, std::string_view term) {
  log("Term", to_string(mode), term);
}

[[nodiscard]] auto label(meta::eval_node auto const& node) {
  return node->is_primary()
             ? node->label()
             : std::format("{} {} {} -> {}", node.left()->label(),
                           (node->is_product() ? "*"
                            : node->is_sum()   ? "+"
                                               : "??"),  //
                           node.right()->label(), node->label());
}

}  // namespace log

namespace {

template <typename F, typename... Args>
[[nodiscard]] log::Duration timed_eval_inplace(F&& fun, Args&&... args)
  requires(std::is_invocable_r_v<void, F, Args...>)
{
  using Clock = std::chrono::high_resolution_clock;
  auto tstart = Clock::now();
  std::forward<F>(fun)(std::forward<Args>(args)...);
  auto tend = Clock::now();
  return {tend - tstart};
}

template <typename... Args>
concept last_type_is_cache_manager =
    std::same_as<CacheManager, std::remove_cvref_t<std::tuple_element_t<
                                   sizeof...(Args) - 1, std::tuple<Args...>>>>;

template <typename... Args>
auto&& arg0(Args&&... args) {
  return std::get<0>(std::forward_as_tuple(std::forward<Args>(args)...));
}

auto&& node0(auto&& val) { return std::forward<decltype(val)>(val); }
auto&& node0(std::ranges::range auto&& rng) {
  return ranges::front(std::forward<decltype(rng)>(rng));
}

enum struct CacheCheck { Checked, Unchecked };

}  // namespace

enum struct Trace {
  On,
  Off,
  Default =
#ifdef SEQUANT_EVAL_TRACE
      On
#else
      Off
#endif
};
static_assert(Trace::Default == Trace::On || Trace::Default == Trace::Off);

namespace {
[[nodiscard]] consteval bool trace(Trace t) noexcept { return t == Trace::On; }
}  // namespace

template <Trace EvalTrace = Trace::Default,
          CacheCheck Cache = CacheCheck::Checked, meta::can_evaluate Node,
          typename F>
  requires meta::leaf_node_evaluator<Node, F>
ResultPtr evaluate(Node const& node,  //
                   F const& le,       //
                   CacheManager& cache) {
  if constexpr (Cache == CacheCheck::Checked) {  // return from cache if found

    auto mult_by_phase = [&node](ResultPtr res) {
      auto phase = node->canon_phase();
      if (phase == 1) return res;

      ResultPtr post;
      auto time =
          timed_eval_inplace([&]() { post = res->mult_by_phase(phase); });

      if constexpr (trace(EvalTrace)) {
        auto stat = log::EvalStat{.mode = log::EvalMode::MultByPhase,
                                  .time = time,
                                  .memory = log::bytes(res, post)};
        log::eval(stat, std::format("{} * {}", phase, node->label()));
      }
      return post;
    };

    auto const h = hash::value(*node);
    if (auto ptr = cache.access(h); ptr) {
      if constexpr (trace(EvalTrace)) log::cache(h, cache);

      return mult_by_phase(ptr);
    } else if (cache.exists(h)) {
      auto ptr = cache.store(
          h, mult_by_phase(
                 evaluate<EvalTrace, CacheCheck::Unchecked>(node, le, cache)));
      if constexpr (trace(EvalTrace)) log::cache(h, cache);

      return mult_by_phase(ptr);
    } else {
      // do nothing
    }
  }

  ResultPtr result;
  ResultPtr left;
  ResultPtr right;

  log::Duration time;

  if (node.leaf()) {
    time = timed_eval_inplace([&]() { result = le(node); });
  } else {
    left = evaluate<EvalTrace>(node.left(), le, cache);
    right = evaluate<EvalTrace>(node.right(), le, cache);
    SEQUANT_ASSERT(left);
    SEQUANT_ASSERT(right);

    std::array<std::any, 3> const ann{node.left()->annot(),
                                      node.right()->annot(), node->annot()};
    if (node->op_type() == EvalOp::Sum) {
      time = timed_eval_inplace([&]() { result = left->sum(*right, ann); });
    } else {
      SEQUANT_ASSERT(node->op_type() == EvalOp::Product);
      auto const de_nest =
          node.left()->tot() && node.right()->tot() && !node->tot();
      time = timed_eval_inplace([&]() {
        result = left->prod(*right, ann,
                            de_nest ? TA::DeNest::True : TA::DeNest::False);
      });
    }
  }

  SEQUANT_ASSERT(result);

  // logging
  if constexpr (trace(EvalTrace)) {
    auto stat =
        log::EvalStat{.mode = log::eval_mode(node),
                      .time = time,
                      .memory = node.leaf() ? log::bytes(result)
                                            : log::bytes(left, right, result)};
    log::eval(stat, log::label(node));
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate Node, typename F>
  requires meta::leaf_node_evaluator<Node, F>  //
ResultPtr evaluate(Node const& node,           //
                   auto const& layout,         //
                   F const& le,                //
                   CacheManager& cache) {
  // if the layout is not the default constructed value need to permute
  bool const perm = layout != decltype(layout){};

  std::string xpr;
  if constexpr (trace(EvalTrace)) {
    xpr = to_string(deparse(to_expr(node)));
    log::term(log::TermMode::Begin, xpr);
  }

  struct {
    ResultPtr pre, post;
  } result;

  result.pre = evaluate<EvalTrace>(node, le, cache);

  auto time = timed_eval_inplace([&]() {
    result.post = perm ? result.pre->permute(
                             std::array<std::any, 2>{node->annot(), layout})
                       : result.pre;
  });

  SEQUANT_ASSERT(result.post);

  // logging
  if constexpr (trace(EvalTrace)) {
    if (perm) {
      auto stat = log::EvalStat{.mode = log::EvalMode::Permute,
                                .time = time,
                                .memory = log::bytes(result.pre, result.post)};
      log::eval(stat, node->label());
    }
    log::term(log::TermMode::End, xpr);
  }
  return result.post;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
          typename F>
  requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes,  //
                   auto const& layout,  //
                   F const& le, CacheManager& cache) {
  ResultPtr result;

  for (auto&& n : nodes) {
    if (!result) {
      result = evaluate<EvalTrace>(n, layout, le, cache);
      continue;
    }

    ResultPtr pre = evaluate<EvalTrace>(n, layout, le, cache);
    auto time = timed_eval_inplace([&]() { result->add_inplace(*pre); });

    // logging
    if constexpr (trace(EvalTrace)) {
      auto stat = log::EvalStat{.mode = log::EvalMode::SumInplace,
                                .time = time,
                                .memory = log::bytes(result, pre)};
      log::eval(stat, n->label());
    }
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
          typename F>
  requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes,  //
                   F const& le, CacheManager& cache) {
  using annot_type = decltype([](std::ranges::range_value_t<Nodes> const& n) {
    return n->annot();
  });

  static_assert(std::is_default_constructible_v<annot_type>);
  return evaluate(nodes, annot_type{}, le, cache);
}

template <Trace EvalTrace = Trace::Default, typename... Args>
  requires(!last_type_is_cache_manager<Args...>)
ResultPtr evaluate(Args&&... args) {
  auto cache = CacheManager::empty();
  return evaluate<EvalTrace>(std::forward<Args>(args)..., cache);
}

template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_symm(Args&&... args) {
  ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
  SEQUANT_ASSERT(pre);
  ResultPtr result;
  auto time = timed_eval_inplace([&]() { result = pre->symmetrize(); });

  // logging
  if constexpr (trace(EvalTrace)) {
    auto stat = log::EvalStat{.mode = log::EvalMode::Symmetrize,
                              .time = time,
                              .memory = log::bytes(pre, result)};
    log::eval(stat, node0(arg0(std::forward<Args>(args)...))->label());
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_antisymm(Args&&... args) {
  ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
  SEQUANT_ASSERT(pre);

  auto const& n0 = node0(arg0(std::forward<Args>(args)...));

  ResultPtr result;
  auto time = timed_eval_inplace(
      [&]() { result = pre->antisymmetrize(n0->as_tensor().bra_rank()); });

  // logging
  if constexpr (trace(EvalTrace)) {
    auto stat = log::EvalStat{.mode = log::EvalMode::Antisymmetrize,
                              .time = time,
                              .memory = log::bytes(pre, result)};
    log::eval(stat, n0->label());
  }
  return result;
}

template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_biorthogonal_nns_project(Args&&... args) {
  ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
  SEQUANT_ASSERT(pre);

  auto const& n0 = node0(arg0(std::forward<Args>(args)...));

  ResultPtr result;
  auto time = timed_eval_inplace([&]() {
    result = pre->biorthogonal_nns_project(n0->as_tensor().bra_rank());
  });

  // logging
  if constexpr (trace(EvalTrace)) {
    auto stat = log::EvalStat{.mode = log::EvalMode::BiorthogonalNNSProject,
                              .time = time,
                              .memory = log::bytes(pre, result)};
    log::eval(stat, n0->label());
  }
  return result;
}
}  // namespace sequant

#endif  // SEQUANT_EVAL_EVAL_HPP