Program Listing for File expr.hpp

Return to documentation for file (SeQuant/core/expr.hpp)

//
// Created by Eduard Valeyev on 3/23/18.
//

#ifndef SEQUANT_EXPR_HPP
#define SEQUANT_EXPR_HPP

#include <SeQuant/core/expr_fwd.hpp>

#include <SeQuant/core/complex.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/latex.hpp>
#include <SeQuant/core/rational.hpp>
#include <SeQuant/core/wolfram.hpp>

#include <range/v3/all.hpp>

#include <boost/core/demangle.hpp>
#include <boost/numeric/conversion/cast.hpp>

#include <algorithm>
#include <atomic>
#include <cassert>
#include <cstdlib>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <type_traits>
#include <typeinfo>
#include <utility>

namespace sequant {

namespace {

template <typename T>
constexpr bool is_an_expr_v = meta::is_base_of_v<Expr, T>;
template <typename T>
constexpr bool is_expr_v = meta::is_same_v<Expr, T>;

template <typename T>
constexpr bool is_a_constant_v = meta::is_base_of_v<Constant, T>;
template <typename T>
constexpr bool is_constant_v = meta::is_same_v<Constant, T>;

template <typename T>
constexpr bool is_a_sum_v = meta::is_base_of_v<Sum, T>;
template <typename T>
constexpr bool is_sum_v = meta::is_same_v<Sum, T>;

template <typename T>
constexpr bool is_a_product_v = meta::is_base_of_v<Product, T>;
template <typename T>
constexpr bool is_product_v = meta::is_same_v<Product, T>;

template <typename T>
constexpr bool is_a_variable_v = meta::is_base_of_v<Variable, T>;
template <typename T>
constexpr bool is_variable_v = meta::is_same_v<Variable, T>;

}  // namespace


class ExprPtr : public std::shared_ptr<Expr> {
 public:
  using base_type = std::shared_ptr<Expr>;
  using base_type::operator->;
  using base_type::base_type;

  ExprPtr() = default;
  ExprPtr(const ExprPtr &) = default;
  ExprPtr(ExprPtr &&) = default;
  template <typename E, typename = std::enable_if_t<
                            std::is_same_v<std::remove_const_t<E>, Expr> ||
                            std::is_base_of_v<Expr, std::remove_const_t<E>>>>
  ExprPtr(const std::shared_ptr<E> &other_sptr) : base_type(other_sptr) {}
  template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
  ExprPtr(std::shared_ptr<E> &&other_sptr) : base_type(std::move(other_sptr)) {}
  template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
  ExprPtr &operator=(const std::shared_ptr<E> &other_sptr) {
    as_shared_ptr() = other_sptr;
    return *this;
  }
  template <typename E, typename = std::enable_if_t<is_an_expr_v<E>>>
  ExprPtr &operator=(std::shared_ptr<E> &&other_sptr) {
    as_shared_ptr() = std::move(other_sptr);
    return *this;
  }

  ExprPtr &operator=(const ExprPtr &) = default;
  ExprPtr &operator=(ExprPtr &&) = default;

  ~ExprPtr() = default;

  [[nodiscard]] ExprPtr clone() const &;
  [[nodiscard]] ExprPtr clone() && noexcept;

  base_type &as_shared_ptr() &;
  const base_type &as_shared_ptr() const &;
  base_type &&as_shared_ptr() &&;

  template <typename E, typename = std::enable_if_t<!is_expr_v<E>>>
  std::shared_ptr<E> as_shared_ptr() const {
    assert(this->is<E>());
    return std::static_pointer_cast<E>(this->as_shared_ptr());
  }

  Expr &operator*() &;

  const Expr &operator*() const &;

  Expr &&operator*() &&;


  ExprPtr &operator+=(const ExprPtr &other);


  ExprPtr &operator-=(const ExprPtr &);


  ExprPtr &operator*=(const ExprPtr &);

  template <typename T>
  bool is() const;

  template <typename T>
  const T &as() const;

  template <typename T>
  T &as();

  std::wstring to_latex() const;
};  // class ExprPtr

inline bool operator==(const ExprPtr &x, std::nullptr_t) {
  return x.get() == nullptr;
}

inline bool operator==(std::nullptr_t, const ExprPtr &x) {
  return x.get() == nullptr;
}


class Expr : public std::enable_shared_from_this<Expr>,
             public ranges::view_facade<Expr> {
 public:
  using range_type = ranges::view_facade<Expr>;
  using hash_type = std::size_t;
  using type_id_type = int;  // to speed up comparisons

  Expr() = default;
  virtual ~Expr() = default;

  bool is_atom() const { return ranges::empty(*this); }

  virtual std::wstring to_latex() const;

  virtual std::wstring to_wolfram() const;

  virtual ExprPtr clone() const;

  ExprPtr exprptr_from_this() {
    return static_cast<ExprPtr>(this->shared_from_this());
  }

  ExprPtr exprptr_from_this() const {
    return static_cast<const ExprPtr>(
        std::const_pointer_cast<Expr>(this->shared_from_this()));
  }

  virtual ExprPtr canonicalize() {
    return {};  // by default do nothing and return nullptr
  }

  virtual ExprPtr rapid_canonicalize() { return this->canonicalize(); }

  // clang-format off
  // clang-format on
  template <typename Visitor>
  bool visit(Visitor &&visitor, const bool atoms_only = false) {
    return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
  }

  template <typename Visitor>
  bool visit(Visitor &&visitor, const bool atoms_only = false) const {
    return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
  }

  auto begin_subexpr() { return range_type::begin(); }

  auto end_subexpr() { return range_type::end(); }

  auto begin_subexpr() const { return range_type::begin(); }

  auto end_subexpr() const { return range_type::end(); }

  Expr &expr() { return *this; }
  const Expr &expr() const { return *this; }

  template <typename T, typename Enabler = void>
  struct is_shared_ptr_of_expr : std::false_type {};
  template <typename T>
  struct is_shared_ptr_of_expr<std::shared_ptr<T>,
                               std::enable_if_t<is_expr_v<T>>>
      : std::true_type {};
  template <typename T, typename Enabler = void>
  struct is_shared_ptr_of_expr_or_derived : std::false_type {};
  template <typename T>
  struct is_shared_ptr_of_expr_or_derived<std::shared_ptr<T>,
                                          std::enable_if_t<is_an_expr_v<T>>>
      : std::true_type {};

  virtual bool is_cnumber() const {
    if (is_atom())
      return true;
    else {
      bool result = true;
      for (auto it = begin_subexpr(); result && it != end_subexpr(); ++it) {
        result &= (*it)->is_cnumber();
      }
      return result;
    }
  }

  bool commutes_with(const Expr &that) const {
    auto this_is_atom = is_atom();
    auto that_is_atom = that.is_atom();
    bool result = true;
    if (this_is_atom && that_is_atom) {
      result =
          this->is_cnumber() || that.is_cnumber() || commutes_with_atom(that);
    } else if (this_is_atom) {
      if (!this->is_cnumber()) {
        for (auto it = that.begin_subexpr(); result && it != that.end_subexpr();
             ++it) {
          result &= this->commutes_with(**it);
        }
      }
    } else {
      for (auto it = this->begin_subexpr(); result && it != this->end_subexpr();
           ++it) {
        result &= (*it)->commutes_with(that);
      }
    }
    return result;
  }

  virtual void adjoint();

  hash_type hash_value(
      std::function<hash_type(const std::shared_ptr<const Expr> &)> hasher = {})
      const {
    return hasher ? hasher(shared_from_this()) : memoizing_hash();
  }

  virtual type_id_type type_id() const
#if __GNUG__
  {
    abort();
  }
#else
      = 0;
#endif

  friend inline bool operator==(const Expr &a, const Expr &b);

  template <typename T, typename = std::enable_if<is_an_expr_v<T>>>
  bool operator<(const T &that) const {
    if (type_id() ==
        that.type_id()) {  // if same type, use generic (or type-specific, if
                           // available) comparison
      return static_less_than(static_cast<const Expr &>(that));
    } else {  // order types by type id
      return type_id() < that.type_id();
    }
  }

  template <typename T>
  static type_id_type get_type_id() {
    return type_id_accessor<T>();
  };

  template <typename T>
  static void set_type_id(type_id_type id) {
    type_id_accessor<T>() = id;
  };

  template <typename T>
  bool is() const {
    if constexpr (is_expr_v<T>)
      return true;
    else
      return this->type_id() == get_type_id<meta::remove_cvref_t<T>>();
  }

  template <typename T>
  const T &as() const {
    assert(this->is<T>());
    return static_cast<const T &>(*this);
  }

  template <typename T>
  T &as() {
    assert(this->is<T>());
    return static_cast<T &>(*this);
  }

  std::string type_name() const {
    return boost::core::demangle(typeid(*this).name());
  }


  virtual Expr &operator*=(const Expr &that);

  virtual Expr &operator^=(const Expr &that);

  virtual Expr &operator+=(const Expr &that);

  virtual Expr &operator-=(const Expr &that);


 private:
  friend ranges::range_access;

  template <typename E, typename Visitor,
            typename =
                std::enable_if_t<std::is_same_v<meta::remove_cvref_t<E>, Expr>>>
  static bool visit_impl(E &&expr, Visitor &&visitor, const bool atoms_only) {
    if (expr.weak_from_this().use_count() == 0)
      throw std::invalid_argument(
          "Expr::visit: cannot visit expressions not managed by shared_ptr");
    for (auto &subexpr_ptr : expr.expr()) {
      const auto subexpr_is_an_atom = subexpr_ptr->is_atom();
      const auto need_to_visit_subexpr = !atoms_only || subexpr_is_an_atom;
      bool visited = false;
      if (!subexpr_is_an_atom)  // if not a leaf, recur into it
        visited = visit_impl(*subexpr_ptr, std::forward<Visitor>(visitor),
                             atoms_only);
      // call on the subexpression itself, if not yet done so
      if (need_to_visit_subexpr && !visited) visitor(subexpr_ptr);
    }
    // N.B. can only visit itself if visitor is nonmutating!
    bool this_visited = false;
    if constexpr (std::is_invocable_r_v<void, std::remove_reference_t<Visitor>,
                                        const ExprPtr &>) {
      if (!atoms_only || expr.is_atom()) {
        const ExprPtr this_exprptr = expr.exprptr_from_this();
        visitor(this_exprptr);
        this_visited = true;
      }
    }
    return this_visited;
  }

 protected:
  Expr(Expr &&) = default;
  Expr(const Expr &) = default;
  Expr &operator=(Expr &&) = default;
  Expr &operator=(const Expr &) = default;

  struct cursor {
    using value_type = ExprPtr;

    cursor() = default;
    constexpr explicit cursor(ExprPtr *subexpr_ptr) noexcept
        : ptr_{subexpr_ptr} {}
    constexpr explicit cursor(const ExprPtr *subexpr_ptr) noexcept
        : ptr_{const_cast<ExprPtr *>(subexpr_ptr)}, const_{true} {}
    bool equal(const cursor &that) const { return ptr_ == that.ptr_; }
    void next() { ++ptr_; }
    void prev() { --ptr_; }
    // TODO figure out why can't return const here if want to be able to assign
    // to *begin(Expr&)
    ExprPtr &read() const {
      RANGES_EXPECT(ptr_);
      return *ptr_;
    }
    ExprPtr &read() {
      RANGES_EXPECT(const_ == false);
      RANGES_EXPECT(ptr_);
      return *ptr_;
    }
    void assign(const ExprPtr &that_ptr) {
      RANGES_EXPECT(ptr_);
      *ptr_ = that_ptr;
    }
    std::ptrdiff_t distance_to(cursor const &that) const {
      return that.ptr_ - ptr_;
    }
    void advance(std::ptrdiff_t n) { ptr_ += n; }

   private:
    ExprPtr *ptr_ =
        nullptr;  // both begin and end will be represented by this, so Expr
                  // without subexpressions begin() equals end() automatically
    bool const_ = false;  // assert in nonconst ops
  };

  virtual cursor begin_cursor() { return cursor{}; }
  virtual cursor end_cursor() { return cursor{}; }

  virtual cursor begin_cursor() const { return cursor{}; }
  virtual cursor end_cursor() const { return cursor{}; }

  mutable std::optional<hash_type> hash_value_;  // not initialized by default
  virtual hash_type memoizing_hash() const {
    static const hash_type default_hash_value = 0;
    if (hash_value_)
      return *hash_value_;
    else
      return default_hash_value;
  }
  virtual void reset_hash_value() const { hash_value_.reset(); }

  virtual bool static_equal(const Expr &that) const
#if __GNUG__
  {
    abort();
  }
#else
      = 0;
#endif

  virtual bool static_less_than(const Expr &that) const {
    return this->hash_value() < that.hash_value();
  }

  virtual bool commutes_with_atom(const Expr &that) const { return true; }

 private:
  static type_id_type get_next_type_id() {
    static std::atomic<type_id_type> grand_type_id = 0;
    return ++grand_type_id;
  };

  template <typename T>
  static type_id_type &type_id_accessor() {
    static type_id_type type_id = get_next_type_id();
    return type_id;
  };

 private:
  std::logic_error not_implemented(const char *fn) const;
};  // class Expr

template <>
struct Expr::is_shared_ptr_of_expr<ExprPtr, void> : std::true_type {};
template <>
struct Expr::is_shared_ptr_of_expr_or_derived<ExprPtr, void> : std::true_type {
};

inline bool operator==(const Expr &a, const Expr &b) {
  if (a.type_id() != b.type_id())
    return false;
  else
    return a.static_equal(b);
}

#if __cplusplus < 202002L
inline bool operator!=(const Expr &a, const Expr &b) { return !(a == b); }
#endif  // __cplusplus < 202002L

template <typename T, typename... Args>
ExprPtr ex(Args &&...args) {
  return std::make_shared<T>(std::forward<Args>(args)...);
}

// this is needed when using std::make_shared<X>({ExprPtr,ExprPtr}), i.e. must
// std::make_shared<X>(ExprPtrList{ExprPtr,ExprPtr})
using ExprPtrList = std::initializer_list<ExprPtr>;
static auto expr_ptr_comparer = [](const auto &ptr1, const auto &ptr2) {
  return *ptr1 == *ptr2;
};

using ExprPtrVector = container::svector<ExprPtr>;

ExprPtr adjoint(const ExprPtr &expr);

static const wchar_t adjoint_label = L'\u207A';

class Labeled {
 public:
  Labeled() = default;
  virtual ~Labeled() = default;

  virtual std::wstring_view label() const = 0;
};


class Constant : public Expr {
 public:
  using scalar_type = Complex<sequant::rational>;

 private:
  scalar_type value_;

 public:
  Constant() = delete;
  virtual ~Constant() = default;
  Constant(const Constant &) = default;
  Constant(Constant &&) = default;
  Constant &operator=(const Constant &) = default;
  Constant &operator=(Constant &&) = default;
  template <typename U, typename = std::enable_if_t<!is_constant_v<U>>>
  explicit Constant(U &&value) : value_(std::forward<U>(value)) {}

 private:
  template <typename X>
  static X numeric_cast(const sequant::rational &r) {
    if constexpr (std::is_integral_v<X>) {
      assert(denominator(r) == 1);
      return boost::numeric_cast<X>(numerator(r));
    } else {
      return boost::numeric_cast<X>(numerator(r)) /
             boost::numeric_cast<X>(denominator(r));
    }
  };

 public:
  template <typename T = decltype(value_)>
  auto value() const {
    if constexpr (std::is_arithmetic_v<T>) {
      assert(value_.imag() == 0);
      return numeric_cast<T>(value_.real());
    } else if constexpr (meta::is_complex_v<T>) {
      return T(numeric_cast<typename T::value_type>(value_.real()),
               numeric_cast<typename T::value_type>(value_.imag()));
    } else
      throw std::invalid_argument(
          "Constant::value<T>: cannot convert value to type T");
  }

  std::wstring to_latex() const override {
    return L"{" + sequant::to_latex(value()) + L"}";
  }

  std::wstring to_wolfram() const override {
    return sequant::to_wolfram(value());
  }

  type_id_type type_id() const override { return get_type_id<Constant>(); }

  ExprPtr clone() const override { return ex<Constant>(this->value()); }

  virtual void adjoint() override;

  virtual Expr &operator*=(const Expr &that) override {
    if (that.is<Constant>()) {
      value_ *= that.as<Constant>().value();
    } else {
      throw std::logic_error("Constant::operator*=(that): not valid for that");
    }
    return *this;
  }

  virtual Expr &operator+=(const Expr &that) override {
    if (that.is<Constant>()) {
      value_ += that.as<Constant>().value();
    } else {
      throw std::logic_error("Constant::operator+=(that): not valid for that");
    }
    return *this;
  }

  virtual Expr &operator-=(const Expr &that) override {
    if (that.is<Constant>()) {
      value_ -= that.as<Constant>().value();
    } else {
      throw std::logic_error("Constant::operator-=(that): not valid for that");
    }
    return *this;
  }

  static bool is_zero(scalar_type v) { return v.is_zero(); }

  bool is_zero() const { return is_zero(this->value()); }

 private:
  hash_type memoizing_hash() const override {
    hash_value_ = hash::value(value_);
    return *hash_value_;
  }

  bool static_equal(const Expr &that) const override {
    return value() == static_cast<const Constant &>(that).value();
  }
};  // class Constant

class Variable : public Expr, public Labeled {
 public:
  Variable() = delete;
  virtual ~Variable() = default;
  Variable(const Variable &) = default;
  Variable(Variable &&) = default;
  Variable &operator=(const Variable &) = default;
  Variable &operator=(Variable &&) = default;
  template <typename U, typename = std::enable_if_t<!is_variable_v<U>>>
  explicit Variable(U &&label) : label_(std::forward<U>(label)) {}

  Variable(std::wstring label) : label_(std::move(label)), conjugated_(false) {}

  std::wstring_view label() const override;

  void conjugate();

  bool conjugated() const;

  std::wstring to_latex() const override;

  type_id_type type_id() const override { return get_type_id<Variable>(); }

  ExprPtr clone() const override;

  virtual void adjoint() override;

 private:
  std::wstring label_;
  bool conjugated_ = false;

  hash_type memoizing_hash() const override {
    hash_value_ = hash::value(label_);
    hash::combine(hash_value_.value(), conjugated_);
    return *hash_value_;
  }

  bool static_equal(const Expr &that) const override {
    return label_ == static_cast<const Variable &>(that).label_ &&
           conjugated_ == static_cast<const Variable &>(that).conjugated_;
  }
};  // class Variable

class Product : public Expr {
 public:
  enum class Flatten { Once, Recursively, Yes = Recursively, No };

  using scalar_type = Constant::scalar_type;

  Product() = default;
  virtual ~Product() = default;
  Product(const Product &) = default;
  Product(Product &&) = default;
  Product &operator=(const Product &) = default;
  Product &operator=(Product &&) = default;

  Product(ExprPtrList factors, Flatten flatten_tag = Flatten::Yes) {
    using std::begin;
    using std::end;
    for (auto it = begin(factors); it != end(factors); ++it)
      append(1, *it, flatten_tag);
  }

  template <typename Range,
            typename = std::enable_if_t<meta::is_range_v<std::decay_t<Range>> &&
                                        !meta::is_same_v<Range, ExprPtrList> &&
                                        !meta::is_same_v<Range, Product>>>
  explicit Product(Range &&rng, Flatten flatten_tag = Flatten::Yes) {
    using ranges::begin;
    using ranges::end;
    for (auto &&v : rng) append(1, std::forward<decltype(v)>(v), flatten_tag);
  }

  template <typename T, typename Range,
            typename = std::enable_if_t<
                meta::is_range_v<std::decay_t<Range>> &&
                !std::is_same_v<std::remove_reference_t<Range>, ExprPtrList> &&
                !std::is_same_v<std::remove_reference_t<Range>, Product>>>
  explicit Product(T scalar, Range &&rng, Flatten flatten_tag = Flatten::Yes)
      : scalar_(std::move(scalar)) {
    using ranges::begin;
    using ranges::end;
    for (auto &&v : rng) append(1, std::forward<decltype(v)>(v), flatten_tag);
  }

  template <typename T>
  Product(T scalar, ExprPtrList factors, Flatten flatten_tag = Flatten::Yes)
      : scalar_(std::move(scalar)) {
    using std::begin;
    using std::end;
    for (auto it = begin(factors); it != end(factors); ++it)
      append(1, *it, flatten_tag);
  }

  template <typename Iterator>
  Product(Iterator begin, Iterator end, Flatten flatten_tag = Flatten::Yes) {
    for (auto it = begin; it != end; ++it) append(1, *it, flatten_tag);
  }

  template <typename T, typename Iterator>
  Product(T scalar, Iterator begin, Iterator end,
          Flatten flatten_tag = Flatten::Yes)
      : scalar_(std::move(scalar)) {
    for (auto it = begin; it != end; ++it) append(1, *it, flatten_tag);
  }

  template <typename T>
  Product &scale(T scalar) {
    scalar_ *= scalar;
    return *this;
  }

  template <typename T>
  Product &append(T scalar, ExprPtr factor,
                  Flatten flatten_tag = Flatten::Yes) {
    assert(factor);
    scalar_ *= scalar;
    if (!factor->is<Product>()) {
      if (factor->is<Constant>()) {  // factor in Constant
        auto factor_constant = factor->as<Constant>();
        scalar_ *= factor_constant.value();
        // no need to reset the hash since scalar is not hashed!
      } else {
        factors_.push_back(factor->clone());
        reset_hash_value();
      }
    } else {                             // factor is a product also ..
      if (flatten_tag != Flatten::No) {  // flatten, once or recursively
        const auto &factor_product = factor->as<Product>();
        scalar_ *= factor_product.scalar_;
        for (auto &&subfactor : factor_product)
          this->append(1, subfactor,
                       flatten_tag == Flatten::Once ? Flatten::No
                                                    : Flatten::Recursively);
      } else {
        factors_.push_back(factor->clone());
        reset_hash_value();
      }
    }
    return *this;
  }

  template <typename T, typename Factor,
            typename = std::enable_if_t<is_an_expr_v<Factor>>>
  Product &append(T scalar, Factor &&factor,
                  Flatten flatten_tag = Flatten::Yes) {
    return this->append(scalar,
                        std::static_pointer_cast<Expr>(
                            std::forward<Factor>(factor).shared_from_this()),
                        flatten_tag);
  }

  template <typename T>
  Product &prepend(T scalar, ExprPtr factor,
                   Flatten flatten_tag = Flatten::Yes) {
    assert(factor);
    scalar_ *= scalar;
    if (!factor->is<Product>()) {
      if (factor->is<Constant>()) {  // factor in Constant
        auto factor_constant = std::static_pointer_cast<Constant>(factor);
        scalar_ *= factor_constant->value();
        // no need to reset the hash since scalar is not hashed!
      } else {
        factors_.insert(factors_.begin(), factor->clone());
        reset_hash_value();
      }
    } else {  // factor is a product also  ... flatten recursively
      const auto &factor_product = factor->as<Product>();
      scalar_ *= factor_product.scalar_;
      if (flatten_tag != Flatten::No) {  // flatten, once or recursively
        for (auto &&subfactor : factor_product)
          this->prepend(1, subfactor,
                        flatten_tag == Flatten::Once ? Flatten::No
                                                     : Flatten::Recursively);
      } else {
        factors_.insert(factors_.begin(), factor->clone());
        reset_hash_value();
      }
    }
    return *this;
  }

  template <typename T, typename Factor,
            typename = std::enable_if_t<is_an_expr_v<Factor>>>
  Product &prepend(T scalar, Factor &&factor,
                   Flatten flatten_tag = Flatten::Yes) {
    return this->prepend(scalar,
                         std::static_pointer_cast<Expr>(
                             std::forward<Factor>(factor).shared_from_this()),
                         flatten_tag);
  }

  const auto &scalar() const { return scalar_; }

  bool is_zero() const { return Constant::is_zero(this->scalar()); }

  const auto &factors() const { return factors_; }
  auto &factors() { return factors_; }

  const ExprPtr &factor(size_t i) const { return factors_.at(i); }

  bool empty() const { return factors_.empty(); }

  virtual bool is_commutative() const;

  virtual void adjoint() override;

 private:
  virtual bool static_commutativity() const { return false; }

 public:
  std::wstring to_latex() const override { return to_latex(false); }

  std::wstring to_latex(bool negate) const {
    std::wstring result;
    result = L"{";
    if (!scalar().is_zero()) {
      const auto scal = negate ? -scalar() : scalar();
      if (!scal.is_identity()) {
        result += sequant::to_latex(scal);
      }
      for (const auto &i : factors()) {
        if (i->is<Product>())
          result += L"\\bigl(" + i->to_latex() + L"\\bigr)";
        else
          result += i->to_latex();
      }
    }
    result += L"}";
    return result;
  }

  std::wstring to_wolfram() const override {
    std::wstring result =
        is_commutative() ? L"Times[" : L"NonCommutativeMultiply[";
    if (scalar() != decltype(scalar_)(1)) {
      result += sequant::to_wolfram(scalar()) + L",";
    }
    const auto nfactors = factors().size();
    size_t factor_count = 1;
    for (const auto &i : factors()) {
      result += i->to_wolfram() + (factor_count == nfactors ? L"" : L",");
      ++factor_count;
    }
    result += L"]";
    return result;
  }

  type_id_type type_id() const override { return get_type_id<Product>(); };

  ExprPtr clone() const override { return ex<Product>(this->deep_copy()); }

  Product deep_copy() const {
    auto cloned_factors =
        factors() | ranges::views::transform([](const ExprPtr &ptr) {
          return ptr ? ptr->clone() : nullptr;
        });
    Product result(this->scalar(), ExprPtrList{});
    ranges::for_each(cloned_factors, [&](const auto &cloned_factor) {
      result.append(1, std::move(cloned_factor), Flatten::No);
    });
    return result;
  }

  virtual Expr &operator*=(const Expr &that) override {
    if (!that.is<Constant>()) {
      this->append(1, const_cast<Expr &>(that).shared_from_this());
    } else {
      scalar_ *= that.as<Constant>().value();
    }
    return *this;
  }

  void add_identical(const Product &other) {
    assert(this->hash_value() == other.hash_value());
    scalar_ += other.scalar_;
  }

  void add_identical(const std::shared_ptr<Product> &other) {
    assert(this->hash_value() == other->hash_value());
    scalar_ += other->scalar_;
  }

 private:
  scalar_type scalar_ = {1, 0};
  container::svector<ExprPtr, 2> factors_{};

  cursor begin_cursor() override {
    return factors_.empty() ? Expr::begin_cursor() : cursor{&factors_[0]};
  };
  cursor end_cursor() override {
    return factors_.empty() ? Expr::end_cursor()
                            : cursor{&factors_[0] + factors_.size()};
  };

  cursor begin_cursor() const override {
    return factors_.empty() ? Expr::begin_cursor() : cursor{&factors_[0]};
  };
  cursor end_cursor() const override {
    return factors_.empty() ? Expr::end_cursor()
                            : cursor{&factors_[0] + factors_.size()};
  };

  hash_type memoizing_hash() const override {
    auto deref_factors =
        factors() |
        ranges::views::transform(
            [](const ExprPtr &ptr) -> const Expr & { return *ptr; });
    hash_value_ =
        hash::range(ranges::begin(deref_factors), ranges::end(deref_factors));
    return *hash_value_;
  }

  ExprPtr canonicalize_impl(bool rapid = false);
  virtual ExprPtr canonicalize() override;
  virtual ExprPtr rapid_canonicalize() override;

  bool static_equal(const Expr &that) const override {
    const auto &that_cast = static_cast<const Product &>(that);
    if (scalar() == that_cast.scalar() &&
        factors().size() == that_cast.factors().size()) {
      if (this->empty()) return true;
      // compare hash values first
      if (this->hash_value() ==
          that.hash_value())  // hash values agree -> do full comparison
        return std::equal(begin_subexpr(), end_subexpr(), that.begin_subexpr(),
                          expr_ptr_comparer);
      else
        return false;
    } else
      return false;
  }
};  // class Product

class CProduct : public Product {
 public:
  using Product::Product;
  CProduct(const Product &other) : Product(other) {}
  CProduct(Product &&other) : Product(other) {}

  bool is_commutative() const override { return true; }

  virtual void adjoint() override;

 private:
  bool static_commutativity() const override { return true; }
};  // class CProduct

class NCProduct : public Product {
 public:
  using Product::Product;
  NCProduct(const Product &other) : Product(other) {}
  NCProduct(Product &&other) : Product(other) {}

  bool is_commutative() const override { return false; }

  virtual void adjoint() override;

 private:
  bool static_commutativity() const override { return true; }
};  // class NCProduct


class Sum : public Expr {
 public:
  Sum() = default;
  virtual ~Sum() = default;
  Sum(const Sum &) = default;
  Sum(Sum &&) = default;
  Sum &operator=(const Sum &) = default;
  Sum &operator=(Sum &&) = default;

  Sum(ExprPtrList summands) {
    // use append to flatten out Sum summands
    for (auto &summand : summands) {
      append(std::move(summand));
    }
  }

  template <typename Iterator>
  Sum(Iterator begin, Iterator end) {
    // use append to flatten out Sum summands
    for (auto it = begin; it != end; ++it) {
      append(*it);
    }
  }

  template <typename Range,
            typename = std::enable_if_t<meta::is_range_v<std::decay_t<Range>> &&
                                        !meta::is_same_v<Range, ExprPtrList>>>
  explicit Sum(Range &&rng) {
    // use append to flatten out Sum summands
    for (auto &&v : rng) {
      append(std::forward<decltype(v)>(v));
    }
  }

  Sum &append(ExprPtr summand) {
    assert(summand);
    if (!summand->is<Sum>()) {
      if (summand->is<Constant>()) {  // exclude zeros, add up constants
                                      // immediately, if possible
        auto summand_constant = std::static_pointer_cast<Constant>(summand);
        if (constant_summand_idx_) {
          assert(summands_.at(*constant_summand_idx_)->is<Constant>());
          *(summands_[*constant_summand_idx_]) += *summand;
        } else {
          if (!summand_constant->is_zero()) {
            summands_.push_back(summand->clone());
            constant_summand_idx_ = summands_.size() - 1;
          }
        }
      } else {
        summands_.push_back(summand->clone());
      }
      reset_hash_value();
    } else {  // this recursively flattens Sum summands
      for (auto &subsummand : *summand) this->append(subsummand);
    }
    return *this;
  }

  Sum &prepend(ExprPtr summand) {
    assert(summand);
    if (!summand->is<Sum>()) {
      if (summand->is<Constant>()) {  // exclude zeros
        auto summand_constant = std::static_pointer_cast<Constant>(summand);
        if (constant_summand_idx_) {  // add up to the existing constant ...
          assert(summands_.at(*constant_summand_idx_)->is<Constant>());
          *summands_[*constant_summand_idx_] += *summand_constant;
        } else {  // or include the nonzero constant and update
                  // constant_summand_idx_
          if (!summand_constant->is_zero()) {
            summands_.insert(summands_.begin(), summand->clone());
            constant_summand_idx_ = 0;
          }
        }
      } else {
        summands_.insert(summands_.begin(), summand->clone());
        if (constant_summand_idx_)  // if have a constant, update its position
          ++*constant_summand_idx_;
      }
      reset_hash_value();
    } else {  // this recursively flattens Sum summands
      for (auto &subsummand : *summand) this->prepend(subsummand);
    }
    return *this;
  }

  const auto &summands() const { return summands_; }

  const ExprPtr &summand(size_t i) const { return summands_.at(i); }

  ExprPtr take_n(size_t count) const {
    const auto e = (count >= summands_.size() ? summands_.end()
                                              : (summands_.begin() + count));
    return ex<Sum>(summands_.begin(), e);
  }

  ExprPtr take_n(size_t offset, size_t count) const {
    const auto offset_plus_count = offset + count;
    const auto b = (offset >= summands_.size() ? summands_.end()
                                               : (summands_.begin() + offset));
    const auto e = (offset_plus_count >= summands_.size()
                        ? summands_.end()
                        : (summands_.begin() + offset_plus_count));
    return ex<Sum>(b, e);
  }

  template <typename Filter>
  ExprPtr filter(Filter &&f) const {
    return ex<Sum>(summands_ | ranges::views::filter(f));
  }

  bool empty() const { return summands_.empty(); }

  std::size_t size() const { return summands_.size(); }

  std::wstring to_latex() const override {
    std::wstring result;
    result = L"{ \\bigl(";
    std::size_t counter = 0;
    for (const auto &i : summands()) {
      const auto i_is_product = i->is<Product>();
      if (!i_is_product) {
        result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
      } else {  // i_is_product
        const auto i_prod = i->as<Product>();
        const auto scalar = i_prod.scalar();
        if (scalar.real() < 0 || (scalar.real() == 0 && scalar.imag() < 0)) {
          result += L" - " + i_prod.to_latex(true);
        } else {
          result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
        }
      }
      ++counter;
    }
    result += L"\\bigr) }";
    return result;
  }

  std::wstring to_wolfram() const override {
    std::wstring result;
    result = L"Plus[";
    std::size_t counter = 0;
    for (const auto &i : summands()) {
      result += i->to_wolfram();
      ++counter;
      if (counter != summands().size()) result += L",";
    }
    result += L"]";
    return result;
  }

  Expr::type_id_type type_id() const override {
    return Expr::get_type_id<Sum>();
  };

  ExprPtr clone() const override {
    auto cloned_summands =
        summands() | ranges::views::transform(
                         [](const ExprPtr &ptr) { return ptr->clone(); });
    return ex<Sum>(ranges::begin(cloned_summands),
                   ranges::end(cloned_summands));
  }

  virtual void adjoint() override;

  virtual Expr &operator+=(const Expr &that) override {
    this->append(const_cast<Expr &>(that).shared_from_this());
    return *this;
  }

  virtual Expr &operator-=(const Expr &that) override {
    if (that.is<Constant>())
      this->append(ex<Constant>(-that.as<Constant>().value()));
    else
      this->append(ex<Product>(
          -1, ExprPtrList{const_cast<Expr &>(that).shared_from_this()}));
    return *this;
  }

 private:
  container::svector<ExprPtr, 2> summands_{};
  std::optional<size_t>
      constant_summand_idx_{};  // points to the constant summand, if any; used
                                // to sum up constants in append/prepend

  cursor begin_cursor() override {
    return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
  };
  cursor end_cursor() override {
    return summands_.empty() ? Expr::end_cursor()
                             : cursor{&summands_[0] + summands_.size()};
  };
  cursor begin_cursor() const override {
    return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
  };
  cursor end_cursor() const override {
    return summands_.empty() ? Expr::end_cursor()
                             : cursor{&summands_[0] + summands_.size()};
  };

  hash_type memoizing_hash() const override {
    auto deref_summands =
        summands() |
        ranges::views::transform(
            [](const ExprPtr &ptr) -> const Expr & { return *ptr; });
    hash_value_ =
        hash::range(ranges::begin(deref_summands), ranges::end(deref_summands));
    return *hash_value_;
  }

  ExprPtr canonicalize_impl(bool multipass);

  virtual ExprPtr canonicalize() override { return canonicalize_impl(true); }
  virtual ExprPtr rapid_canonicalize() override {
    return canonicalize_impl(false);
  }

  bool static_equal(const Expr &that) const override {
    const auto &that_cast = static_cast<const Sum &>(that);
    if (summands().size() == that_cast.summands().size()) {
      if (this->empty()) return true;
      // compare hash values first
      if (this->hash_value() ==
          that.hash_value())  // hash values agree -> do full comparison
        return std::equal(begin_subexpr(), end_subexpr(), that.begin_subexpr(),
                          expr_ptr_comparer);
      else
        return false;
    } else
      return false;
  }
};  // class Sum

inline std::wstring to_latex(const ExprPtr &exprptr) {
  return exprptr->to_latex();
}

inline std::wstring to_latex_align(const ExprPtr &exprptr,
                                   size_t max_lines_per_align = 0,
                                   size_t max_terms_per_line = 1) {
  std::wstring result = to_latex(exprptr);
  if (exprptr->is<Sum>()) {
    result.erase(0, 7);  // remove leading  "{ \bigl"
    result.replace(result.size() - 8, 8,
                   L")");  // replace trailing "\bigr) }" with ")"
    result = std::wstring(L"\\begin{align}\n& ") + result;
    // assume no inner sums
    size_t line_counter = 0;
    size_t term_counter = 0;
    std::wstring::size_type pos = 0;
    std::wstring::size_type plus_pos = 0;
    std::wstring::size_type minus_pos = 0;
    bool last_pos_has_plus = false;
    bool have_next_term = true;
    auto insert_into_result_at = [&](std::wstring::size_type at,
                                     const auto &str) {
      assert(pos != std::wstring::npos);
      result.insert(at, str);
      const auto str_nchar = std::size(str) - 1;  // neglect end-of-string
      pos += str_nchar;
      if (plus_pos != std::wstring::npos) plus_pos += str_nchar;
      if (minus_pos != std::wstring::npos) minus_pos += str_nchar;
      if (pos != plus_pos) assert(plus_pos == result.find(L" + ", plus_pos));
      if (pos != minus_pos) assert(minus_pos == result.find(L" - ", minus_pos));
    };
    while (have_next_term) {
      if (max_lines_per_align > 0 &&
          line_counter == max_lines_per_align) {  // start new align block?
        insert_into_result_at(pos + 1, L"\n\\end{align}\n\\begin{align}\n& ");
        line_counter = 0;
      } else {
        // break the line if needed
        if (term_counter != 0 && term_counter % max_terms_per_line == 0) {
          insert_into_result_at(pos + 1, L"\\\\\n& ");
          ++line_counter;
        }
      }
      // next term, plz
      if (plus_pos == 0 || last_pos_has_plus)
        plus_pos = result.find(L" + ", plus_pos + 1);
      if (minus_pos == 0 || !last_pos_has_plus)
        minus_pos = result.find(L" - ", minus_pos + 1);
      pos = std::min(plus_pos, minus_pos);
      last_pos_has_plus = (pos == plus_pos);
      if (pos != std::wstring::npos)
        ++term_counter;
      else
        have_next_term = false;
    }
  } else {
    result = std::wstring(L"\\begin{align}\n& ") + result;
  }
  result += L"\n\\end{align}";
  return result;
}

inline std::wstring to_wolfram(const ExprPtr &exprptr) {
  return exprptr->to_wolfram();
}

template <typename Sequence>
std::decay_t<Sequence> clone(Sequence &&exprseq) {
  auto cloned_seq = exprseq | ranges::views::transform([](const ExprPtr &ptr) {
                      return ptr ? ptr->clone() : nullptr;
                    });
  return std::decay_t<Sequence>(ranges::begin(cloned_seq),
                                ranges::end(cloned_seq));
}

inline std::size_t size(const Expr &expr) { return ranges::size(expr); }

inline std::size_t size(const ExprPtr &exprptr) {
  if (exprptr)
    return size(*exprptr);
  else
    return 0;
}

inline decltype(auto) begin(const ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::begin(*exprptr);
}

inline decltype(auto) begin(ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::begin(*exprptr);
}

inline decltype(auto) cbegin(const ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::cbegin(*exprptr);
}

inline decltype(auto) end(const ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::end(*exprptr);
}

inline decltype(auto) end(ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::end(*exprptr);
}

inline decltype(auto) cend(const ExprPtr &exprptr) {
  assert(exprptr);
  return ranges::cend(*exprptr);
}

// finish off ExprPtr members that depend on Expr

template <typename T>
bool ExprPtr::is() const {
  return as_shared_ptr()->is<T>();
}

template <typename T>
const T &ExprPtr::as() const {
  return as_shared_ptr()->as<T>();
}

template <typename T>
T &ExprPtr::as() {
  return as_shared_ptr()->as<T>();
}

}  // namespace sequant

#endif  // SEQUANT_EXPR_HPP

#include <SeQuant/core/expr_operator.hpp>

#include <SeQuant/core/expr_algorithm.hpp>