Program Listing for File expr.hpp

Return to documentation for file (SeQuant/core/expressions/expr.hpp)

#ifndef SEQUANT_EXPRESSIONS_EXPR_HPP
#define SEQUANT_EXPRESSIONS_EXPR_HPP

#include <SeQuant/core/expressions/expr_ptr.hpp>
#include <SeQuant/core/options.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <boost/core/demangle.hpp>

#include <atomic>
#include <memory>
#include <optional>

#include <range/v3/all.hpp>

namespace sequant {

static const wchar_t adjoint_label = L'\u207A';


class Expr : public std::enable_shared_from_this<Expr>,
             public ranges::view_facade<Expr> {
 public:
  using range_type = ranges::view_facade<Expr>;
  using hash_type = std::size_t;
  using type_id_type = int;  // to speed up comparisons

  Expr() = default;
  virtual ~Expr() = default;

  bool is_atom() const { return ranges::empty(*this); }

  virtual bool is_zero() const { return false; }

  virtual std::wstring to_latex() const;

  virtual std::wstring to_wolfram() const;

  virtual ExprPtr clone() const;

  ExprPtr exprptr_from_this() {
    if (weak_from_this().use_count() == 0)
      return this->clone();
    else
      return static_cast<ExprPtr>(this->shared_from_this());
  }

  ExprPtr exprptr_from_this() const {
    if (weak_from_this().use_count() == 0)
      return this->clone();
    else
      return static_cast<const ExprPtr>(
          std::const_pointer_cast<Expr>(this->shared_from_this()));
  }

  virtual ExprPtr canonicalize(
      CanonicalizeOptions = CanonicalizeOptions::default_options()) {
    return {};  // by default do nothing and return nullptr
  }

  virtual ExprPtr rapid_canonicalize(
      CanonicalizeOptions = CanonicalizeOptions::default_options().copy_and_set(
          CanonicalizationMethod::Rapid)) {
    return this->canonicalize({.method = CanonicalizationMethod::Rapid});
  }

  // clang-format off
  // clang-format on
  template <typename Visitor>
  bool visit(Visitor &&visitor, const bool atoms_only = false) {
    return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
  }

  template <typename Visitor>
  bool visit(Visitor &&visitor, const bool atoms_only = false) const {
    return visit_impl(*this, std::forward<Visitor>(visitor), atoms_only);
  }

  auto begin_subexpr() { return range_type::begin(); }

  auto end_subexpr() { return range_type::end(); }

  auto begin_subexpr() const { return range_type::begin(); }

  auto end_subexpr() const { return range_type::end(); }

  Expr &expr() { return *this; }
  const Expr &expr() const { return *this; }

  template <typename T, typename Enabler = void>
  struct is_shared_ptr_of_expr : std::false_type {};
  template <typename T>
  struct is_shared_ptr_of_expr<std::shared_ptr<T>,
                               std::enable_if_t<is_expr_v<T>>>
      : std::true_type {};
  template <typename T, typename Enabler = void>
  struct is_shared_ptr_of_expr_or_derived : std::false_type {};
  template <typename T>
  struct is_shared_ptr_of_expr_or_derived<std::shared_ptr<T>,
                                          std::enable_if_t<is_an_expr_v<T>>>
      : std::true_type {};

  virtual bool is_cnumber() const {
    if (is_atom())
      return true;
    else {
      bool result = true;
      for (auto it = begin_subexpr(); result && it != end_subexpr(); ++it) {
        result &= (*it)->is_cnumber();
      }
      return result;
    }
  }

  bool commutes_with(const Expr &that) const {
    auto this_is_atom = is_atom();
    auto that_is_atom = that.is_atom();
    bool result = true;
    if (this_is_atom && that_is_atom) {
      result =
          this->is_cnumber() || that.is_cnumber() || commutes_with_atom(that);
    } else if (this_is_atom) {
      if (!this->is_cnumber()) {
        for (auto it = that.begin_subexpr(); result && it != that.end_subexpr();
             ++it) {
          result &= this->commutes_with(**it);
        }
      }
    } else {
      for (auto it = this->begin_subexpr(); result && it != this->end_subexpr();
           ++it) {
        result &= (*it)->commutes_with(that);
      }
    }
    return result;
  }

  virtual void adjoint();

  hash_type hash_value(
      std::function<hash_type(const std::shared_ptr<const Expr> &)> hasher = {})
      const {
    return hasher ? hasher(shared_from_this()) : memoizing_hash();
  }

  virtual type_id_type type_id() const
#if __GNUG__
  {
    abort();
  }
#else
      = 0;
#endif

  friend inline bool operator==(const Expr &a, const Expr &b);

  template <typename T, typename = std::enable_if<is_an_expr_v<T>>>
  bool operator<(const T &that) const {
    if (type_id() ==
        that.type_id()) {  // if same type, use generic (or type-specific, if
                           // available) comparison
      return static_less_than(static_cast<const Expr &>(that));
    } else {  // order types by type id
      return type_id() < that.type_id();
    }
  }

  template <typename T>
  static type_id_type get_type_id() {
    return type_id_accessor<T>();
  }

  template <typename T>
  static void set_type_id(type_id_type id) {
    type_id_accessor<T>() = id;
  }

  template <typename T>
  bool is() const {
    if constexpr (is_expr_v<T>)
      return true;
    else if constexpr (std::is_base_of_v<Expr, T>)
      return this->type_id() == get_type_id<std::remove_cvref_t<T>>();
    else
      return dynamic_cast<const T *>(this) != nullptr;
  }

  template <typename T>
  const T &as() const {
    SEQUANT_ASSERT(this->is<T>());
    if constexpr (std::is_base_of_v<Expr, T>) {
      return static_cast<const T &>(*this);
    } else
      return dynamic_cast<const T &>(*this);
  }

  template <typename T>
  T &as() {
    SEQUANT_ASSERT(this->is<T>());
    if constexpr (std::is_base_of_v<Expr, T>) {
      return static_cast<T &>(*this);
    } else
      return dynamic_cast<T &>(*this);
  }

  std::string type_name() const {
    return boost::core::demangle(typeid(*this).name());
  }


  virtual Expr &operator*=(const Expr &that);

  virtual Expr &operator^=(const Expr &that);

  virtual Expr &operator+=(const Expr &that);

  virtual Expr &operator-=(const Expr &that);


 private:
  friend ranges::range_access;

  template <
      typename E, typename Visitor,
      typename = std::enable_if_t<std::is_same_v<std::remove_cvref_t<E>, Expr>>>
  static bool visit_impl(E &&expr, Visitor &&visitor, const bool atoms_only) {
    if (expr.weak_from_this().use_count() == 0)
      throw std::invalid_argument(
          "Expr::visit: cannot visit expressions not managed by shared_ptr");
    for (auto &subexpr_ptr : expr.expr()) {
      const auto subexpr_is_an_atom = subexpr_ptr->is_atom();
      const auto need_to_visit_subexpr = !atoms_only || subexpr_is_an_atom;
      bool visited = false;
      if (!subexpr_is_an_atom)  // if not a leaf, recur into it
        visited = visit_impl(*subexpr_ptr, std::forward<Visitor>(visitor),
                             atoms_only);
      // call on the subexpression itself, if not yet done so
      if (need_to_visit_subexpr && !visited) visitor(subexpr_ptr);
    }
    // N.B. can only visit itself if visitor is nonmutating!
    bool this_visited = false;
    if constexpr (std::is_invocable_r_v<void, std::remove_reference_t<Visitor>,
                                        const ExprPtr &>) {
      if (!atoms_only || expr.is_atom()) {
        const ExprPtr this_exprptr = expr.exprptr_from_this();
        visitor(this_exprptr);
        this_visited = true;
      }
    }
    return this_visited;
  }

 protected:
  Expr(Expr &&) = default;
  Expr(const Expr &) = default;
  Expr &operator=(Expr &&) = default;
  Expr &operator=(const Expr &) = default;

  struct cursor {
    using value_type = ExprPtr;

    cursor() = default;
    constexpr explicit cursor(ExprPtr *subexpr_ptr) noexcept
        : ptr_{subexpr_ptr} {}
    constexpr explicit cursor(const ExprPtr *subexpr_ptr) noexcept
        : ptr_{const_cast<ExprPtr *>(subexpr_ptr)}, const_{true} {}
    bool equal(const cursor &that) const { return ptr_ == that.ptr_; }
    void next() { ++ptr_; }
    void prev() { --ptr_; }
    // TODO figure out why can't return const here if want to be able to assign
    // to *begin(Expr&)
    ExprPtr &read() const {
      RANGES_EXPECT(ptr_);
      return *ptr_;
    }
    ExprPtr &read() {
      RANGES_EXPECT(const_ == false);
      RANGES_EXPECT(ptr_);
      return *ptr_;
    }
    void assign(const ExprPtr &that_ptr) {
      RANGES_EXPECT(ptr_);
      *ptr_ = that_ptr;
    }
    std::ptrdiff_t distance_to(cursor const &that) const {
      return that.ptr_ - ptr_;
    }
    void advance(std::ptrdiff_t n) { ptr_ += n; }

   private:
    ExprPtr *ptr_ =
        nullptr;  // both begin and end will be represented by this, so Expr
                  // without subexpressions begin() equals end() automatically
    bool const_ = false;  // assert in nonconst ops
  };

  virtual cursor begin_cursor() { return cursor{}; }
  virtual cursor end_cursor() { return cursor{}; }

  virtual cursor begin_cursor() const { return cursor{}; }
  virtual cursor end_cursor() const { return cursor{}; }

  mutable std::optional<hash_type> hash_value_;  // not initialized by default
  virtual hash_type memoizing_hash() const {
    static const hash_type default_hash_value = 0;
    if (hash_value_)
      return *hash_value_;
    else
      return default_hash_value;
  }
  virtual void reset_hash_value() const { hash_value_.reset(); }

  virtual bool static_equal([[maybe_unused]] const Expr &that) const
#if __GNUG__
  {
    abort();
  }
#else
      = 0;
#endif

  virtual bool static_less_than(const Expr &that) const {
    return this->hash_value() < that.hash_value();
  }

  virtual bool commutes_with_atom([[maybe_unused]] const Expr &that) const {
    return true;
  }

 private:
  static type_id_type get_next_type_id() {
    static std::atomic<type_id_type> grand_type_id = 0;
    return ++grand_type_id;
  }

  template <typename T>
  static type_id_type &type_id_accessor() {
    static type_id_type type_id = get_next_type_id();
    return type_id;
  }

 private:
  std::logic_error not_implemented(const char *fn) const;
};  // class Expr

template <>
struct Expr::is_shared_ptr_of_expr<ExprPtr, void> : std::true_type {};
template <>
struct Expr::is_shared_ptr_of_expr_or_derived<ExprPtr, void> : std::true_type {
};

inline bool operator==(const Expr &a, const Expr &b) {
  if (a.type_id() != b.type_id())
    return false;
  else
    return a.static_equal(b);
}

struct proportional_to {
  bool operator()(const ExprPtr &expr1, const ExprPtr &expr2) const;
};

}  // namespace sequant

#endif  // SEQUANT_EXPRESSIONS_EXPR_HPP