Program Listing for File eval.hpp¶
↰ Return to documentation for file (SeQuant/domain/eval/eval.hpp
)
#ifndef SEQUANT_EVAL_EVAL_HPP
#define SEQUANT_EVAL_EVAL_HPP
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval_node.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/tensor.hpp>
#include <SeQuant/domain/eval/cache_manager.hpp>
#include <SeQuant/domain/eval/eval_result.hpp>
#include <btas/btas.h>
#include <tiledarray.h>
#include <range/v3/numeric.hpp>
#include <range/v3/view.hpp>
#include <any>
#include <iostream>
#include <stdexcept>
#include <type_traits>
namespace sequant {
namespace {
template <typename... Args>
void log_eval(Args const&... args) noexcept {
#ifdef SEQUANT_EVAL_TRACE
auto l = Logger::instance();
if (l.log_level_eval > 0) write_log(l, "[EVAL] ", args...);
#endif
}
[[maybe_unused]] void log_cache_access(size_t key, CacheManager const& cm) {
#ifdef SEQUANT_EVAL_TRACE
auto l = Logger::instance();
if (l.log_level_eval > 0) {
assert(cm.exists(key));
auto max_l = cm.max_life(key);
auto cur_l = cm.life(key);
write_log(l, //
"[CACHE] Accessed key: ", key, ". ", //
cur_l, "/", max_l, " lives remain.\n");
if (cur_l == 0) {
write_log(l, //
"[CACHE] Released key: ", key, ".\n");
}
}
#endif
}
[[maybe_unused]] void log_cache_store(size_t key, CacheManager const& cm) {
#ifdef SEQUANT_EVAL_TRACE
auto l = Logger::instance();
if (l.log_level_eval > 0) {
assert(cm.exists(key));
write_log(l, //
"[CACHE] Stored key: ", key, ".\n");
// because storing automatically implies immediately accessing it
log_cache_access(key, cm);
}
#endif
}
[[maybe_unused]] std::string perm_groups_string(
container::svector<std::array<size_t, 3>> const& perm_groups) {
std::string result;
for (auto const& g : perm_groups)
result += "(" + std::to_string(g[0]) + "," + std::to_string(g[1]) + "," +
std::to_string(g[2]) + ") ";
result.pop_back(); // remove last space
return result;
}
template <typename, typename = void>
constexpr bool HasAnnotMethod{};
template <typename T>
constexpr bool HasAnnotMethod<
T, std::void_t<decltype(std::declval<meta::remove_cvref_t<T>>().annot())>> =
true;
template <typename, typename = void>
constexpr bool IsEvaluable{};
template <typename T>
constexpr bool IsEvaluable<
FullBinaryNode<T>,
std::enable_if_t<std::is_convertible_v<T, EvalExpr> && HasAnnotMethod<T>>> =
true;
template <typename T>
constexpr bool IsEvaluable<
const FullBinaryNode<T>,
std::enable_if_t<std::is_convertible_v<T, EvalExpr> && HasAnnotMethod<T>>> =
true;
template <typename, typename = void>
constexpr bool IsIterableOfEvaluableNodes{};
template <typename Iterable>
constexpr bool IsIterableOfEvaluableNodes<
Iterable, std::enable_if_t<IsEvaluable<meta::range_value_t<Iterable>>>> =
true;
} // namespace
template <typename, typename, typename = void>
constexpr bool IsLeafEvaluator{};
template <typename NodeT>
constexpr bool IsLeafEvaluator<NodeT, CacheManager, void>{};
template <typename NodeT, typename Le>
constexpr bool IsLeafEvaluator<
NodeT, Le,
std::enable_if_t<
IsEvaluable<NodeT> &&
std::is_same_v<
ERPtr, std::remove_reference_t<std::invoke_result_t<Le, NodeT>>>>> =
true;
class EvalExprTA final : public EvalExpr {
public:
[[nodiscard]] std::string const& annot() const;
explicit EvalExprTA(Tensor const&);
explicit EvalExprTA(Constant const&);
explicit EvalExprTA(Variable const&);
EvalExprTA(EvalExprTA const&, EvalExprTA const&, EvalOp);
private:
std::string annot_;
}; // class EvalExprTA
class EvalExprBTAS final : public EvalExpr {
public:
using annot_t = container::svector<long>;
[[nodiscard]] annot_t const& annot() const noexcept;
explicit EvalExprBTAS(Tensor const&) noexcept;
explicit EvalExprBTAS(Constant const&) noexcept;
explicit EvalExprBTAS(Variable const&) noexcept;
EvalExprBTAS(EvalExprBTAS const&, EvalExprBTAS const&, EvalOp) noexcept;
private:
annot_t annot_;
}; // EvalExprBTAS
template <typename NodeT, typename Le,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool> = true>
ERPtr evaluate_crust(NodeT const&, Le const&);
template <typename NodeT, typename Le,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool> = true>
ERPtr evaluate_crust(NodeT const&, Le const&, CacheManager&);
template <typename NodeT, typename Le, typename... Args,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool> = true>
ERPtr evaluate_core(NodeT const& node, Le const& le, Args&&... args) {
if (node.leaf()) {
log_eval(node->is_constant() ? "[CONSTANT] "
: node->is_variable() ? "[VARIABLE] "
: "[TENSOR] ",
node->label(), "\n");
return le(node);
} else {
ERPtr const left =
evaluate_crust(node.left(), le, std::forward<Args>(args)...);
ERPtr const right =
evaluate_crust(node.right(), le, std::forward<Args>(args)...);
assert(left);
assert(right);
std::array<std::any, 3> const ann{node.left()->annot(),
node.right()->annot(), node->annot()};
if (node->op_type() == EvalOp::Sum) {
log_eval("[SUM] ", node.left()->label(), " + ", node.right()->label(),
" = ", node->label(), "\n");
return left->sum(*right, ann);
} else {
assert(node->op_type() == EvalOp::Prod);
log_eval("[PRODUCT] ", node.left()->label(), " * ", node.right()->label(),
" = ", node->label(), "\n");
auto const de_nest =
node.left()->tot() && node.right()->tot() && !node->tot();
return left->prod(*right, ann,
de_nest ? TA::DeNest::True : TA::DeNest::False);
}
}
}
template <typename NodeT, typename Le,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool>>
ERPtr evaluate_crust(NodeT const& node, Le const& le) {
return evaluate_core(node, le);
}
template <typename NodeT, typename Le,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool>>
ERPtr evaluate_crust(NodeT const& node, Le const& le, CacheManager& cache) {
auto const h = hash::value(*node);
if (auto ptr = cache.access(h); ptr) {
log_cache_access(h, cache);
return ptr;
} else if (cache.exists(h)) {
auto ptr = cache.store(h, evaluate_core(node, le, cache));
log_cache_store(h, cache);
return ptr;
} else {
return evaluate_core(node, le, cache);
}
}
template <typename NodeT, typename Le, typename... Args,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool> = true>
auto evaluate(NodeT const& node, Le&& le, Args&&... args) {
return evaluate_crust(node, le, std::forward<Args>(args)...);
}
template <typename NodesT, typename Le, typename... Args,
std::enable_if_t<IsIterableOfEvaluableNodes<NodesT>, bool> = true,
std::enable_if_t<IsLeafEvaluator<meta::range_value_t<NodesT>, Le>,
bool> = true>
auto evaluate(NodesT const& nodes, Le const& le, Args&&... args) {
auto iter = std::begin(nodes);
auto end = std::end(nodes);
assert(iter != end);
auto result = evaluate(*iter, le, std::forward<Args>(args)...);
for (++iter; iter != end; ++iter) {
auto right = evaluate(*iter, le, std::forward<Args>(args)...);
result->add_inplace(*right);
}
return result;
}
template <typename NodeT, typename Annot, typename Le, typename... Args,
std::enable_if_t<IsEvaluable<NodeT>, bool> = true,
std::enable_if_t<IsLeafEvaluator<NodeT, Le>, bool> = true>
auto evaluate(NodeT const& node, //
Annot const& layout, //
Le const& le, Args&&... args) {
auto result = evaluate_crust(node, le, std::forward<Args>(args)...);
log_eval("[PERMUTE] ", node->label(), "\n");
return result->permute(std::array<std::any, 2>{node->annot(), layout});
}
template <typename NodesT, typename Annot, typename Le, typename... Args,
std::enable_if_t<IsIterableOfEvaluableNodes<NodesT>, bool> = true,
std::enable_if_t<IsLeafEvaluator<meta::range_value_t<NodesT>, Le>,
bool> = true>
auto evaluate(NodesT const& nodes, //
Annot const& layout, //
Le const& le, Args&&... args) {
auto iter = std::begin(nodes);
auto end = std::end(nodes);
assert(iter != end);
auto const pnode_label = (*iter)->label();
auto result = evaluate(*iter, layout, le, std::forward<Args>(args)...);
for (++iter; iter != end; ++iter) {
auto right = evaluate(*iter, layout, le, std::forward<Args>(args)...);
log_eval("[ADD_INPLACE] ", (*iter)->label(), " += ", pnode_label, "\n");
result->add_inplace(*right);
}
return result;
}
template <typename NodeT, typename Annot, typename Le, typename... Args>
auto evaluate_symm(NodeT const& node, Annot const& layout,
container::svector<std::array<size_t, 3>> const& perm_groups,
Le const& le, Args&&... args) {
container::svector<std::array<size_t, 3>> pgs;
if (perm_groups.empty()) {
// asked for symmetrization without specifying particle
// symmetric index ranges assume both bra indices and ket indices are
// symmetric in the particle exchange
ExprPtr expr_ptr{};
if constexpr (IsIterableOfEvaluableNodes<NodeT>) {
expr_ptr = (*std::begin(node))->expr();
} else {
expr_ptr = node->expr();
}
assert(expr_ptr->is<Tensor>());
auto const& t = expr_ptr->as<Tensor>();
assert(t.bra_rank() == t.ket_rank());
size_t const half_rank = t.bra_rank();
pgs = {{0, half_rank, half_rank}};
}
auto result = evaluate(node, layout, le, std::forward<Args>(args)...);
log_eval("[SYMMETRIZE] (bra pos, ket pos, length) ",
perm_groups_string(perm_groups.empty() ? pgs : perm_groups), "\n");
return result->symmetrize(perm_groups.empty() ? pgs : perm_groups);
}
template <typename NodeT, typename Annot, typename Le,
typename... Args>
auto evaluate_antisymm(
NodeT const& node, //
Annot const& layout, //
container::svector<std::array<size_t, 3>> const& perm_groups, //
Le const& le, //
Args&&... args) {
container::svector<std::array<size_t, 3>> pgs;
if (perm_groups.empty()) {
// asked for anti-symmetrization without specifying particle
// antisymmetric index ranges assume both bra indices and ket indices are
// antisymmetric in the particle exchange
ExprPtr expr_ptr{};
if constexpr (IsIterableOfEvaluableNodes<NodeT>) {
expr_ptr = (*std::begin(node))->expr();
} else {
expr_ptr = node->expr();
}
assert(expr_ptr->is<Tensor>());
auto const& t = expr_ptr->as<Tensor>();
assert(t.bra_rank() == t.ket_rank());
size_t const half_rank = t.bra_rank();
pgs = {{0, half_rank, half_rank}};
}
auto result = evaluate(node, layout, le, std::forward<Args>(args)...);
log_eval("[ANTISYMMETRIZE] (bra pos, ket pos, length) ",
perm_groups_string(perm_groups.empty() ? pgs : perm_groups), "\n");
return result->antisymmetrize(perm_groups.empty() ? pgs : perm_groups);
}
} // namespace sequant
#endif // SEQUANT_EVAL_EVAL_HPP