Program Listing for File eval.hpp¶
↰ Return to documentation for file (SeQuant/core/eval/eval.hpp)
#ifndef SEQUANT_EVAL_EVAL_HPP
#define SEQUANT_EVAL_EVAL_HPP
#include <SeQuant/core/eval/fwd.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/cache_manager.hpp>
#include <SeQuant/core/eval/eval_node.hpp>
#include <SeQuant/core/eval/result.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/io/serialization/serialization.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/core/utility/string.hpp>
#include <range/v3/range/operations.hpp>
#include <algorithm>
#include <any>
#include <chrono>
#include <iostream>
#include <optional>
#include <stdexcept>
#include <type_traits>
// Headers for process_rss_bytes() — see log::process_rss_bytes() below.
#if defined(__APPLE__)
#include <mach/mach.h>
#elif defined(__linux__)
#include <unistd.h>
#include <fstream>
#endif
namespace sequant {
namespace log {
using Duration = std::chrono::nanoseconds;
struct Bytes {
size_t value;
};
[[nodiscard]] inline bool printing() noexcept {
return Logger::instance().eval.level > 0;
}
template <typename T, typename... Ts>
[[nodiscard]] inline auto bytes(T const& arg, Ts const&... args) {
auto one = [](auto const& a) -> size_t {
if constexpr (requires {
static_cast<bool>(a);
a->size_in_bytes();
}) {
// Smart-pointer-like operand: tolerate null so callers (e.g. the
// EvalOp::Adjoint dispatcher, which leaves `right` unevaluated) can
// pass an empty ResultPtr without an external guard.
return a ? a->size_in_bytes() : size_t{0};
} else if constexpr (requires { a->size_in_bytes(); })
return a->size_in_bytes();
else
return a.size_in_bytes();
};
return Bytes{(one(arg) + ... + one(args))};
}
template <typename N, bool F, typename... Ts>
[[nodiscard]] inline Bytes bytes(CacheManager<N, F> const& cache,
Ts const&... args) {
if (!printing()) return Bytes{0};
return Bytes{cache.size_in_bytes() + (size_t{0} + ... + bytes(args).value)};
}
[[nodiscard]] inline auto to_string(Bytes bs) noexcept {
return std::format("{}B", bs.value);
}
[[nodiscard]] inline std::size_t process_rss_bytes() noexcept {
#if defined(__APPLE__)
::task_vm_info_data_t vm_info{};
::mach_msg_type_number_t vm_count = TASK_VM_INFO_COUNT;
if (::task_info(::mach_task_self(), TASK_VM_INFO,
reinterpret_cast<::task_info_t>(&vm_info),
&vm_count) == KERN_SUCCESS &&
vm_count >= TASK_VM_INFO_COUNT) {
return static_cast<std::size_t>(vm_info.phys_footprint);
}
// Fallback: raw resident-set size (larger; includes shared pages).
::mach_task_basic_info_data_t info{};
::mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT;
if (::task_info(::mach_task_self(), MACH_TASK_BASIC_INFO,
reinterpret_cast<::task_info_t>(&info),
&count) != KERN_SUCCESS) {
return 0;
}
return static_cast<std::size_t>(info.resident_size);
#elif defined(__linux__)
// /proc/self/statm columns are page counts:
// total resident shared text lib data dt
std::ifstream f("/proc/self/statm");
std::size_t pages_total = 0, pages_resident = 0;
if (!(f >> pages_total >> pages_resident)) return 0;
static const long page_size = ::sysconf(_SC_PAGESIZE);
if (page_size <= 0) return 0;
return pages_resident * static_cast<std::size_t>(page_size);
#else
return 0;
#endif
}
[[nodiscard]] inline Bytes rss() noexcept { return Bytes{process_rss_bytes()}; }
enum struct EvalMode {
Constant,
Variable,
Power,
Tensor,
Permute,
Product,
MultByPhase,
Sum,
SumInplace,
Symmetrize,
Antisymmetrize,
Unknown
};
[[nodiscard]] EvalMode eval_mode(meta::eval_node auto const& node) {
if (node.leaf()) {
return node->is_constant() ? EvalMode::Constant
: node->is_variable() ? EvalMode::Variable
: node->is_power() ? EvalMode::Power
: node->is_tensor() ? EvalMode::Tensor
: EvalMode::Unknown;
} else {
return node->is_product() ? EvalMode::Product
: node->is_sum() ? EvalMode::Sum
: node->is_adjoint() ? EvalMode::Permute
: EvalMode::Unknown;
}
}
[[nodiscard]] constexpr auto to_string(EvalMode mode) noexcept {
return (mode == EvalMode::Constant) ? "Constant"
: (mode == EvalMode::Variable) ? "Variable"
: (mode == EvalMode::Power) ? "Power"
: (mode == EvalMode::Tensor) ? "Tensor"
: (mode == EvalMode::Permute) ? "Permute"
: (mode == EvalMode::Product) ? "Product"
: (mode == EvalMode::MultByPhase) ? "MultByPhase"
: (mode == EvalMode::Sum) ? "Sum"
: (mode == EvalMode::SumInplace) ? "SumInplace"
: (mode == EvalMode::Symmetrize) ? "Symmetrize"
: (mode == EvalMode::Antisymmetrize) ? "Antisymmetrize"
: "??";
}
enum struct CacheMode { Store, Access, Release };
[[nodiscard]] constexpr auto to_string(CacheMode mode) noexcept {
return (mode == CacheMode::Store) ? "Store"
: (mode == CacheMode::Access) ? "Access"
: "Release";
}
enum struct TermMode { Begin, End };
[[nodiscard]] constexpr auto to_string(TermMode mode) noexcept {
return (mode == TermMode::Begin) ? "Begin" : "End";
}
// clang-format off
// clang-format on
struct EvalStat {
EvalMode mode;
Duration time;
Bytes mem_result{};
Bytes mem_alloc{};
Bytes mem_hwmark{};
std::optional<Bytes> mem_left;
std::optional<Bytes> mem_right;
};
struct CacheStat {
CacheMode mode;
size_t key;
int curr_life, max_life;
size_t num_alive;
Bytes entry_memory;
Bytes total_memory;
};
template <typename Arg, typename... Args>
void log(Arg const& arg, Args const&... args) {
auto& l = Logger::instance();
if (l.eval.level > 0) write_log(l, arg, std::format(" | {}", args)..., '\n');
}
template <typename... Args>
auto eval(EvalStat const& stat, Args const&... args) {
if (!printing()) return; // nothing to format/emit; skip rss() and formatting
auto const result_s = std::format("result={}", to_string(stat.mem_result));
auto const alloc_s = std::format("alloc={}", to_string(stat.mem_alloc));
auto const hw_s = std::format("hw={}", to_string(stat.mem_hwmark));
auto const rss_s = std::format("rss={}", to_string(rss()));
if (stat.mem_left) {
SEQUANT_ASSERT(stat.mem_right);
log("Eval", //
to_string(stat.mode), //
stat.time, //
std::format("left={}", to_string(*stat.mem_left)), //
std::format("right={}", to_string(*stat.mem_right)), //
result_s, alloc_s, hw_s, rss_s, //
args...);
} else {
log("Eval", //
to_string(stat.mode), //
stat.time, //
result_s, alloc_s, hw_s, rss_s, args...);
}
}
template <typename... Args>
auto cache(CacheStat const& stat, Args const&... args) {
log("Cache", //
to_string(stat.mode), //
std::format("key={}", stat.key), //
std::format("life={}/{}", stat.curr_life, stat.max_life), //
std::format("alive={}", stat.num_alive), //
std::format("entry={}", to_string(stat.entry_memory)), //
std::format("total={}", to_string(stat.total_memory)), //
args...);
}
template <typename N, bool F, typename... Args>
auto cache(N const& node, CacheManager<N, F>& cm, Args const&... args) {
if (!printing()) return; // skip the entry/total size walks and formatting
using CacheMode::Access;
using CacheMode::Release;
using CacheMode::Store;
auto const key = hash::value(*node);
auto const cur_l = cm.life(node);
auto const max_l = cm.max_life(node);
bool const release = cur_l == 0;
bool const store = cur_l + 1 == max_l;
cache(CacheStat{.mode = store ? Store
: release ? Release
: Access,
.key = key,
.curr_life = cur_l,
.max_life = max_l,
.num_alive = cm.alive_count(),
.entry_memory = {cm.entry_size_in_bytes(node)},
.total_memory = {bytes(cm)}},
args...);
}
inline auto term(TermMode mode, std::string_view term) {
log("Term", to_string(mode), term);
}
[[nodiscard]] auto label(meta::eval_node auto const& node) {
return node->is_primary()
? node->label()
: std::format("{} {} {} -> {}", node.left()->label(),
(node->is_product() ? "*"
: node->is_sum() ? "+"
: "??"), //
node.right()->label(), node->label());
}
} // namespace log
namespace {
template <typename F, typename... Args>
[[nodiscard]] log::Duration timed_eval_inplace(F&& fun, Args&&... args)
requires(std::is_invocable_r_v<void, F, Args...>)
{
using Clock = std::chrono::high_resolution_clock;
auto tstart = Clock::now();
std::forward<F>(fun)(std::forward<Args>(args)...);
auto tend = Clock::now();
return {tend - tstart};
}
template <typename T>
constexpr bool is_cache_manager_v = false;
template <typename N, bool F>
constexpr bool is_cache_manager_v<CacheManager<N, F>> = true;
template <typename... Args>
concept last_type_is_cache_manager = is_cache_manager_v<std::remove_cvref_t<
std::tuple_element_t<sizeof...(Args) - 1, std::tuple<Args...>>>>;
template <typename... Args>
auto&& arg0(Args&&... args) {
return std::get<0>(std::forward_as_tuple(std::forward<Args>(args)...));
}
auto&& node0(auto&& val) { return std::forward<decltype(val)>(val); }
auto&& node0(std::ranges::range auto&& rng) {
return ranges::front(std::forward<decltype(rng)>(rng));
}
enum struct CacheCheck { Checked, Unchecked };
} // namespace
enum struct Trace {
On,
Off,
Default =
#ifdef SEQUANT_EVAL_TRACE
On
#else
Off
#endif
};
static_assert(Trace::Default == Trace::On || Trace::Default == Trace::Off);
namespace {
[[nodiscard]] consteval bool trace(Trace t) noexcept { return t == Trace::On; }
} // namespace
[[nodiscard]] inline Index::index_vector contracted_indices(
meta::eval_node auto const& node) {
Index::index_vector result;
if (node.leaf() || !node->is_product()) return result;
auto const& l = node.left()->canon_indices();
auto const& r = node.right()->canon_indices();
auto const& c = node->canon_indices();
auto contains = [](auto const& vec, Index const& ix) {
return std::find(vec.begin(), vec.end(), ix) != vec.end();
};
for (Index const& ix : l)
if (contains(r, ix) && !contains(c, ix)) result.push_back(ix);
return result;
}
template <typename IndexPredicate>
[[nodiscard]] inline std::optional<Index> batch_axis(
meta::eval_node auto const& node, IndexPredicate const& accept) {
std::optional<Index> best;
for (Index const& ix : contracted_indices(node)) {
if (!accept(ix)) continue;
if (!best ||
best->space().approximate_size() < ix.space().approximate_size())
best = ix;
}
return best;
}
[[nodiscard]] inline std::optional<Index> batch_axis(
meta::eval_node auto const& node) {
return batch_axis(node, [](Index const&) { return true; });
}
[[nodiscard]] inline std::optional<std::size_t> index_position(
meta::eval_node auto const& node, Index const& ix) {
auto const& idxs = node->canon_indices();
for (std::size_t p = 0; p < idxs.size(); ++p)
if (idxs[p] == ix) return p;
return std::nullopt;
}
template <typename Node>
[[nodiscard]] std::optional<std::pair<Node, std::size_t>> find_leaf_carrying(
Node const& node, Index const& ix) {
if (node.leaf()) {
if (auto const p = index_position(node, ix)) return std::pair{node, *p};
return std::nullopt;
}
if (auto found = find_leaf_carrying(node.left(), ix)) return found;
return find_leaf_carrying(node.right(), ix);
}
template <Trace EvalTrace = Trace::Default,
CacheCheck Cache = CacheCheck::Checked, meta::can_evaluate Node,
typename F, typename N, bool FHC>
requires meta::leaf_node_evaluator<Node, F>
ResultPtr evaluate(Node const& node, //
F const& le, //
CacheManager<N, FHC>& cache) {
if constexpr (Cache == CacheCheck::Checked) { // return from cache if found
auto mult_by_phase = [&node, &cache](ResultPtr res) {
auto phase = node->canon_phase();
if (phase == 1) return res;
ResultPtr post;
auto time =
timed_eval_inplace([&]() { post = res->mult_by_phase(phase); });
if constexpr (trace(EvalTrace)) {
size_t hwmark = log::bytes(cache, post).value;
if (!cache.alive(node)) hwmark += log::bytes(res).value;
auto stat =
log::EvalStat{.mode = log::EvalMode::MultByPhase,
.time = time,
.mem_result = log::bytes(post),
.mem_alloc = log::bytes(post),
.mem_hwmark = {cache.note_working_set(hwmark)}};
log::eval(stat, std::format("{} * {}", phase, node->label()));
}
return post;
};
if (auto ptr = cache.access(node); ptr) {
if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));
return mult_by_phase(ptr);
} else if (cache.exists(node)) {
auto ptr = cache.store(
node, mult_by_phase(evaluate<EvalTrace, CacheCheck::Unchecked>(
node, le, cache)));
if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));
return mult_by_phase(ptr);
} else {
// do nothing
}
}
ResultPtr result;
ResultPtr left;
ResultPtr right;
log::Duration time;
// Custom-evaluator interception: before the standard scheme, a non-leaf node
// may be evaluated by the cache's custom evaluator (e.g. blocked over a
// contracted index to bound peak memory). A non-null result is used (and
// cached by the Checked wrapper) as-is; null declines to the standard scheme
// below. See CacheManager::custom_evaluator_type.
if (!node.leaf()) {
if (auto const& custom_eval = cache.custom_evaluator(); custom_eval) {
ResultPtr intercepted;
time =
timed_eval_inplace([&]() { intercepted = custom_eval(node, cache); });
if (intercepted) {
if constexpr (trace(EvalTrace)) {
log::eval(log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.mem_result = log::bytes(intercepted),
.mem_alloc = log::bytes(intercepted),
.mem_hwmark = {cache.note_working_set(
log::bytes(cache, intercepted).value)}},
log::label(node));
}
return intercepted;
}
}
}
if (node.leaf()) {
time = timed_eval_inplace([&]() { result = le(node); });
} else if (node->op_type() == EvalOp::Adjoint) {
// Unary IR op: dispatch on left operand only; right is the Constant(1)
// sentinel kept around to preserve FullBinaryNode's invariant. We
// intentionally skip evaluating the sentinel — leaf evaluators that
// can't manufacture scalar constants (rare in practice but possible)
// would otherwise be invoked needlessly.
left = evaluate<EvalTrace>(node.left(), le, cache);
SEQUANT_ASSERT(left);
std::array<std::any, 2> const adj_ann{node.left()->annot(), node->annot()};
time = timed_eval_inplace([&]() { result = left->adjoint(adj_ann); });
} else {
left = evaluate<EvalTrace>(node.left(), le, cache);
right = evaluate<EvalTrace>(node.right(), le, cache);
SEQUANT_ASSERT(left);
SEQUANT_ASSERT(right);
std::array<std::any, 3> const ann{node.left()->annot(),
node.right()->annot(), node->annot()};
if (node->op_type() == EvalOp::Sum) {
time = timed_eval_inplace([&]() { result = left->sum(*right, ann); });
} else {
SEQUANT_ASSERT(node->op_type() == EvalOp::Product);
auto const de_nest =
node.left()->tot() && node.right()->tot() && !node->tot();
time = timed_eval_inplace([&]() {
result =
left->prod(*right, ann, de_nest ? DeNest::True : DeNest::False);
});
}
}
SEQUANT_ASSERT(result);
// logging
if constexpr (trace(EvalTrace)) {
if (node.leaf()) {
log::eval(log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = {cache.note_working_set(
log::bytes(cache, result).value)}},
log::label(node));
} else {
// A cached child is *distinct* from the local left/right when its
// canon_phase != 1, because mult_by_phase allocates a fresh buffer
// while the cache still holds the pre-phase data. So only skip the
// local's bytes when the cache aliases the same buffer (phase == 1).
// Adjoint nodes evaluate only the left operand (the right child is the
// sentinel Constant(1) — see the Adjoint branch above), so `right` is
// null; log::bytes() tolerates a null shared_ptr for that reason.
size_t hwmark = log::bytes(cache, result).value;
if (!cache.alive(node.left()) || node.left()->canon_phase() != 1)
hwmark += log::bytes(left).value;
if (right &&
(!cache.alive(node.right()) || node.right()->canon_phase() != 1))
hwmark += log::bytes(right).value;
log::eval(log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = {cache.note_working_set(hwmark)},
.mem_left = log::bytes(left),
.mem_right = log::bytes(right)},
log::label(node));
}
}
return result;
}
template <Trace EvalTrace = Trace::Default, meta::can_evaluate Node, typename F,
typename N, bool FHC>
requires meta::leaf_node_evaluator<Node, F> //
ResultPtr evaluate(Node const& node, //
auto const& layout, //
F const& le, //
CacheManager<N, FHC>& cache) {
// if the layout is not the default constructed value need to permute
bool const perm = layout != decltype(layout){};
std::string xpr;
if constexpr (trace(EvalTrace)) {
xpr = toUtf8(io::serialization::to_string(to_expr(node)));
log::term(log::TermMode::Begin, xpr);
}
struct {
ResultPtr pre, post;
} result;
result.pre = evaluate<EvalTrace>(node, le, cache);
auto time = timed_eval_inplace([&]() {
result.post = perm ? result.pre->permute(
std::array<std::any, 2>{node->annot(), layout})
: result.pre;
});
SEQUANT_ASSERT(result.post);
// logging
if constexpr (trace(EvalTrace)) {
if (perm) {
// result.pre aliases the cache only when the inner evaluate returned
// the cached buffer unchanged — i.e. the node is cached AND no
// mult_by_phase fresh allocation happened (phase == 1).
size_t hwmark = log::bytes(cache, result.post).value;
if (!cache.alive(node) || node->canon_phase() != 1)
hwmark += log::bytes(result.pre).value;
auto stat = log::EvalStat{.mode = log::EvalMode::Permute,
.time = time,
.mem_result = log::bytes(result.post),
.mem_alloc = log::bytes(result.post),
.mem_hwmark = {cache.note_working_set(hwmark)}};
log::eval(stat, node->label());
}
log::term(log::TermMode::End, xpr);
}
return result.post;
}
template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
typename F, typename N, bool FHC>
requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes, //
auto const& layout, //
F const& le, CacheManager<N, FHC>& cache) {
ResultPtr result;
// pre comes back from the permute-wrapping evaluate; it aliases the
// cache only when the inner evaluate returned the cached buffer
// unchanged — i.e. node cached, phase == 1, AND no permute happened.
bool const layout_is_default = (layout == decltype(layout){});
for (auto&& n : nodes) {
if (!result) {
result = evaluate<EvalTrace>(n, layout, le, cache);
continue;
}
ResultPtr pre = evaluate<EvalTrace>(n, layout, le, cache);
auto time = timed_eval_inplace([&]() { result->add_inplace(*pre); });
// logging
if constexpr (trace(EvalTrace)) {
// SumInplace allocates nothing: it writes into the accumulator.
// hwmark counts the cache plus both operands live at this moment;
// skip pre's bytes only when pre is the cached buffer itself.
size_t hwmark = log::bytes(cache, result).value;
if (!cache.alive(n) || n->canon_phase() != 1 || !layout_is_default)
hwmark += log::bytes(pre).value;
auto stat = log::EvalStat{.mode = log::EvalMode::SumInplace,
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = {0},
.mem_hwmark = {cache.note_working_set(hwmark)}};
log::eval(stat, n->label());
}
}
return result;
}
template <Trace EvalTrace = Trace::Default, meta::can_evaluate_range Nodes,
typename F, typename N, bool FHC>
requires meta::leaf_node_evaluator<std::ranges::range_value_t<Nodes>, F>
ResultPtr evaluate(Nodes const& nodes, //
F const& le, CacheManager<N, FHC>& cache) {
using annot_type = decltype([](std::ranges::range_value_t<Nodes> const& n) {
return n->annot();
});
static_assert(std::is_default_constructible_v<annot_type>);
return evaluate(nodes, annot_type{}, le, cache);
}
template <Trace EvalTrace = Trace::Default, typename... Args>
requires(!last_type_is_cache_manager<Args...>)
ResultPtr evaluate(Args&&... args) {
using Node =
std::remove_cvref_t<decltype(node0(arg0(std::forward<Args>(args)...)))>;
auto cache = CacheManager<Node>::empty();
return evaluate<EvalTrace>(std::forward<Args>(args)..., cache);
}
template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_symm(Args&&... args) {
ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
SEQUANT_ASSERT(pre);
ResultPtr result;
auto time = timed_eval_inplace([&]() { result = pre->symmetrize(); });
// logging
if constexpr (trace(EvalTrace)) {
// cache is owned by the inner evaluate call and out of scope here;
// hwmark reflects only the local working set (pre + freshly allocated
// result both live during the symmetrize op).
auto stat = log::EvalStat{.mode = log::EvalMode::Symmetrize,
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = log::bytes(pre, result)};
log::eval(stat, node0(arg0(std::forward<Args>(args)...))->label());
}
return result;
}
template <Trace EvalTrace = Trace::Default, typename... Args>
ResultPtr evaluate_antisymm(Args&&... args) {
ResultPtr pre = evaluate<EvalTrace>(std::forward<Args>(args)...);
SEQUANT_ASSERT(pre);
auto const& n0 = node0(arg0(std::forward<Args>(args)...));
ResultPtr result;
auto time = timed_eval_inplace(
[&]() { result = pre->antisymmetrize(n0->as_tensor().bra_rank()); });
// logging
if constexpr (trace(EvalTrace)) {
// See Symmetrize for the rationale on hwmark.
auto stat = log::EvalStat{.mode = log::EvalMode::Antisymmetrize,
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = log::bytes(pre, result)};
log::eval(stat, n0->label());
}
return result;
}
struct accept_any_index {
bool operator()(Index const&) const noexcept { return true; }
};
struct no_scope_guard {};
struct make_no_scope_guard {
no_scope_guard operator()(std::size_t /*n_batches*/) const noexcept {
return {};
}
};
template <typename F, typename IndexPredicate = accept_any_index,
typename ScopeGuardFactory = make_no_scope_guard>
[[nodiscard]] auto make_batched_custom_evaluator(
F le, std::size_t target_batch_size, IndexPredicate accept = {},
ScopeGuardFactory make_scope_guard = {}) {
return [le = std::move(le), target_batch_size, accept, make_scope_guard](
auto const& node, auto& cache) -> ResultPtr {
using cache_t = std::remove_reference_t<decltype(cache)>;
auto const K = batch_axis(node, accept);
if (!K) return nullptr;
auto const leaf = find_leaf_carrying(node, *K);
if (!leaf) return nullptr;
auto const batches =
le(leaf->first)->mode_batches(leaf->second, target_batch_size);
if (batches.size() <= 1)
return nullptr; // nothing to gain (or unbatchable)
// RAII scope for the batched partial contractions; a backend-supplied
// factory may relax block-sparse screening here (scaled by the batch count)
// so per-batch screening does not drop contributions that survive over the
// full batch axis.
auto const scope_guard = make_scope_guard(batches.size());
(void)scope_guard;
ResultPtr acc;
for (auto const& [e_lo, e_hi] : batches) {
if (e_lo == e_hi) continue;
// leaf evaluator that slices every leaf carrying K to this element batch;
// others pass through unchanged.
auto le_g = [&le, &K, e_lo = e_lo,
e_hi = e_hi](auto const& leaf_node) -> ResultPtr {
ResultPtr r = le(leaf_node);
if (auto const p = index_position(leaf_node, *K))
return r->slice_mode(*p, e_lo, e_hi);
return r;
};
// standard scheme on a fresh scratch cache: no re-interception, and the
// (partial, sliced) intermediates do not pollute the real cache.
auto scratch = cache_t::empty();
ResultPtr part = evaluate(node, le_g, scratch);
if (!acc)
acc = std::move(part);
else
acc->add_inplace(*part);
}
return acc;
};
}
} // namespace sequant
#endif // SEQUANT_EVAL_EVAL_HPP