.. _program_listing_file_SeQuant_domain_eval_eval.hpp: Program Listing for File eval.hpp ================================= |exhale_lsh| :ref:`Return to documentation for file ` (``SeQuant/domain/eval/eval.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef SEQUANT_EVAL_EVAL_HPP #define SEQUANT_EVAL_EVAL_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace sequant { namespace { template void log_eval(Args const&... args) noexcept { #ifdef SEQUANT_EVAL_TRACE auto l = Logger::instance(); if (l->log_level_eval > 0) write_log(l, "[EVAL] ", args...); #endif } [[maybe_unused]] void log_cache_access(size_t key, CacheManager const& cm) { #ifdef SEQUANT_EVAL_TRACE auto l = Logger::instance(); if (l->log_level_eval > 0) { assert(cm.exists(key)); auto max_l = cm.max_life(key); auto cur_l = cm.life(key); write_log(l, // "[CACHE] Accessed key: ", key, ". ", // cur_l, "/", max_l, " lives remain.\n"); if (cur_l == 0) { write_log(l, // "[CACHE] Released key: ", key, ".\n"); } } #endif } [[maybe_unused]] void log_cache_store(size_t key, CacheManager const& cm) { #ifdef SEQUANT_EVAL_TRACE auto l = Logger::instance(); if (l->log_level_eval > 0) { assert(cm.exists(key)); write_log(l, // "[CACHE] Stored key: ", key, ".\n"); // because storing automatically implies immediately accessing it log_cache_access(key, cm); } #endif } [[maybe_unused]] std::string perm_groups_string( container::svector> const& perm_groups) { std::string result; for (auto const& g : perm_groups) result += "(" + std::to_string(g[0]) + "," + std::to_string(g[1]) + "," + std::to_string(g[2]) + ") "; result.pop_back(); // remove last space return result; } template constexpr bool HasAnnotMethod{}; template constexpr bool HasAnnotMethod< T, std::void_t>().annot())>> = true; template constexpr bool IsEvaluable{}; template constexpr bool IsEvaluable< FullBinaryNode, std::enable_if_t && HasAnnotMethod>> = true; template constexpr bool IsEvaluable< const FullBinaryNode, std::enable_if_t && HasAnnotMethod>> = true; template constexpr bool IsIterableOfEvaluableNodes{}; template constexpr bool IsIterableOfEvaluableNodes< Iterable, std::enable_if_t>>> = true; } // namespace template constexpr bool IsLeafEvaluator{}; template constexpr bool IsLeafEvaluator{}; template constexpr bool IsLeafEvaluator< NodeT, Le, std::enable_if_t< IsEvaluable && std::is_same_v< ERPtr, std::remove_reference_t>>>> = true; class EvalExprTA final : public EvalExpr { public: [[nodiscard]] std::string const& annot() const; explicit EvalExprTA(Tensor const&); explicit EvalExprTA(Constant const&); explicit EvalExprTA(Variable const&); EvalExprTA(EvalExprTA const&, EvalExprTA const&, EvalOp); private: std::string annot_; }; // class EvalExprTA class EvalExprBTAS final : public EvalExpr { public: using annot_t = container::svector; [[nodiscard]] annot_t const& annot() const noexcept; explicit EvalExprBTAS(Tensor const&) noexcept; explicit EvalExprBTAS(Constant const&) noexcept; explicit EvalExprBTAS(Variable const&) noexcept; EvalExprBTAS(EvalExprBTAS const&, EvalExprBTAS const&, EvalOp) noexcept; private: annot_t annot_; }; // EvalExprBTAS template , bool> = true> ERPtr evaluate_crust(NodeT const&, Le const&); template , bool> = true> ERPtr evaluate_crust(NodeT const&, Le const&, CacheManager&); template , bool> = true> ERPtr evaluate_core(NodeT const& node, Le const& le, Args&&... args) { if (node.leaf()) { log_eval(node->is_constant() ? "[CONSTANT] " : node->is_variable() ? "[VARIABLE] " : "[TENSOR] ", node->label(), "\n"); return le(node); } else { ERPtr const left = evaluate_crust(node.left(), le, std::forward(args)...); ERPtr const right = evaluate_crust(node.right(), le, std::forward(args)...); assert(left); assert(right); std::array const ann{node.left()->annot(), node.right()->annot(), node->annot()}; if (node->op_type() == EvalOp::Sum) { log_eval("[SUM] ", node.left()->label(), " + ", node.right()->label(), " = ", node->label(), "\n"); return left->sum(*right, ann); } else { assert(node->op_type() == EvalOp::Prod); log_eval("[PRODUCT] ", node.left()->label(), " * ", node.right()->label(), " = ", node->label(), "\n"); auto const de_nest = node.left()->tot() && node.right()->tot() && !node->tot(); return left->prod(*right, ann, de_nest ? TA::DeNest::True : TA::DeNest::False); } } } template , bool>> ERPtr evaluate_crust(NodeT const& node, Le const& le) { return evaluate_core(node, le); } template , bool>> ERPtr evaluate_crust(NodeT const& node, Le const& le, CacheManager& cache) { auto const h = hash::value(*node); if (auto ptr = cache.access(h); ptr) { log_cache_access(h, cache); return ptr; } else if (cache.exists(h)) { auto ptr = cache.store(h, evaluate_core(node, le, cache)); log_cache_store(h, cache); return ptr; } else { return evaluate_core(node, le, cache); } } template , bool> = true> auto evaluate(NodeT const& node, Le&& le, Args&&... args) { return evaluate_crust(node, le, std::forward(args)...); } template , bool> = true, std::enable_if_t, Le>, bool> = true> auto evaluate(NodesT const& nodes, Le const& le, Args&&... args) { auto iter = std::begin(nodes); auto end = std::end(nodes); assert(iter != end); auto result = evaluate(*iter, le, std::forward(args)...); for (++iter; iter != end; ++iter) { auto right = evaluate(*iter, le, std::forward(args)...); result->add_inplace(*right); } return result; } template , bool> = true, std::enable_if_t, bool> = true> auto evaluate(NodeT const& node, // Annot const& layout, // Le const& le, Args&&... args) { auto result = evaluate_crust(node, le, std::forward(args)...); log_eval("[PERMUTE] ", node->label(), "\n"); return result->permute(std::array{node->annot(), layout}); } template , bool> = true, std::enable_if_t, Le>, bool> = true> auto evaluate(NodesT const& nodes, // Annot const& layout, // Le const& le, Args&&... args) { auto iter = std::begin(nodes); auto end = std::end(nodes); assert(iter != end); auto const pnode_label = (*iter)->label(); auto result = evaluate(*iter, layout, le, std::forward(args)...); for (++iter; iter != end; ++iter) { auto right = evaluate(*iter, layout, le, std::forward(args)...); log_eval("[ADD_INPLACE] ", (*iter)->label(), " += ", pnode_label, "\n"); result->add_inplace(*right); } return result; } template auto evaluate_symm(NodeT const& node, Annot const& layout, container::svector> const& perm_groups, Le const& le, Args&&... args) { container::svector> pgs; if (perm_groups.empty()) { // asked for symmetrization without specifying particle // symmetric index ranges assume both bra indices and ket indices are // symmetric in the particle exchange ExprPtr expr_ptr{}; if constexpr (IsIterableOfEvaluableNodes) { expr_ptr = (*std::begin(node))->expr(); } else { expr_ptr = node->expr(); } assert(expr_ptr->is()); auto const& t = expr_ptr->as(); assert(t.bra_rank() == t.ket_rank()); size_t const half_rank = t.bra_rank(); pgs = {{0, half_rank, half_rank}}; } auto result = evaluate(node, layout, le, std::forward(args)...); log_eval("[SYMMETRIZE] (bra pos, ket pos, length) ", perm_groups_string(perm_groups.empty() ? pgs : perm_groups), "\n"); return result->symmetrize(perm_groups.empty() ? pgs : perm_groups); } template auto evaluate_antisymm( NodeT const& node, // Annot const& layout, // container::svector> const& perm_groups, // Le const& le, // Args&&... args) { container::svector> pgs; if (perm_groups.empty()) { // asked for anti-symmetrization without specifying particle // antisymmetric index ranges assume both bra indices and ket indices are // antisymmetric in the particle exchange ExprPtr expr_ptr{}; if constexpr (IsIterableOfEvaluableNodes) { expr_ptr = (*std::begin(node))->expr(); } else { expr_ptr = node->expr(); } assert(expr_ptr->is()); auto const& t = expr_ptr->as(); assert(t.bra_rank() == t.ket_rank()); size_t const half_rank = t.bra_rank(); pgs = {{0, half_rank, half_rank}}; } auto result = evaluate(node, layout, le, std::forward(args)...); log_eval("[ANTISYMMETRIZE] (bra pos, ket pos, length) ", perm_groups_string(perm_groups.empty() ? pgs : perm_groups), "\n"); return result->antisymmetrize(perm_groups.empty() ? pgs : perm_groups); } } // namespace sequant #endif // SEQUANT_EVAL_EVAL_HPP