Program Listing for File eval.hpp

Return to documentation for file (SeQuant/core/eval/eval.hpp)

#ifndef SEQUANT_EVAL_EVAL_HPP
#define SEQUANT_EVAL_EVAL_HPP

#include <SeQuant/core/eval/fwd.hpp>

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/cache_manager.hpp>
#include <SeQuant/core/eval/eval_node.hpp>
#include <SeQuant/core/eval/result.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/io/serialization/serialization.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/core/utility/string.hpp>

#include <range/v3/range/operations.hpp>

#include <algorithm>
#include <any>
#include <chrono>
#include <iostream>
#include <optional>
#include <stdexcept>
#include <type_traits>

// Headers for process_rss_bytes() — see log::process_rss_bytes() below.
#if defined(__APPLE__)
#include <mach/mach.h>
#elif defined(__linux__)
#include <unistd.h>
#include <fstream>
#endif

namespace sequant {

namespace log {

using Duration = std::chrono::nanoseconds;

struct Bytes {
  size_t value;
};

[[nodiscard]] inline bool printing() noexcept {
  return Logger::instance().eval.level > 0;
}

template <typename T, typename... Ts>
[[nodiscard]] inline auto bytes(T const& arg, Ts const&... args) {
  auto one = [](auto const& a) -> size_t {
    if constexpr (requires {
                    static_cast<bool>(a);
                    a->size_in_bytes();
                  }) {
      // Smart-pointer-like operand: tolerate null so callers (e.g. the
      // EvalOp::Adjoint dispatcher, which leaves `right` unevaluated) can
      // pass an empty ResultPtr without an external guard.
      return a ? a->size_in_bytes() : size_t{0};
    } else if constexpr (requires { a->size_in_bytes(); })
      return a->size_in_bytes();
    else
      return a.size_in_bytes();
  };
  return Bytes{(one(arg) + ... + one(args))};
}

template <typename N, bool F, typename... Ts>
[[nodiscard]] inline Bytes bytes(CacheManager<N, F> const& cache,
                                 Ts const&... args) {
  if (!printing()) return Bytes{0};
  return Bytes{cache.size_in_bytes() + (size_t{0} + ... + bytes(args).value)};
}

[[nodiscard]] inline auto to_string(Bytes bs) noexcept {
  return std::format("{}B", bs.value);
}

[[nodiscard]] inline std::size_t process_rss_bytes() noexcept {
#if defined(__APPLE__)
  ::task_vm_info_data_t vm_info{};
  ::mach_msg_type_number_t vm_count = TASK_VM_INFO_COUNT;
  if (::task_info(::mach_task_self(), TASK_VM_INFO,
                  reinterpret_cast<::task_info_t>(&vm_info),
                  &vm_count) == KERN_SUCCESS &&
      vm_count >= TASK_VM_INFO_COUNT) {
    return static_cast<std::size_t>(vm_info.phys_footprint);
  }
  // Fallback: raw resident-set size (larger; includes shared pages).
  ::mach_task_basic_info_data_t info{};
  ::mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT;
  if (::task_info(::mach_task_self(), MACH_TASK_BASIC_INFO,
                  reinterpret_cast<::task_info_t>(&info),
                  &count) != KERN_SUCCESS) {
    return 0;
  }
  return static_cast<std::size_t>(info.resident_size);
#elif defined(__linux__)
  // /proc/self/statm columns are page counts:
  //   total resident shared text lib data dt
  std::ifstream f("/proc/self/statm");
  std::size_t pages_total = 0, pages_resident = 0;
  if (!(f >> pages_total >> pages_resident)) return 0;
  static const long page_size = ::sysconf(_SC_PAGESIZE);
  if (page_size <= 0) return 0;
  return pages_resident * static_cast<std::size_t>(page_size);
#else
  return 0;
#endif
}

[[nodiscard]] inline Bytes rss() noexcept { return Bytes{process_rss_bytes()}; }

enum struct EvalMode {
  Constant,
  Variable,
  Power,
  Tensor,
  Permute,
  Product,
  MultByPhase,
  Sum,
  SumInplace,
  Symmetrize,
  Antisymmetrize,
  Unknown
};

[[nodiscard]] EvalMode eval_mode(meta::eval_node auto const& node) {
  if (node.leaf()) {
    return node->is_constant()   ? EvalMode::Constant
           : node->is_variable() ? EvalMode::Variable
           : node->is_power()    ? EvalMode::Power
           : node->is_tensor()   ? EvalMode::Tensor
                                 : EvalMode::Unknown;
  } else {
    return node->is_product()   ? EvalMode::Product
           : node->is_sum()     ? EvalMode::Sum
           : node->is_adjoint() ? EvalMode::Permute
                                : EvalMode::Unknown;
  }
}

[[nodiscard]] constexpr auto to_string(EvalMode mode) noexcept {
  return (mode == EvalMode::Constant)         ? "Constant"
         : (mode == EvalMode::Variable)       ? "Variable"
         : (mode == EvalMode::Power)          ? "Power"
         : (mode == EvalMode::Tensor)         ? "Tensor"
         : (mode == EvalMode::Permute)        ? "Permute"
         : (mode == EvalMode::Product)        ? "Product"
         : (mode == EvalMode::MultByPhase)    ? "MultByPhase"
         : (mode == EvalMode::Sum)            ? "Sum"
         : (mode == EvalMode::SumInplace)     ? "SumInplace"
         : (mode == EvalMode::Symmetrize)     ? "Symmetrize"
         : (mode == EvalMode::Antisymmetrize) ? "Antisymmetrize"
                                              : "??";
}

enum struct CacheMode { Store, Access, Release };

[[nodiscard]] constexpr auto to_string(CacheMode mode) noexcept {
  return (mode == CacheMode::Store)    ? "Store"
         : (mode == CacheMode::Access) ? "Access"
                                       : "Release";
}

enum struct TermMode { Begin, End };

[[nodiscard]] constexpr auto to_string(TermMode mode) noexcept {
  return (mode == TermMode::Begin) ? "Begin" : "End";
}

// clang-format off
// clang-format on
struct EvalStat {
  EvalMode mode;
  Duration time;
  Bytes mem_result{};
  Bytes mem_alloc{};
  Bytes mem_hwmark{};
  std::optional<Bytes> mem_left;
  std::optional<Bytes> mem_right;
};

struct CacheStat {
  CacheMode mode;
  size_t key;
  int curr_life, max_life;
  size_t num_alive;
  Bytes entry_memory;
  Bytes total_memory;
};

template <typename Arg, typename... Args>
void log(Arg const& arg, Args const&... args) {
  auto& l = Logger::instance();
  if (l.eval.level > 0) write_log(l, arg, std::format(" | {}", args)..., '\n');
}

template <typename... Args>
auto eval(EvalStat const& stat, Args const&... args) {
  if (!printing()) return;  // nothing to format/emit; skip rss() and formatting
  auto const result_s = std::format("result={}", to_string(stat.mem_result));
  auto const alloc_s = std::format("alloc={}", to_string(stat.mem_alloc));
  auto const hw_s = std::format("hw={}", to_string(stat.mem_hwmark));
  auto const rss_s = std::format("rss={}", to_string(rss()));
  if (stat.mem_left) {
    SEQUANT_ASSERT(stat.mem_right);
    log("Eval",                                               //
        to_string(stat.mode),                                 //
        stat.time,                                            //
        std::format("left={}", to_string(*stat.mem_left)),    //
        std::format("right={}", to_string(*stat.mem_right)),  //
        result_s, alloc_s, hw_s, rss_s,                       //
        args...);
  } else {
    log("Eval",                //
        to_string(stat.mode),  //
        stat.time,             //
        result_s, alloc_s, hw_s, rss_s, args...);
  }
}

template <typename... Args>
auto cache(CacheStat const& stat, Args const&... args) {
  log("Cache",                                                   //
      to_string(stat.mode),                                      //
      std::format("key={}", stat.key),                           //
      std::format("life={}/{}", stat.curr_life, stat.max_life),  //
      std::format("alive={}", stat.num_alive),                   //
      std::format("entry={}", to_string(stat.entry_memory)),     //
      std::format("total={}", to_string(stat.total_memory)),     //
      args...);
}

template <typename N, bool F, typename... Args>
auto cache(N const& node, CacheManager<N, F>& cm, Args const&... args) {
  if (!printing()) return;  // skip the entry/total size walks and formatting
  using CacheMode::Access;
  using CacheMode::Release;
  using CacheMode::Store;
  auto const key = hash::value(*node);
  auto const cur_l = cm.life(node);
  auto const max_l = cm.max_life(node);
  bool const release = cur_l == 0;
  bool const store = cur_l + 1 == max_l;
  cache(CacheStat{.mode = store     ? Store
                          : release ? Release
                                    : Access,
                  .key = key,
                  .curr_life = cur_l,
                  .max_life = max_l,
                  .num_alive = cm.alive_count(),
                  .entry_memory = {cm.entry_size_in_bytes(node)},
                  .total_memory = {bytes(cm)}},
        args...);
}

inline auto term(TermMode mode, std::string_view term) {
  log("Term", to_string(mode), term);
}

[[nodiscard]] auto label(meta::eval_node auto const& node) {
  return node->is_primary()
             ? node->label()
             : std::format("{} {} {} -> {}", node.left()->label(),
                           (node->is_product() ? "*"
                            : node->is_sum()   ? "+"
                                               : "??"),  //
                           node.right()->label(), node->label());
}

}  // namespace log

namespace {

template <typename F, typename... Args>
[[nodiscard]] log::Duration timed_eval_inplace(F&& fun, Args&&... args)
  requires(std::is_invocable_r_v<void, F, Args...>)
{
  using Clock = std::chrono::high_resolution_clock;
  auto tstart = Clock::now();
  std::forward<F>(fun)(std::forward<Args>(args)...);
  auto tend = Clock::now();
  return {tend - tstart};
}

template <typename T>
constexpr bool is_cache_manager_v = false;

template <typename N, bool F>
constexpr bool is_cache_manager_v<CacheManager<N, F>> = true;

template <typename... Args>
concept last_type_is_cache_manager = is_cache_manager_v<std::remove_cvref_t<
    std::tuple_element_t<sizeof...(Args) - 1, std::tuple<Args...>>>>;

template <typename... Args>
auto&& arg0(Args&&... args) {
  return std::get<0>(std::forward_as_tuple(std::forward<Args>(args)...));
}

auto&& node0(auto&& val) { return std::forward<decltype(val)>(val); }
auto&& node0(std::ranges::range auto&& rng) {
  return ranges::front(std::forward<decltype(rng)>(rng));
}

enum struct CacheCheck { Checked, Unchecked };

}  // namespace

enum struct Trace {
  On,
  Off,
  Default =
#ifdef SEQUANT_EVAL_TRACE
      On
#else
      Off
#endif
};
static_assert(Trace::Default == Trace::On || Trace::Default == Trace::Off);

namespace {
[[nodiscard]] consteval bool trace(Trace t) noexcept { return t == Trace::On; }
}  // namespace

[[nodiscard]] inline Index::index_vector contracted_indices(
    meta::eval_node auto const& node) {
  Index::index_vector result;
  if (node.leaf() || !node->is_product()) return result;
  auto const& l = node.left()->canon_indices();
  auto const& r = node.right()->canon_indices();
  auto const& c = node->canon_indices();
  auto contains = [](auto const& vec, Index const& ix) {
    return std::find(vec.begin(), vec.end(), ix) != vec.end();
  };
  for (Index const& ix : l)
    if (contains(r, ix) && !contains(c, ix)) result.push_back(ix);
  return result;
}

template <typename IndexPredicate>
[[nodiscard]] inline std::optional<Index> batch_axis(
    meta::eval_node auto const& node, IndexPredicate const& accept) {
  std::optional<Index> best;
  for (Index const& ix : contracted_indices(node)) {
    if (!accept(ix)) continue;
    if (!best ||
        best->space().approximate_size() < ix.space().approximate_size())
      best = ix;
  }
  return best;
}

[[nodiscard]] inline std::optional<Index> batch_axis(
    meta::eval_node auto const& node) {
  return batch_axis(node, [](Index const&) { return true; });
}

[[nodiscard]] inline std::optional<std::size_t> index_position(
    meta::eval_node auto const& node, Index const& ix) {
  auto const& idxs = node->canon_indices();
  for (std::size_t p = 0; p < idxs.size(); ++p)
    if (idxs[p] == ix) return p;
  return std::nullopt;
}

template <typename Node>
[[nodiscard]] std::optional<std::pair<Node, std::size_t>> find_leaf_carrying(
    Node const& node, Index const& ix) {
  if (node.leaf()) {
    if (auto const p = index_position(node, ix)) return std::pair{node, *p};
    return std::nullopt;
  }
  if (auto found = find_leaf_carrying(node.left(), ix)) return found;
  return find_leaf_carrying(node.right(), ix);
}

template <Trace EvalTrace = Trace::Default,
          CacheCheck Cache = CacheCheck::Checked, meta::can_evaluate Node,
          typename F, typename N, bool FHC>
  requires meta::leaf_node_evaluator<Node, F>
ResultPtr evaluate(Node const& node,  //
                   F const& le,       //
                   CacheManager<N, FHC>& cache) {
  if constexpr (Cache == CacheCheck::Checked) {  // return from cache if found

    auto mult_by_phase = [&node, &cache](ResultPtr res) {
      auto phase = node->canon_phase();
      if (phase == 1) return res;

      ResultPtr post;
      auto time =
          timed_eval_inplace([&]() { post = res->mult_by_phase(phase); });

      if constexpr (trace(EvalTrace)) {
        size_t hwmark = log::bytes(cache, post).value;
        if (!cache.alive(node)) hwmark += log::bytes(res).value;
        auto stat =
            log::EvalStat{.mode = log::EvalMode::MultByPhase,
                          .time = time,
                          .mem_result = log::bytes(post),
                          .mem_alloc = log::bytes(post),
                          .mem_hwmark = {cache.note_working_set(hwmark)}};
        log::eval(stat, std::format("{} * {}", phase, node->label()));
      }
      return post;
    };

    if (auto ptr = cache.access(node); ptr) {
      if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));

      return mult_by_phase(ptr);
    } else if (cache.exists(node)) {
      auto ptr = cache.store(
          node, mult_by_phase(evaluate<EvalTrace, CacheCheck::Unchecked>(
                    node, le, cache)));
      if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));

      return mult_by_phase(ptr);
    } else {
      // do nothing
    }
  }

  ResultPtr result;
  ResultPtr left;
  ResultPtr right;

  log::Duration time;

  // Custom-evaluator interception: before the standard scheme, a non-leaf node
  // may be evaluated by the cache's custom evaluator (e.g. blocked over a
  // contracted index to bound peak memory). A non-null result is used (and
  // cached by the Checked wrapper) as-is; null declines to the standard scheme
  // below. See CacheManager::custom_evaluator_type.
  if (!node.leaf()) {
    if (auto const& custom_eval = cache.custom_evaluator(); custom_eval) {
      ResultPtr intercepted;
      time =
          timed_eval_inplace([&]() { intercepted = custom_eval(node, cache); });
      if (intercepted) {
        if constexpr (trace(EvalTrace)) {
          log::eval(log::EvalStat{.mode = log::eval_mode(node),
                                  .time = time,
                                  .mem_result = log::bytes(intercepted),
                                  .mem_alloc = log::bytes(intercepted),
                                  .mem_hwmark = {cache.note_working_set(
                                      log::bytes(cache, intercepted).value)}},
                    log::label(node));
        }
        return intercepted;
      }
    }
  }

  if (node.leaf()) {
    time = timed_eval_inplace([&]() { result = le(node); });
  } else if (node->op_type() == EvalOp::Adjoint) {
    // Unary IR op: dispatch on left operand only; right is the Constant(1)
    // sentinel kept around to preserve FullBinaryNode's invariant. We
    // intentionally skip evaluating the sentinel — leaf evaluators that
    // can't manufacture scalar constants (rare in practice but possible)
    // would otherwise be invoked needlessly.
    left = evaluate<EvalTrace>(node.left(), le, cache);
    SEQUANT_ASSERT(left);
    std::array<std::any, 2> const adj_ann{node.left()->annot(), node->annot()};
    time = timed_eval_inplace([&]() { result = left->adjoint(adj_ann); });
  } else {
    left = evaluate<EvalTrace>(node.left(), le, cache);
    right = evaluate<EvalTrace>(node.right(), le, cache);
    SEQUANT_ASSERT(left);
    SEQUANT_ASSERT(right);

    std::array<std::any, 3> const ann{node.left()->annot(),
                                      node.right()->annot(), node->annot()};
    if (node->op_type() == EvalOp::Sum) {
      time = timed_eval_inplace([&]() { result = left->sum(*right, ann); });
    } else {
      SEQUANT_ASSERT(node->op_type() == EvalOp::Product);
      auto const de_nest =
          node.left()->tot() && node.right()->tot() && !node->tot();
      time = timed_eval_inplace([&]() {
        result =
            left->prod(*right, ann, de_nest ? DeNest::True : DeNest::False);
      });
    }
  }

  SEQUANT_ASSERT(result);

  // logging
  if constexpr (trace(EvalTrace)) {
    if (node.leaf()) {
      log::eval(log::EvalStat{.mode = log::eval_mode(node),
                              .time = time,
                              .mem_result = log::bytes(result),
                              .mem_alloc = log::bytes(result),
                              .mem_hwmark = {cache.note_working_set(
                                  log::bytes(cache, result).value)}},
                log::label(node));
    } else {
      // A cached child is *distinct* from the local left/right when its
      // canon_phase != 1, because mult_by_phase allocates a fresh buffer
      // while the cache still holds the pre-phase data. So only skip the
      // local's bytes when the cache aliases the same buffer (phase == 1).
      // Adjoint nodes evaluate only the left operand (the right child is the
      // sentinel Constant(1) — see the Adjoint branch above), so `right` is
      // null; log::bytes() tolerates a null shared_ptr for that reason.
      size_t hwmark = log::bytes(cache, result).value;
      if (!cache.alive(node.left()) || node.left()->canon_phase() != 1)
        hwmark += log::bytes(left).value;
      if (right &&
          (!cache.alive(node.right()) || node.right()->canon_phase() != 1))
        hwmark += log::bytes(right).value;
      log::eval(log::EvalStat{.mode = log::eval_mode(node),
                              .time = time,
                              .mem_result = log::bytes(result),
                              .mem_alloc = log::bytes(result),
                              .mem_hwmark = {cache.note_working_set(hwmark)},
                              .mem_left = log::bytes(left),
                              .mem_right = log::bytes(right)},
                log::label(node));
    }
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate Node, typename F,
          typename N, bool FHC>
  requires meta::leaf_node_evaluator<Node, F>  //
ResultPtr evaluate(Node const& node,           //
                   auto const& layout,         //
                   F const& le,                //
                   CacheManager<N, FHC>& cache) {
  // if the layout is not the default constructed value need to permute
  bool const perm = layout != decltype(layout){};

  std::string xpr;
  if constexpr (trace(EvalTrace)) {
    xpr = toUtf8(io::serialization::to_string(to_expr(node)));
    log::term(log::TermMode::Begin, xpr);
  }

  struct {
    ResultPtr pre, post;
  } result;

  result.pre = evaluate<EvalTrace>(node, le, cache);

  auto time = timed_eval_inplace([&]() {
    result.post = perm ? result.pre->permute(
                             std::array<std::any, 2>{node->annot(), layout})
                       : result.pre;
  });

  SEQUANT_ASSERT(result.post);

  // logging
  if constexpr (trace(EvalTrace)) {
    if (perm) {
      // result.pre aliases the cache only when the inner evaluate returned
      // the cached buffer unchanged — i.e. the node is cached AND no
      // mult_by_phase fresh allocation happened (phase == 1).
      size_t hwmark = log::bytes(cache, result.post).value;
      if (!cache.alive(node) || node->canon_phase() != 1)
        hwmark += log::bytes(result.pre).value;
      auto stat = log::EvalStat{.mode = log::EvalMode::Permute,
                                .time = time,
                                .mem_result = log::bytes(result.post),
                                .mem_alloc = log::bytes(result.post),
                                .mem_hwmark = {cache.note_working_set(hwmark)}};
      log::eval(stat, node->label());
    }
    log::term(log::TermMode::End, xpr);
  }
  return result.post;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
          typename F, typename N, bool FHC>
  requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes,  //
                   auto const& layout,  //
                   F const& le, CacheManager<N, FHC>& cache) {
  ResultPtr result;

  // pre comes back from the permute-wrapping evaluate; it aliases the
  // cache only when the inner evaluate returned the cached buffer
  // unchanged — i.e. node cached, phase == 1, AND no permute happened.
  bool const layout_is_default = (layout == decltype(layout){});

  for (auto&& n : nodes) {
    if (!result) {
      result = evaluate<EvalTrace>(n, layout, le, cache);
      continue;
    }

    ResultPtr pre = evaluate<EvalTrace>(n, layout, le, cache);
    auto time = timed_eval_inplace([&]() { result->add_inplace(*pre); });

    // logging
    if constexpr (trace(EvalTrace)) {
      // SumInplace allocates nothing: it writes into the accumulator.
      // hwmark counts the cache plus both operands live at this moment;
      // skip pre's bytes only when pre is the cached buffer itself.
      size_t hwmark = log::bytes(cache, result).value;
      if (!cache.alive(n) || n->canon_phase() != 1 || !layout_is_default)
        hwmark += log::bytes(pre).value;
      auto stat = log::EvalStat{.mode = log::EvalMode::SumInplace,
                                .time = time,
                                .mem_result = log::bytes(result),
                                .mem_alloc = {0},
                                .mem_hwmark = {cache.note_working_set(hwmark)}};
      log::eval(stat, n->label());
    }
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
          typename F, typename N, bool FHC>
  requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes,  //
                   F const& le, CacheManager<N, FHC>& cache) {
  using annot_type = decltype([](std::ranges::range_value_t<Nodes> const& n) {
    return n->annot();
  });

  static_assert(std::is_default_constructible_v<annot_type>);
  return evaluate(nodes, annot_type{}, le, cache);
}

template <Trace EvalTrace = Trace::Default, typename... Args>
  requires(!last_type_is_cache_manager<Args...>)
ResultPtr evaluate(Args&&... args) {
  using Node =
      std::remove_cvref_t<decltype(node0(arg0(std::forward<Args>(args)...)))>;
  auto cache = CacheManager<Node>::empty();
  return evaluate<EvalTrace>(std::forward<Args>(args)..., cache);
}

template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_symm(Args&&... args) {
  ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
  SEQUANT_ASSERT(pre);
  ResultPtr result;
  auto time = timed_eval_inplace([&]() { result = pre->symmetrize(); });

  // logging
  if constexpr (trace(EvalTrace)) {
    // cache is owned by the inner evaluate call and out of scope here;
    // hwmark reflects only the local working set (pre + freshly allocated
    // result both live during the symmetrize op).
    auto stat = log::EvalStat{.mode = log::EvalMode::Symmetrize,
                              .time = time,
                              .mem_result = log::bytes(result),
                              .mem_alloc = log::bytes(result),
                              .mem_hwmark = log::bytes(pre, result)};
    log::eval(stat, node0(arg0(std::forward<Args>(args)...))->label());
  }

  return result;
}

template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_antisymm(Args&&... args) {
  ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
  SEQUANT_ASSERT(pre);

  auto const& n0 = node0(arg0(std::forward<Args>(args)...));

  ResultPtr result;
  auto time = timed_eval_inplace(
      [&]() { result = pre->antisymmetrize(n0->as_tensor().bra_rank()); });

  // logging
  if constexpr (trace(EvalTrace)) {
    // See Symmetrize for the rationale on hwmark.
    auto stat = log::EvalStat{.mode = log::EvalMode::Antisymmetrize,
                              .time = time,
                              .mem_result = log::bytes(result),
                              .mem_alloc = log::bytes(result),
                              .mem_hwmark = log::bytes(pre, result)};
    log::eval(stat, n0->label());
  }
  return result;
}

struct accept_any_index {
  bool operator()(Index const&) const noexcept { return true; }
};

struct no_scope_guard {};
struct make_no_scope_guard {
  no_scope_guard operator()(std::size_t /*n_batches*/) const noexcept {
    return {};
  }
};

template <typename F, typename IndexPredicate = accept_any_index,
          typename ScopeGuardFactory = make_no_scope_guard>
[[nodiscard]] auto make_batched_custom_evaluator(
    F le, std::size_t target_batch_size, IndexPredicate accept = {},
    ScopeGuardFactory make_scope_guard = {}) {
  return [le = std::move(le), target_batch_size, accept, make_scope_guard](
             auto const& node, auto& cache) -> ResultPtr {
    using cache_t = std::remove_reference_t<decltype(cache)>;

    auto const K = batch_axis(node, accept);
    if (!K) return nullptr;

    auto const leaf = find_leaf_carrying(node, *K);
    if (!leaf) return nullptr;
    auto const batches =
        le(leaf->first)->mode_batches(leaf->second, target_batch_size);

    if (batches.size() <= 1)
      return nullptr;  // nothing to gain (or unbatchable)

    // RAII scope for the batched partial contractions; a backend-supplied
    // factory may relax block-sparse screening here (scaled by the batch count)
    // so per-batch screening does not drop contributions that survive over the
    // full batch axis.
    auto const scope_guard = make_scope_guard(batches.size());
    (void)scope_guard;

    ResultPtr acc;
    for (auto const& [e_lo, e_hi] : batches) {
      if (e_lo == e_hi) continue;

      // leaf evaluator that slices every leaf carrying K to this element batch;
      // others pass through unchanged.
      auto le_g = [&le, &K, e_lo = e_lo,
                   e_hi = e_hi](auto const& leaf_node) -> ResultPtr {
        ResultPtr r = le(leaf_node);
        if (auto const p = index_position(leaf_node, *K))
          return r->slice_mode(*p, e_lo, e_hi);
        return r;
      };

      // standard scheme on a fresh scratch cache: no re-interception, and the
      // (partial, sliced) intermediates do not pollute the real cache.
      auto scratch = cache_t::empty();
      ResultPtr part = evaluate(node, le_g, scratch);
      if (!acc)
        acc = std::move(part);
      else
        acc->add_inplace(*part);
    }
    return acc;
  };
}

}  // namespace sequant

#endif  // SEQUANT_EVAL_EVAL_HPP