Program Listing for File eval_expr.hpp¶
↰ Return to documentation for file (SeQuant/core/eval_expr.hpp)
#ifndef SEQUANT_EVAL_EXPR_HPP
#define SEQUANT_EVAL_EXPR_HPP
#include <SeQuant/core/binary_node.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/external/bliss/graph.hh>
#include <cstddef>
#include <memory>
#include <string>
namespace sequant {
class Tensor;
enum class EvalOp {
Sum,
Product
};
enum class ResultType { Tensor, Scalar };
struct EvalOpSetter;
class EvalExpr {
public:
friend struct EvalOpSetter;
using index_vector = Index::index_vector;
explicit EvalExpr(Tensor const& tnsr);
explicit EvalExpr(Constant const& c);
explicit EvalExpr(Variable const& v);
EvalExpr(EvalOp op, ResultType res, ExprPtr const& expr, index_vector ixs,
std::int8_t phase, size_t hash,
std::shared_ptr<bliss::Graph> connectivity);
[[nodiscard]] const std::optional<EvalOp>& op_type() const noexcept;
[[nodiscard]] ResultType result_type() const noexcept;
[[nodiscard]] size_t hash_value() const noexcept;
[[nodiscard]] ExprPtr expr() const noexcept;
[[nodiscard]] bool tot() const noexcept;
[[nodiscard]] std::wstring to_latex() const noexcept;
[[nodiscard]] bool is_tensor() const noexcept;
[[nodiscard]] bool is_scalar() const noexcept;
[[nodiscard]] bool is_constant() const noexcept;
[[nodiscard]] bool is_variable() const noexcept;
[[nodiscard]] bool is_primary() const noexcept;
[[nodiscard]] bool is_product() const noexcept;
[[nodiscard]] bool is_sum() const noexcept;
[[nodiscard]] Tensor const& as_tensor() const;
[[nodiscard]] Constant const& as_constant() const;
[[nodiscard]] Variable const& as_variable() const;
[[nodiscard]] std::string label() const noexcept;
[[nodiscard]] std::string indices_annot() const noexcept;
[[nodiscard]] index_vector const& canon_indices() const noexcept;
[[nodiscard]] std::int8_t canon_phase() const noexcept;
[[nodiscard]] bool has_connectivity_graph() const noexcept;
[[nodiscard]] const bliss::Graph& connectivity_graph() const noexcept;
[[nodiscard]] std::shared_ptr<bliss::Graph> copy_connectivity_graph()
const noexcept;
protected:
std::optional<EvalOp> op_type_ = std::nullopt;
ResultType result_type_;
ExprPtr expr_;
index_vector canon_indices_;
std::int8_t canon_phase_{1};
size_t hash_value_;
std::shared_ptr<bliss::Graph> connectivity_;
};
struct EvalOpSetter {
void set(EvalExpr& expr, EvalOp op) { expr.op_type_ = op; }
};
class EvalExprTA final : public EvalExpr {
public:
template <typename... Args, typename = std::enable_if_t<
std::is_constructible_v<EvalExpr, Args...>>>
EvalExprTA(Args&&... args) : EvalExpr{std::forward<Args>(args)...} {
annot_ = indices_annot();
}
[[nodiscard]] inline auto const& annot() const noexcept { return annot_; }
private:
std::string annot_;
};
class EvalExprBTAS final : public EvalExpr {
public:
using annot_t = container::svector<long>;
template <typename Iterable>
static auto index_hash(Iterable&& bk) {
return ranges::views::transform(
std::forward<Iterable>(bk), [](auto const& idx) {
//
// WARNING!
// The BTAS uses long for scalar indexing by default.
// Hence, here we explicitly cast the size_t values to long
// Which is a potentially narrowing conversion leading to
// integral overflow. Hence, the values in the returned
// container are mixed negative and positive integers (long type)
//
return static_cast<long>(sequant::hash::value(Index{idx}.label()));
});
}
template <typename... Args, typename = std::enable_if_t<
std::is_constructible_v<EvalExpr, Args...>>>
EvalExprBTAS(Args&&... args) : EvalExpr{std::forward<Args>(args)...} {
annot_ = index_hash(canon_indices()) | ranges::to<annot_t>;
}
[[nodiscard]] inline annot_t const& annot() const noexcept { return annot_; }
private:
annot_t annot_;
};
namespace meta {
namespace detail {
template <typename, typename = void>
constexpr bool is_eval_expr{};
template <typename T>
constexpr bool
is_eval_expr<T, std::enable_if_t<std::is_convertible_v<T, EvalExpr>>>{true};
template <typename, typename = void>
constexpr bool is_eval_node{};
template <typename T>
constexpr bool
is_eval_node<FullBinaryNode<T>, std::enable_if_t<is_eval_expr<T>>>{true};
} // namespace detail
template <typename T>
concept eval_expr = detail::is_eval_expr<T>;
template <typename T>
concept eval_node = detail::is_eval_node<std::remove_cvref_t<T>>;
template <typename Rng>
concept eval_node_range =
std::ranges::range<Rng> && eval_node<std::ranges::range_value_t<Rng>>;
} // namespace meta
namespace impl {
FullBinaryNode<EvalExpr> binarize(ExprPtr const&);
}
template <meta::eval_expr T>
using EvalNode = FullBinaryNode<T>;
template <typename ExprT = EvalExpr>
requires std::is_constructible_v<ExprT, EvalExpr>
FullBinaryNode<ExprT> binarize(ExprPtr const& expr) {
if constexpr (std::is_same_v<ExprT, EvalExpr>) return impl::binarize(expr);
return transform_node(impl::binarize(expr),
[](auto&& val) { return ExprT{val}; });
}
template <typename ExprT = EvalExpr>
requires std::is_constructible_v<ExprT, EvalExpr>
FullBinaryNode<ExprT> binarize(ResultExpr const& res) {
FullBinaryNode<ExprT> tree = binarize<ExprT>(res.expression());
const bool is_scalar =
res.bra().empty() && res.ket().empty() && res.aux().empty();
if (tree.size() < 2) {
// We want to have a result node with the result from the ResultExpr.
// In order for that to work, we need a dedicated result node in the first
// place. Hence, we adapt the represented expression for terminals to be
// that terminal multiplied by 1.
ExprT result = [&]() {
if (is_scalar) {
return *binarize<ExprT>(ex<Variable>(res.result_as_variable()));
}
return *binarize<ExprT>(ex<Tensor>(res.result_as_tensor()));
}();
EvalOpSetter{}.set(result, EvalOp::Product);
tree = FullBinaryNode<ExprT>(std::move(result), std::move(tree),
binarize<ExprT>(ex<Constant>(1)));
}
SEQUANT_ASSERT(tree.size() > 1);
if (is_scalar) {
if (res.has_label()) {
tree->expr().template as<Variable>().set_label(res.label());
}
} else {
Tensor& tensor = tree->expr().template as<Tensor>();
tensor = res.result_as_tensor();
// if (res.has_label()) {
// tensor.set_label(res.label());
// }
// SEQUANT_ASSERT(tensor.num_slots() ==
// res.bra().size() + res.ket().size() + res.aux().size());
// tensor.set_bra(res.bra());
// tensor.set_ket(res.ket());
// tensor.set_aux(res.aux());
}
return tree;
}
ExprPtr to_expr(meta::eval_node auto const& node) {
auto const op = node->op_type();
auto const& evxpr = *node;
if (node.leaf()) return evxpr.expr();
if (op == EvalOp::Product) {
auto prod = Product{};
ExprPtr lexpr = to_expr(node.left());
ExprPtr rexpr = to_expr(node.right());
prod.append(1, lexpr, Product::Flatten::No);
prod.append(1, rexpr, Product::Flatten::No);
SEQUANT_ASSERT(!prod.empty());
if (prod.size() == 1 && !prod.factor(0)->is<Tensor>()) {
return ex<Product>(Product{prod.scalar(), prod.factor(0)->begin(),
prod.factor(0)->end(), Product::Flatten::No});
} else {
return ex<Product>(std::move(prod));
}
} else {
SEQUANT_ASSERT(op == EvalOp::Sum && "unsupported operation type");
return ex<Sum>(Sum{to_expr(node.left()), to_expr(node.right())});
}
}
} // namespace sequant
#endif // SEQUANT_EVAL_EXPR_HPP