Program Listing for File expr.hpp¶
↰ Return to documentation for file (SeQuant/core/expressions/expr.hpp)
#ifndef SEQUANT_EXPRESSIONS_EXPR_HPP
#define SEQUANT_EXPRESSIONS_EXPR_HPP
#include <SeQuant/core/expressions/expr_ptr.hpp>
#include <SeQuant/core/options.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <boost/core/demangle.hpp>
#include <atomic>
#include <memory>
#include <optional>
#include <range/v3/all.hpp>
namespace sequant {
static const wchar_t adjoint_label = L'\u207A';
class Expr : public std::enable_shared_from_this<Expr>,
public ranges::view_facade<Expr> {
public:
using range_type = ranges::view_facade<Expr>;
using hash_type = std::size_t;
using type_id_type = int; // to speed up comparisons
Expr() = default;
virtual ~Expr() = default;
bool is_atom() const { return ranges::empty(*this); }
virtual bool is_zero() const { return false; }
virtual std::wstring to_latex() const;
virtual std::wstring to_wolfram() const;
virtual ExprPtr clone() const;
ExprPtr exprptr_from_this() {
if (weak_from_this().use_count() == 0)
return this->clone();
else
return static_cast<ExprPtr>(this->shared_from_this());
}
ExprPtr exprptr_from_this() const {
if (weak_from_this().use_count() == 0)
return this->clone();
else
return static_cast<const ExprPtr>(
std::const_pointer_cast<Expr>(this->shared_from_this()));
}
virtual ExprPtr canonicalize(
CanonicalizeOptions = CanonicalizeOptions::default_options()) {
return {}; // by default do nothing and return nullptr
}
virtual ExprPtr rapid_canonicalize(
CanonicalizeOptions = CanonicalizeOptions::default_options().copy_and_set(
CanonicalizationMethod::Rapid)) {
return this->canonicalize({.method = CanonicalizationMethod::Rapid});
}
// clang-format off
// clang-format on
template <typename Visitor>
bool visit(Visitor &&visitor, const bool atoms_only = false) {
return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
}
template <typename Visitor>
bool visit(Visitor &&visitor, const bool atoms_only = false) const {
return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
}
auto begin_subexpr() { return range_type::begin(); }
auto end_subexpr() { return range_type::end(); }
auto begin_subexpr() const { return range_type::begin(); }
auto end_subexpr() const { return range_type::end(); }
Expr &expr() { return *this; }
const Expr &expr() const { return *this; }
template <typename T, typename Enabler = void>
struct is_shared_ptr_of_expr : std::false_type {};
template <typename T>
struct is_shared_ptr_of_expr<std::shared_ptr<T>,
std::enable_if_t<is_expr_v<T>>>
: std::true_type {};
template <typename T, typename Enabler = void>
struct is_shared_ptr_of_expr_or_derived : std::false_type {};
template <typename T>
struct is_shared_ptr_of_expr_or_derived<std::shared_ptr<T>,
std::enable_if_t<is_an_expr_v<T>>>
: std::true_type {};
virtual bool is_cnumber() const {
if (is_atom())
return true;
else {
bool result = true;
for (auto it = begin_subexpr(); result && it != end_subexpr(); ++it) {
result &= (*it)->is_cnumber();
}
return result;
}
}
bool commutes_with(const Expr &that) const {
auto this_is_atom = is_atom();
auto that_is_atom = that.is_atom();
bool result = true;
if (this_is_atom && that_is_atom) {
result =
this->is_cnumber() || that.is_cnumber() || commutes_with_atom(that);
} else if (this_is_atom) {
if (!this->is_cnumber()) {
for (auto it = that.begin_subexpr(); result && it != that.end_subexpr();
++it) {
result &= this->commutes_with(**it);
}
}
} else {
for (auto it = this->begin_subexpr(); result && it != this->end_subexpr();
++it) {
result &= (*it)->commutes_with(that);
}
}
return result;
}
virtual void adjoint();
hash_type hash_value(
std::function<hash_type(const std::shared_ptr<const Expr> &)> hasher = {})
const {
return hasher ? hasher(shared_from_this()) : memoizing_hash();
}
virtual type_id_type type_id() const
#if __GNUG__
{
abort();
}
#else
= 0;
#endif
friend inline bool operator==(const Expr &a, const Expr &b);
template <typename T, typename = std::enable_if<is_an_expr_v<T>>>
bool operator<(const T &that) const {
if (type_id() ==
that.type_id()) { // if same type, use generic (or type-specific, if
// available) comparison
return static_less_than(static_cast<const Expr &>(that));
} else { // order types by type id
return type_id() < that.type_id();
}
}
template <typename T>
static type_id_type get_type_id() {
return type_id_accessor<T>();
}
template <typename T>
static void set_type_id(type_id_type id) {
type_id_accessor<T>() = id;
}
template <typename T>
bool is() const {
if constexpr (is_expr_v<T>)
return true;
else if constexpr (std::is_base_of_v<Expr, T>)
return this->type_id() == get_type_id<std::remove_cvref_t<T>>();
else
return dynamic_cast<const T *>(this) != nullptr;
}
template <typename T>
const T &as() const {
SEQUANT_ASSERT(this->is<T>());
if constexpr (std::is_base_of_v<Expr, T>) {
return static_cast<const T &>(*this);
} else
return dynamic_cast<const T &>(*this);
}
template <typename T>
T &as() {
SEQUANT_ASSERT(this->is<T>());
if constexpr (std::is_base_of_v<Expr, T>) {
return static_cast<T &>(*this);
} else
return dynamic_cast<T &>(*this);
}
std::string type_name() const {
return boost::core::demangle(typeid(*this).name());
}
virtual Expr &operator*=(const Expr &that);
virtual Expr &operator^=(const Expr &that);
virtual Expr &operator+=(const Expr &that);
virtual Expr &operator-=(const Expr &that);
private:
friend ranges::range_access;
template <
typename E, typename Visitor,
typename = std::enable_if_t<std::is_same_v<std::remove_cvref_t<E>, Expr>>>
static bool visit_impl(E &&expr, Visitor &&visitor, const bool atoms_only) {
if (expr.weak_from_this().use_count() == 0)
throw std::invalid_argument(
"Expr::visit: cannot visit expressions not managed by shared_ptr");
for (auto &subexpr_ptr : expr.expr()) {
const auto subexpr_is_an_atom = subexpr_ptr->is_atom();
const auto need_to_visit_subexpr = !atoms_only || subexpr_is_an_atom;
bool visited = false;
if (!subexpr_is_an_atom) // if not a leaf, recur into it
visited = visit_impl(*subexpr_ptr, std::forward<Visitor>(visitor),
atoms_only);
// call on the subexpression itself, if not yet done so
if (need_to_visit_subexpr && !visited) visitor(subexpr_ptr);
}
// N.B. can only visit itself if visitor is nonmutating!
bool this_visited = false;
if constexpr (std::is_invocable_r_v<void, std::remove_reference_t<Visitor>,
const ExprPtr &>) {
if (!atoms_only || expr.is_atom()) {
const ExprPtr this_exprptr = expr.exprptr_from_this();
visitor(this_exprptr);
this_visited = true;
}
}
return this_visited;
}
protected:
Expr(Expr &&) = default;
Expr(const Expr &) = default;
Expr &operator=(Expr &&) = default;
Expr &operator=(const Expr &) = default;
struct cursor {
using value_type = ExprPtr;
cursor() = default;
constexpr explicit cursor(ExprPtr *subexpr_ptr) noexcept
: ptr_{subexpr_ptr} {}
constexpr explicit cursor(const ExprPtr *subexpr_ptr) noexcept
: ptr_{const_cast<ExprPtr *>(subexpr_ptr)}, const_{true} {}
bool equal(const cursor &that) const { return ptr_ == that.ptr_; }
void next() { ++ptr_; }
void prev() { --ptr_; }
// TODO figure out why can't return const here if want to be able to assign
// to *begin(Expr&)
ExprPtr &read() const {
RANGES_EXPECT(ptr_);
return *ptr_;
}
ExprPtr &read() {
RANGES_EXPECT(const_ == false);
RANGES_EXPECT(ptr_);
return *ptr_;
}
void assign(const ExprPtr &that_ptr) {
RANGES_EXPECT(ptr_);
*ptr_ = that_ptr;
}
std::ptrdiff_t distance_to(cursor const &that) const {
return that.ptr_ - ptr_;
}
void advance(std::ptrdiff_t n) { ptr_ += n; }
private:
ExprPtr *ptr_ =
nullptr; // both begin and end will be represented by this, so Expr
// without subexpressions begin() equals end() automatically
bool const_ = false; // assert in nonconst ops
};
virtual cursor begin_cursor() { return cursor{}; }
virtual cursor end_cursor() { return cursor{}; }
virtual cursor begin_cursor() const { return cursor{}; }
virtual cursor end_cursor() const { return cursor{}; }
mutable std::optional<hash_type> hash_value_; // not initialized by default
virtual hash_type memoizing_hash() const {
static const hash_type default_hash_value = 0;
if (hash_value_)
return *hash_value_;
else
return default_hash_value;
}
virtual void reset_hash_value() const { hash_value_.reset(); }
virtual bool static_equal([[maybe_unused]] const Expr &that) const
#if __GNUG__
{
abort();
}
#else
= 0;
#endif
virtual bool static_less_than(const Expr &that) const {
return this->hash_value() < that.hash_value();
}
virtual bool commutes_with_atom([[maybe_unused]] const Expr &that) const {
return true;
}
private:
static type_id_type get_next_type_id() {
static std::atomic<type_id_type> grand_type_id = 0;
return ++grand_type_id;
}
template <typename T>
static type_id_type &type_id_accessor() {
static type_id_type type_id = get_next_type_id();
return type_id;
}
private:
std::logic_error not_implemented(const char *fn) const;
}; // class Expr
template <>
struct Expr::is_shared_ptr_of_expr<ExprPtr, void> : std::true_type {};
template <>
struct Expr::is_shared_ptr_of_expr_or_derived<ExprPtr, void> : std::true_type {
};
inline bool operator==(const Expr &a, const Expr &b) {
if (a.type_id() != b.type_id())
return false;
else
return a.static_equal(b);
}
struct proportional_to {
bool operator()(const ExprPtr &expr1, const ExprPtr &expr2) const;
};
} // namespace sequant
#endif // SEQUANT_EXPRESSIONS_EXPR_HPP