Program Listing for File expr.hpp¶
↰ Return to documentation for file (SeQuant/core/expr.hpp
)
//
// Created by Eduard Valeyev on 3/23/18.
//
#ifndef SEQUANT_EXPR_HPP
#define SEQUANT_EXPR_HPP
#include <SeQuant/core/expr_fwd.hpp>
#include <SeQuant/core/complex.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/latex.hpp>
#include <SeQuant/core/rational.hpp>
#include <SeQuant/core/wolfram.hpp>
#include <range/v3/all.hpp>
#include <boost/core/demangle.hpp>
#include <boost/numeric/conversion/cast.hpp>
#include <algorithm>
#include <atomic>
#include <cassert>
#include <cstdlib>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <type_traits>
#include <typeinfo>
#include <utility>
namespace sequant {
namespace {
template <typename T>
constexpr bool is_an_expr_v = meta::is_base_of_v<Expr, T>;
template <typename T>
constexpr bool is_expr_v = meta::is_same_v<Expr, T>;
template <typename T>
constexpr bool is_a_constant_v = meta::is_base_of_v<Constant, T>;
template <typename T>
constexpr bool is_constant_v = meta::is_same_v<Constant, T>;
template <typename T>
constexpr bool is_a_sum_v = meta::is_base_of_v<Sum, T>;
template <typename T>
constexpr bool is_sum_v = meta::is_same_v<Sum, T>;
template <typename T>
constexpr bool is_a_product_v = meta::is_base_of_v<Product, T>;
template <typename T>
constexpr bool is_product_v = meta::is_same_v<Product, T>;
template <typename T>
constexpr bool is_a_variable_v = meta::is_base_of_v<Variable, T>;
template <typename T>
constexpr bool is_variable_v = meta::is_same_v<Variable, T>;
} // namespace
class ExprPtr : public std::shared_ptr<Expr> {
public:
using base_type = std::shared_ptr<Expr>;
using base_type::operator->;
using base_type::base_type;
ExprPtr() = default;
ExprPtr(const ExprPtr &) = default;
ExprPtr(ExprPtr &&) = default;
template <typename E, typename = std::enable_if_t<
std::is_same_v<std::remove_const_t<E>, Expr> ||
std::is_base_of_v<Expr, std::remove_const_t<E>>>>
ExprPtr(const std::shared_ptr<E> &other_sptr) : base_type(other_sptr) {}
template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
ExprPtr(std::shared_ptr<E> &&other_sptr) : base_type(std::move(other_sptr)) {}
template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
ExprPtr &operator=(const std::shared_ptr<E> &other_sptr) {
as_shared_ptr() = other_sptr;
return *this;
}
template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
ExprPtr &operator=(std::shared_ptr<E> &&other_sptr) {
as_shared_ptr() = std::move(other_sptr);
return *this;
}
ExprPtr &operator=(const ExprPtr &) = default;
ExprPtr &operator=(ExprPtr &&) = default;
~ExprPtr() = default;
[[nodiscard]] ExprPtr clone() const &;
[[nodiscard]] ExprPtr clone() && noexcept;
base_type &as_shared_ptr() &;
const base_type &as_shared_ptr() const &;
base_type &&as_shared_ptr() &&;
template <typename E, typename = std::enable_if_t<!is_expr_v<E>>>
std::shared_ptr<E> as_shared_ptr() const {
assert(this->is<E>());
return std::static_pointer_cast<E>(this->as_shared_ptr());
}
Expr &operator*() &;
const Expr &operator*() const &;
Expr &&operator*() &&;
ExprPtr &operator+=(const ExprPtr &other);
ExprPtr &operator-=(const ExprPtr &);
ExprPtr &operator*=(const ExprPtr &);
template <typename T>
bool is() const;
template <typename T>
const T &as() const;
template <typename T>
T &as();
std::wstring to_latex() const;
}; // class ExprPtr
inline bool operator==(const ExprPtr &x, std::nullptr_t) {
return x.get() == nullptr;
}
inline bool operator==(std::nullptr_t, const ExprPtr &x) {
return x.get() == nullptr;
}
class Expr : public std::enable_shared_from_this<Expr>,
public ranges::view_facade<Expr> {
public:
using range_type = ranges::view_facade<Expr>;
using hash_type = std::size_t;
using type_id_type = int; // to speed up comparisons
Expr() = default;
virtual ~Expr() = default;
bool is_atom() const { return ranges::empty(*this); }
virtual std::wstring to_latex() const;
virtual std::wstring to_wolfram() const;
virtual ExprPtr clone() const;
ExprPtr exprptr_from_this() {
return static_cast<ExprPtr>(this->shared_from_this());
}
ExprPtr exprptr_from_this() const {
return static_cast<const ExprPtr>(
std::const_pointer_cast<Expr>(this->shared_from_this()));
}
virtual ExprPtr canonicalize() {
return {}; // by default do nothing and return nullptr
}
virtual ExprPtr rapid_canonicalize() { return this->canonicalize(); }
// clang-format off
// clang-format on
template <typename Visitor>
bool visit(Visitor &&visitor, const bool atoms_only = false) {
return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
}
template <typename Visitor>
bool visit(Visitor &&visitor, const bool atoms_only = false) const {
return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
}
auto begin_subexpr() { return range_type::begin(); }
auto end_subexpr() { return range_type::end(); }
auto begin_subexpr() const { return range_type::begin(); }
auto end_subexpr() const { return range_type::end(); }
Expr &expr() { return *this; }
const Expr &expr() const { return *this; }
template <typename T, typename Enabler = void>
struct is_shared_ptr_of_expr : std::false_type {};
template <typename T>
struct is_shared_ptr_of_expr<std::shared_ptr<T>,
std::enable_if_t<is_expr_v<T>>>
: std::true_type {};
template <typename T, typename Enabler = void>
struct is_shared_ptr_of_expr_or_derived : std::false_type {};
template <typename T>
struct is_shared_ptr_of_expr_or_derived<std::shared_ptr<T>,
std::enable_if_t<is_an_expr_v<T>>>
: std::true_type {};
virtual bool is_cnumber() const {
if (is_atom())
return true;
else {
bool result = true;
for (auto it = begin_subexpr(); result && it != end_subexpr(); ++it) {
result &= (*it)->is_cnumber();
}
return result;
}
}
bool commutes_with(const Expr &that) const {
auto this_is_atom = is_atom();
auto that_is_atom = that.is_atom();
bool result = true;
if (this_is_atom && that_is_atom) {
result =
this->is_cnumber() || that.is_cnumber() || commutes_with_atom(that);
} else if (this_is_atom) {
if (!this->is_cnumber()) {
for (auto it = that.begin_subexpr(); result && it != that.end_subexpr();
++it) {
result &= this->commutes_with(**it);
}
}
} else {
for (auto it = this->begin_subexpr(); result && it != this->end_subexpr();
++it) {
result &= (*it)->commutes_with(that);
}
}
return result;
}
virtual void adjoint();
hash_type hash_value(
std::function<hash_type(const std::shared_ptr<const Expr> &)> hasher = {})
const {
return hasher ? hasher(shared_from_this()) : memoizing_hash();
}
virtual type_id_type type_id() const
#if __GNUG__
{
abort();
}
#else
= 0;
#endif
friend inline bool operator==(const Expr &a, const Expr &b);
template <typename T, typename = std::enable_if<is_an_expr_v<T>>>
bool operator<(const T &that) const {
if (type_id() ==
that.type_id()) { // if same type, use generic (or type-specific, if
// available) comparison
return static_less_than(static_cast<const Expr &>(that));
} else { // order types by type id
return type_id() < that.type_id();
}
}
template <typename T>
static type_id_type get_type_id() {
return type_id_accessor<T>();
};
template <typename T>
static void set_type_id(type_id_type id) {
type_id_accessor<T>() = id;
};
template <typename T>
bool is() const {
if constexpr (is_expr_v<T>)
return true;
else
return this->type_id() == get_type_id<meta::remove_cvref_t<T>>();
}
template <typename T>
const T &as() const {
assert(this->is<T>());
return static_cast<const T &>(*this);
}
template <typename T>
T &as() {
assert(this->is<T>());
return static_cast<T &>(*this);
}
std::string type_name() const {
return boost::core::demangle(typeid(*this).name());
}
virtual Expr &operator*=(const Expr &that);
virtual Expr &operator^=(const Expr &that);
virtual Expr &operator+=(const Expr &that);
virtual Expr &operator-=(const Expr &that);
private:
friend ranges::range_access;
template <typename E, typename Visitor,
typename =
std::enable_if_t<std::is_same_v<meta::remove_cvref_t<E>, Expr>>>
static bool visit_impl(E &&expr, Visitor &&visitor, const bool atoms_only) {
if (expr.weak_from_this().use_count() == 0)
throw std::invalid_argument(
"Expr::visit: cannot visit expressions not managed by shared_ptr");
for (auto &subexpr_ptr : expr.expr()) {
const auto subexpr_is_an_atom = subexpr_ptr->is_atom();
const auto need_to_visit_subexpr = !atoms_only || subexpr_is_an_atom;
bool visited = false;
if (!subexpr_is_an_atom) // if not a leaf, recur into it
visited = visit_impl(*subexpr_ptr, std::forward<Visitor>(visitor),
atoms_only);
// call on the subexpression itself, if not yet done so
if (need_to_visit_subexpr && !visited) visitor(subexpr_ptr);
}
// N.B. can only visit itself if visitor is nonmutating!
bool this_visited = false;
if constexpr (std::is_invocable_r_v<void, std::remove_reference_t<Visitor>,
const ExprPtr &>) {
if (!atoms_only || expr.is_atom()) {
const ExprPtr this_exprptr = expr.exprptr_from_this();
visitor(this_exprptr);
this_visited = true;
}
}
return this_visited;
}
protected:
Expr(Expr &&) = default;
Expr(const Expr &) = default;
Expr &operator=(Expr &&) = default;
Expr &operator=(const Expr &) = default;
struct cursor {
using value_type = ExprPtr;
cursor() = default;
constexpr explicit cursor(ExprPtr *subexpr_ptr) noexcept
: ptr_{subexpr_ptr} {}
constexpr explicit cursor(const ExprPtr *subexpr_ptr) noexcept
: ptr_{const_cast<ExprPtr *>(subexpr_ptr)}, const_{true} {}
bool equal(const cursor &that) const { return ptr_ == that.ptr_; }
void next() { ++ptr_; }
void prev() { --ptr_; }
// TODO figure out why can't return const here if want to be able to assign
// to *begin(Expr&)
ExprPtr &read() const {
RANGES_EXPECT(ptr_);
return *ptr_;
}
ExprPtr &read() {
RANGES_EXPECT(const_ == false);
RANGES_EXPECT(ptr_);
return *ptr_;
}
void assign(const ExprPtr &that_ptr) {
RANGES_EXPECT(ptr_);
*ptr_ = that_ptr;
}
std::ptrdiff_t distance_to(cursor const &that) const {
return that.ptr_ - ptr_;
}
void advance(std::ptrdiff_t n) { ptr_ += n; }
private:
ExprPtr *ptr_ =
nullptr; // both begin and end will be represented by this, so Expr
// without subexpressions begin() equals end() automatically
bool const_ = false; // assert in nonconst ops
};
virtual cursor begin_cursor() { return cursor{}; }
virtual cursor end_cursor() { return cursor{}; }
virtual cursor begin_cursor() const { return cursor{}; }
virtual cursor end_cursor() const { return cursor{}; }
mutable std::optional<hash_type> hash_value_; // not initialized by default
virtual hash_type memoizing_hash() const {
static const hash_type default_hash_value = 0;
if (hash_value_)
return *hash_value_;
else
return default_hash_value;
}
virtual void reset_hash_value() const { hash_value_.reset(); }
virtual bool static_equal(const Expr &that) const
#if __GNUG__
{
abort();
}
#else
= 0;
#endif
virtual bool static_less_than(const Expr &that) const {
return this->hash_value() < that.hash_value();
}
virtual bool commutes_with_atom(const Expr &that) const { return true; }
private:
static type_id_type get_next_type_id() {
static std::atomic<type_id_type> grand_type_id = 0;
return ++grand_type_id;
};
template <typename T>
static type_id_type &type_id_accessor() {
static type_id_type type_id = get_next_type_id();
return type_id;
};
private:
std::logic_error not_implemented(const char *fn) const;
}; // class Expr
template <>
struct Expr::is_shared_ptr_of_expr<ExprPtr, void> : std::true_type {};
template <>
struct Expr::is_shared_ptr_of_expr_or_derived<ExprPtr, void> : std::true_type {
};
inline bool operator==(const Expr &a, const Expr &b) {
if (a.type_id() != b.type_id())
return false;
else
return a.static_equal(b);
}
#if __cplusplus < 202002L
inline bool operator!=(const Expr &a, const Expr &b) { return !(a == b); }
#endif // __cplusplus < 202002L
template <typename T, typename... Args>
ExprPtr ex(Args &&...args) {
return std::make_shared<T>(std::forward<Args>(args)...);
}
// this is needed when using std::make_shared<X>({ExprPtr,ExprPtr}), i.e. must
// std::make_shared<X>(ExprPtrList{ExprPtr,ExprPtr})
using ExprPtrList = std::initializer_list<ExprPtr>;
static auto expr_ptr_comparer = [](const auto &ptr1, const auto &ptr2) {
return *ptr1 == *ptr2;
};
using ExprPtrVector = container::svector<ExprPtr>;
ExprPtr adjoint(const ExprPtr &expr);
static const wchar_t adjoint_label = L'\u207A';
class Labeled {
public:
Labeled() = default;
virtual ~Labeled() = default;
virtual std::wstring_view label() const = 0;
};
class Constant : public Expr {
public:
using scalar_type = Complex<sequant::rational>;
private:
scalar_type value_;
public:
Constant() = delete;
virtual ~Constant() = default;
Constant(const Constant &) = default;
Constant(Constant &&) = default;
Constant &operator=(const Constant &) = default;
Constant &operator=(Constant &&) = default;
template <typename U, typename = std::enable_if_t<!is_constant_v<U>>>
explicit Constant(U &&value) : value_(std::forward<U>(value)) {}
private:
template <typename X>
static X numeric_cast(const sequant::rational &r) {
if constexpr (std::is_integral_v<X>) {
assert(denominator(r) == 1);
return boost::numeric_cast<X>(numerator(r));
} else {
return boost::numeric_cast<X>(numerator(r)) /
boost::numeric_cast<X>(denominator(r));
}
};
public:
template <typename T = decltype(value_)>
auto value() const {
if constexpr (std::is_arithmetic_v<T>) {
assert(value_.imag() == 0);
return numeric_cast<T>(value_.real());
} else if constexpr (meta::is_complex_v<T>) {
return T(numeric_cast<typename T::value_type>(value_.real()),
numeric_cast<typename T::value_type>(value_.imag()));
} else
throw std::invalid_argument(
"Constant::value<T>: cannot convert value to type T");
}
std::wstring to_latex() const override {
return L"{" + sequant::to_latex(value()) + L"}";
}
std::wstring to_wolfram() const override {
return sequant::to_wolfram(value());
}
type_id_type type_id() const override { return get_type_id<Constant>(); }
ExprPtr clone() const override { return ex<Constant>(this->value()); }
virtual void adjoint() override;
virtual Expr &operator*=(const Expr &that) override {
if (that.is<Constant>()) {
value_ *= that.as<Constant>().value();
} else {
throw std::logic_error("Constant::operator*=(that): not valid for that");
}
return *this;
}
virtual Expr &operator+=(const Expr &that) override {
if (that.is<Constant>()) {
value_ += that.as<Constant>().value();
} else {
throw std::logic_error("Constant::operator+=(that): not valid for that");
}
return *this;
}
virtual Expr &operator-=(const Expr &that) override {
if (that.is<Constant>()) {
value_ -= that.as<Constant>().value();
} else {
throw std::logic_error("Constant::operator-=(that): not valid for that");
}
return *this;
}
static bool is_zero(scalar_type v) { return v.is_zero(); }
bool is_zero() const { return is_zero(this->value()); }
private:
hash_type memoizing_hash() const override {
hash_value_ = hash::value(value_);
return *hash_value_;
}
bool static_equal(const Expr &that) const override {
return value() == static_cast<const Constant &>(that).value();
}
}; // class Constant
class Variable : public Expr, public Labeled {
public:
Variable() = delete;
virtual ~Variable() = default;
Variable(const Variable &) = default;
Variable(Variable &&) = default;
Variable &operator=(const Variable &) = default;
Variable &operator=(Variable &&) = default;
template <typename U, typename = std::enable_if_t<!is_variable_v<U>>>
explicit Variable(U &&label) : label_(std::forward<U>(label)) {}
Variable(std::wstring label) : label_(std::move(label)), conjugated_(false) {}
std::wstring_view label() const override;
void conjugate();
bool conjugated() const;
std::wstring to_latex() const override;
type_id_type type_id() const override { return get_type_id<Variable>(); }
ExprPtr clone() const override;
virtual void adjoint() override;
private:
std::wstring label_;
bool conjugated_ = false;
hash_type memoizing_hash() const override {
hash_value_ = hash::value(label_);
hash::combine(hash_value_.value(), conjugated_);
return *hash_value_;
}
bool static_equal(const Expr &that) const override {
return label_ == static_cast<const Variable &>(that).label_ &&
conjugated_ == static_cast<const Variable &>(that).conjugated_;
}
}; // class Variable
class Product : public Expr {
public:
enum class Flatten { Once, Recursively, Yes = Recursively, No };
using scalar_type = Constant::scalar_type;
Product() = default;
virtual ~Product() = default;
Product(const Product &) = default;
Product(Product &&) = default;
Product &operator=(const Product &) = default;
Product &operator=(Product &&) = default;
Product(ExprPtrList factors, Flatten flatten_tag = Flatten::Yes) {
using std::begin;
using std::end;
for (auto it = begin(factors); it != end(factors); ++it)
append(1, *it, flatten_tag);
}
template <typename Range,
typename = std::enable_if_t<meta::is_range_v<std::decay_t<Range>> &&
!meta::is_same_v<Range, ExprPtrList> &&
!meta::is_same_v<Range, Product>>>
explicit Product(Range &&rng, Flatten flatten_tag = Flatten::Yes) {
using ranges::begin;
using ranges::end;
for (auto &&v : rng) append(1, std::forward<decltype(v)>(v), flatten_tag);
}
template <typename T, typename Range,
typename = std::enable_if_t<
meta::is_range_v<std::decay_t<Range>> &&
!std::is_same_v<std::remove_reference_t<Range>, ExprPtrList> &&
!std::is_same_v<std::remove_reference_t<Range>, Product>>>
explicit Product(T scalar, Range &&rng, Flatten flatten_tag = Flatten::Yes)
: scalar_(std::move(scalar)) {
using ranges::begin;
using ranges::end;
for (auto &&v : rng) append(1, std::forward<decltype(v)>(v), flatten_tag);
}
template <typename T>
Product(T scalar, ExprPtrList factors, Flatten flatten_tag = Flatten::Yes)
: scalar_(std::move(scalar)) {
using std::begin;
using std::end;
for (auto it = begin(factors); it != end(factors); ++it)
append(1, *it, flatten_tag);
}
template <typename Iterator>
Product(Iterator begin, Iterator end, Flatten flatten_tag = Flatten::Yes) {
for (auto it = begin; it != end; ++it) append(1, *it, flatten_tag);
}
template <typename T, typename Iterator>
Product(T scalar, Iterator begin, Iterator end,
Flatten flatten_tag = Flatten::Yes)
: scalar_(std::move(scalar)) {
for (auto it = begin; it != end; ++it) append(1, *it, flatten_tag);
}
template <typename T>
Product &scale(T scalar) {
scalar_ *= scalar;
return *this;
}
template <typename T>
Product &append(T scalar, ExprPtr factor,
Flatten flatten_tag = Flatten::Yes) {
assert(factor);
scalar_ *= scalar;
if (!factor->is<Product>()) {
if (factor->is<Constant>()) { // factor in Constant
auto factor_constant = factor->as<Constant>();
scalar_ *= factor_constant.value();
// no need to reset the hash since scalar is not hashed!
} else {
factors_.push_back(factor->clone());
reset_hash_value();
}
} else { // factor is a product also ..
if (flatten_tag != Flatten::No) { // flatten, once or recursively
const auto &factor_product = factor->as<Product>();
scalar_ *= factor_product.scalar_;
for (auto &&subfactor : factor_product)
this->append(1, subfactor,
flatten_tag == Flatten::Once ? Flatten::No
: Flatten::Recursively);
} else {
factors_.push_back(factor->clone());
reset_hash_value();
}
}
return *this;
}
template <typename T, typename Factor,
typename = std::enable_if_t<is_an_expr_v<Factor>>>
Product &append(T scalar, Factor &&factor,
Flatten flatten_tag = Flatten::Yes) {
return this->append(scalar,
std::static_pointer_cast<Expr>(
std::forward<Factor>(factor).shared_from_this()),
flatten_tag);
}
template <typename T>
Product &prepend(T scalar, ExprPtr factor,
Flatten flatten_tag = Flatten::Yes) {
assert(factor);
scalar_ *= scalar;
if (!factor->is<Product>()) {
if (factor->is<Constant>()) { // factor in Constant
auto factor_constant = std::static_pointer_cast<Constant>(factor);
scalar_ *= factor_constant->value();
// no need to reset the hash since scalar is not hashed!
} else {
factors_.insert(factors_.begin(), factor->clone());
reset_hash_value();
}
} else { // factor is a product also ... flatten recursively
const auto &factor_product = factor->as<Product>();
scalar_ *= factor_product.scalar_;
if (flatten_tag != Flatten::No) { // flatten, once or recursively
for (auto &&subfactor : factor_product)
this->prepend(1, subfactor,
flatten_tag == Flatten::Once ? Flatten::No
: Flatten::Recursively);
} else {
factors_.insert(factors_.begin(), factor->clone());
reset_hash_value();
}
}
return *this;
}
template <typename T, typename Factor,
typename = std::enable_if_t<is_an_expr_v<Factor>>>
Product &prepend(T scalar, Factor &&factor,
Flatten flatten_tag = Flatten::Yes) {
return this->prepend(scalar,
std::static_pointer_cast<Expr>(
std::forward<Factor>(factor).shared_from_this()),
flatten_tag);
}
const auto &scalar() const { return scalar_; }
bool is_zero() const { return Constant::is_zero(this->scalar()); }
const auto &factors() const { return factors_; }
auto &factors() { return factors_; }
const ExprPtr &factor(size_t i) const { return factors_.at(i); }
bool empty() const { return factors_.empty(); }
virtual bool is_commutative() const;
virtual void adjoint() override;
private:
virtual bool static_commutativity() const { return false; }
public:
std::wstring to_latex() const override { return to_latex(false); }
std::wstring to_latex(bool negate) const {
std::wstring result;
result = L"{";
if (!scalar().is_zero()) {
const auto scal = negate ? -scalar() : scalar();
if (!scal.is_identity()) {
result += sequant::to_latex(scal);
}
for (const auto &i : factors()) {
if (i->is<Product>())
result += L"\\bigl(" + i->to_latex() + L"\\bigr)";
else
result += i->to_latex();
}
}
result += L"}";
return result;
}
std::wstring to_wolfram() const override {
std::wstring result =
is_commutative() ? L"Times[" : L"NonCommutativeMultiply[";
if (scalar() != decltype(scalar_)(1)) {
result += sequant::to_wolfram(scalar()) + L",";
}
const auto nfactors = factors().size();
size_t factor_count = 1;
for (const auto &i : factors()) {
result += i->to_wolfram() + (factor_count == nfactors ? L"" : L",");
++factor_count;
}
result += L"]";
return result;
}
type_id_type type_id() const override { return get_type_id<Product>(); };
ExprPtr clone() const override { return ex<Product>(this->deep_copy()); }
Product deep_copy() const {
auto cloned_factors =
factors() | ranges::views::transform([](const ExprPtr &ptr) {
return ptr ? ptr->clone() : nullptr;
});
Product result(this->scalar(), ExprPtrList{});
ranges::for_each(cloned_factors, [&](const auto &cloned_factor) {
result.append(1, std::move(cloned_factor), Flatten::No);
});
return result;
}
virtual Expr &operator*=(const Expr &that) override {
if (!that.is<Constant>()) {
this->append(1, const_cast<Expr &>(that).shared_from_this());
} else {
scalar_ *= that.as<Constant>().value();
}
return *this;
}
void add_identical(const Product &other) {
assert(this->hash_value() == other.hash_value());
scalar_ += other.scalar_;
}
void add_identical(const std::shared_ptr<Product> &other) {
assert(this->hash_value() == other->hash_value());
scalar_ += other->scalar_;
}
private:
scalar_type scalar_ = {1, 0};
container::svector<ExprPtr, 2> factors_{};
cursor begin_cursor() override {
return factors_.empty() ? Expr::begin_cursor() : cursor{&factors_[0]};
};
cursor end_cursor() override {
return factors_.empty() ? Expr::end_cursor()
: cursor{&factors_[0] + factors_.size()};
};
cursor begin_cursor() const override {
return factors_.empty() ? Expr::begin_cursor() : cursor{&factors_[0]};
};
cursor end_cursor() const override {
return factors_.empty() ? Expr::end_cursor()
: cursor{&factors_[0] + factors_.size()};
};
hash_type memoizing_hash() const override {
auto deref_factors =
factors() |
ranges::views::transform(
[](const ExprPtr &ptr) -> const Expr & { return *ptr; });
hash_value_ =
hash::range(ranges::begin(deref_factors), ranges::end(deref_factors));
return *hash_value_;
}
ExprPtr canonicalize_impl(bool rapid = false);
virtual ExprPtr canonicalize() override;
virtual ExprPtr rapid_canonicalize() override;
bool static_equal(const Expr &that) const override {
const auto &that_cast = static_cast<const Product &>(that);
if (scalar() == that_cast.scalar() &&
factors().size() == that_cast.factors().size()) {
if (this->empty()) return true;
// compare hash values first
if (this->hash_value() ==
that.hash_value()) // hash values agree -> do full comparison
return std::equal(begin_subexpr(), end_subexpr(), that.begin_subexpr(),
expr_ptr_comparer);
else
return false;
} else
return false;
}
}; // class Product
class CProduct : public Product {
public:
using Product::Product;
CProduct(const Product &other) : Product(other) {}
CProduct(Product &&other) : Product(other) {}
bool is_commutative() const override { return true; }
virtual void adjoint() override;
private:
bool static_commutativity() const override { return true; }
}; // class CProduct
class NCProduct : public Product {
public:
using Product::Product;
NCProduct(const Product &other) : Product(other) {}
NCProduct(Product &&other) : Product(other) {}
bool is_commutative() const override { return false; }
virtual void adjoint() override;
private:
bool static_commutativity() const override { return true; }
}; // class NCProduct
class Sum : public Expr {
public:
Sum() = default;
virtual ~Sum() = default;
Sum(const Sum &) = default;
Sum(Sum &&) = default;
Sum &operator=(const Sum &) = default;
Sum &operator=(Sum &&) = default;
Sum(ExprPtrList summands) {
// use append to flatten out Sum summands
for (auto &summand : summands) {
append(std::move(summand));
}
}
template <typename Iterator>
Sum(Iterator begin, Iterator end) {
// use append to flatten out Sum summands
for (auto it = begin; it != end; ++it) {
append(*it);
}
}
template <typename Range,
typename = std::enable_if_t<meta::is_range_v<std::decay_t<Range>> &&
!meta::is_same_v<Range, ExprPtrList>>>
explicit Sum(Range &&rng) {
// use append to flatten out Sum summands
for (auto &&v : rng) {
append(std::forward<decltype(v)>(v));
}
}
Sum &append(ExprPtr summand) {
assert(summand);
if (!summand->is<Sum>()) {
if (summand->is<Constant>()) { // exclude zeros, add up constants
// immediately, if possible
auto summand_constant = std::static_pointer_cast<Constant>(summand);
if (constant_summand_idx_) {
assert(summands_.at(*constant_summand_idx_)->is<Constant>());
*(summands_[*constant_summand_idx_]) += *summand;
} else {
if (!summand_constant->is_zero()) {
summands_.push_back(summand->clone());
constant_summand_idx_ = summands_.size() - 1;
}
}
} else {
summands_.push_back(summand->clone());
}
reset_hash_value();
} else { // this recursively flattens Sum summands
for (auto &subsummand : *summand) this->append(subsummand);
}
return *this;
}
Sum &prepend(ExprPtr summand) {
assert(summand);
if (!summand->is<Sum>()) {
if (summand->is<Constant>()) { // exclude zeros
auto summand_constant = std::static_pointer_cast<Constant>(summand);
if (constant_summand_idx_) { // add up to the existing constant ...
assert(summands_.at(*constant_summand_idx_)->is<Constant>());
*summands_[*constant_summand_idx_] += *summand_constant;
} else { // or include the nonzero constant and update
// constant_summand_idx_
if (!summand_constant->is_zero()) {
summands_.insert(summands_.begin(), summand->clone());
constant_summand_idx_ = 0;
}
}
} else {
summands_.insert(summands_.begin(), summand->clone());
if (constant_summand_idx_) // if have a constant, update its position
++*constant_summand_idx_;
}
reset_hash_value();
} else { // this recursively flattens Sum summands
for (auto &subsummand : *summand) this->prepend(subsummand);
}
return *this;
}
const auto &summands() const { return summands_; }
const ExprPtr &summand(size_t i) const { return summands_.at(i); }
ExprPtr take_n(size_t count) const {
const auto e = (count >= summands_.size() ? summands_.end()
: (summands_.begin() + count));
return ex<Sum>(summands_.begin(), e);
}
ExprPtr take_n(size_t offset, size_t count) const {
const auto offset_plus_count = offset + count;
const auto b = (offset >= summands_.size() ? summands_.end()
: (summands_.begin() + offset));
const auto e = (offset_plus_count >= summands_.size()
? summands_.end()
: (summands_.begin() + offset_plus_count));
return ex<Sum>(b, e);
}
template <typename Filter>
ExprPtr filter(Filter &&f) const {
return ex<Sum>(summands_ | ranges::views::filter(f));
}
bool empty() const { return summands_.empty(); }
std::size_t size() const { return summands_.size(); }
std::wstring to_latex() const override {
std::wstring result;
result = L"{ \\bigl(";
std::size_t counter = 0;
for (const auto &i : summands()) {
const auto i_is_product = i->is<Product>();
if (!i_is_product) {
result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
} else { // i_is_product
const auto i_prod = i->as<Product>();
const auto scalar = i_prod.scalar();
if (scalar.real() < 0 || (scalar.real() == 0 && scalar.imag() < 0)) {
result += L" - " + i_prod.to_latex(true);
} else {
result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
}
}
++counter;
}
result += L"\\bigr) }";
return result;
}
std::wstring to_wolfram() const override {
std::wstring result;
result = L"Plus[";
std::size_t counter = 0;
for (const auto &i : summands()) {
result += i->to_wolfram();
++counter;
if (counter != summands().size()) result += L",";
}
result += L"]";
return result;
}
Expr::type_id_type type_id() const override {
return Expr::get_type_id<Sum>();
};
ExprPtr clone() const override {
auto cloned_summands =
summands() | ranges::views::transform(
[](const ExprPtr &ptr) { return ptr->clone(); });
return ex<Sum>(ranges::begin(cloned_summands),
ranges::end(cloned_summands));
}
virtual void adjoint() override;
virtual Expr &operator+=(const Expr &that) override {
this->append(const_cast<Expr &>(that).shared_from_this());
return *this;
}
virtual Expr &operator-=(const Expr &that) override {
if (that.is<Constant>())
this->append(ex<Constant>(-that.as<Constant>().value()));
else
this->append(ex<Product>(
-1, ExprPtrList{const_cast<Expr &>(that).shared_from_this()}));
return *this;
}
private:
container::svector<ExprPtr, 2> summands_{};
std::optional<size_t>
constant_summand_idx_{}; // points to the constant summand, if any; used
// to sum up constants in append/prepend
cursor begin_cursor() override {
return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
};
cursor end_cursor() override {
return summands_.empty() ? Expr::end_cursor()
: cursor{&summands_[0] + summands_.size()};
};
cursor begin_cursor() const override {
return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
};
cursor end_cursor() const override {
return summands_.empty() ? Expr::end_cursor()
: cursor{&summands_[0] + summands_.size()};
};
hash_type memoizing_hash() const override {
auto deref_summands =
summands() |
ranges::views::transform(
[](const ExprPtr &ptr) -> const Expr & { return *ptr; });
hash_value_ =
hash::range(ranges::begin(deref_summands), ranges::end(deref_summands));
return *hash_value_;
}
ExprPtr canonicalize_impl(bool multipass);
virtual ExprPtr canonicalize() override { return canonicalize_impl(true); }
virtual ExprPtr rapid_canonicalize() override {
return canonicalize_impl(false);
}
bool static_equal(const Expr &that) const override {
const auto &that_cast = static_cast<const Sum &>(that);
if (summands().size() == that_cast.summands().size()) {
if (this->empty()) return true;
// compare hash values first
if (this->hash_value() ==
that.hash_value()) // hash values agree -> do full comparison
return std::equal(begin_subexpr(), end_subexpr(), that.begin_subexpr(),
expr_ptr_comparer);
else
return false;
} else
return false;
}
}; // class Sum
inline std::wstring to_latex(const ExprPtr &exprptr) {
return exprptr->to_latex();
}
inline std::wstring to_latex_align(const ExprPtr &exprptr,
size_t max_lines_per_align = 0,
size_t max_terms_per_line = 1) {
std::wstring result = to_latex(exprptr);
if (exprptr->is<Sum>()) {
result.erase(0, 7); // remove leading "{ \bigl"
result.replace(result.size() - 8, 8,
L")"); // replace trailing "\bigr) }" with ")"
result = std::wstring(L"\\begin{align}\n& ") + result;
// assume no inner sums
size_t line_counter = 0;
size_t term_counter = 0;
std::wstring::size_type pos = 0;
std::wstring::size_type plus_pos = 0;
std::wstring::size_type minus_pos = 0;
bool last_pos_has_plus = false;
bool have_next_term = true;
auto insert_into_result_at = [&](std::wstring::size_type at,
const auto &str) {
assert(pos != std::wstring::npos);
result.insert(at, str);
const auto str_nchar = std::size(str) - 1; // neglect end-of-string
pos += str_nchar;
if (plus_pos != std::wstring::npos) plus_pos += str_nchar;
if (minus_pos != std::wstring::npos) minus_pos += str_nchar;
if (pos != plus_pos) assert(plus_pos == result.find(L" + ", plus_pos));
if (pos != minus_pos) assert(minus_pos == result.find(L" - ", minus_pos));
};
while (have_next_term) {
if (max_lines_per_align > 0 &&
line_counter == max_lines_per_align) { // start new align block?
insert_into_result_at(pos + 1, L"\n\\end{align}\n\\begin{align}\n& ");
line_counter = 0;
} else {
// break the line if needed
if (term_counter != 0 && term_counter % max_terms_per_line == 0) {
insert_into_result_at(pos + 1, L"\\\\\n& ");
++line_counter;
}
}
// next term, plz
if (plus_pos == 0 || last_pos_has_plus)
plus_pos = result.find(L" + ", plus_pos + 1);
if (minus_pos == 0 || !last_pos_has_plus)
minus_pos = result.find(L" - ", minus_pos + 1);
pos = std::min(plus_pos, minus_pos);
last_pos_has_plus = (pos == plus_pos);
if (pos != std::wstring::npos)
++term_counter;
else
have_next_term = false;
}
} else {
result = std::wstring(L"\\begin{align}\n& ") + result;
}
result += L"\n\\end{align}";
return result;
}
inline std::wstring to_wolfram(const ExprPtr &exprptr) {
return exprptr->to_wolfram();
}
template <typename Sequence>
std::decay_t<Sequence> clone(Sequence &&exprseq) {
auto cloned_seq = exprseq | ranges::views::transform([](const ExprPtr &ptr) {
return ptr ? ptr->clone() : nullptr;
});
return std::decay_t<Sequence>(ranges::begin(cloned_seq),
ranges::end(cloned_seq));
}
inline std::size_t size(const Expr &expr) { return ranges::size(expr); }
inline std::size_t size(const ExprPtr &exprptr) {
if (exprptr)
return size(*exprptr);
else
return 0;
}
inline decltype(auto) begin(const ExprPtr &exprptr) {
assert(exprptr);
return ranges::begin(*exprptr);
}
inline decltype(auto) begin(ExprPtr &exprptr) {
assert(exprptr);
return ranges::begin(*exprptr);
}
inline decltype(auto) cbegin(const ExprPtr &exprptr) {
assert(exprptr);
return ranges::cbegin(*exprptr);
}
inline decltype(auto) end(const ExprPtr &exprptr) {
assert(exprptr);
return ranges::end(*exprptr);
}
inline decltype(auto) end(ExprPtr &exprptr) {
assert(exprptr);
return ranges::end(*exprptr);
}
inline decltype(auto) cend(const ExprPtr &exprptr) {
assert(exprptr);
return ranges::cend(*exprptr);
}
// finish off ExprPtr members that depend on Expr
template <typename T>
bool ExprPtr::is() const {
return as_shared_ptr()->is<T>();
}
template <typename T>
const T &ExprPtr::as() const {
return as_shared_ptr()->as<T>();
}
template <typename T>
T &ExprPtr::as() {
return as_shared_ptr()->as<T>();
}
} // namespace sequant
#endif // SEQUANT_EXPR_HPP
#include <SeQuant/core/expr_operator.hpp>
#include <SeQuant/core/expr_algorithm.hpp>