Program Listing for File power.hpp

Return to documentation for file (SeQuant/core/expressions/power.hpp)

#ifndef SEQUANT_EXPRESSIONS_POWER_HPP
#define SEQUANT_EXPRESSIONS_POWER_HPP

#include <SeQuant/core/expressions/constant.hpp>
#include <SeQuant/core/expressions/expr.hpp>
#include <SeQuant/core/expressions/expr_ptr.hpp>
#include <SeQuant/core/expressions/variable.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/io/latex/latex.hpp>
#include <SeQuant/core/rational.hpp>
#include <SeQuant/core/utility/macros.hpp>

namespace sequant {

class Power : public Expr {
 public:
  using exponent_type = rational;

  Power() = delete;
  virtual ~Power() = default;
  Power(const Power&) = default;
  Power(Power&&) = default;
  Power& operator=(const Power&) = default;
  Power& operator=(Power&&) = default;

  Power(ExprPtr base, exponent_type exponent)
      : base_{}, exponent_{std::move(exponent)} {
    SEQUANT_ASSERT(base);
    SEQUANT_ASSERT(base->is<Constant>() || base->is<Variable>());
    // clone on construction so that external
    // mutations of the input cannot invalidate our memoized hash
    base_ = base->clone();
    // 0^n is defined only for n >= 0 (0^0 = 1 by convention)
    SEQUANT_ASSERT(!base_->is<Constant>() || !base_->as<Constant>().is_zero() ||
                   exponent_ >= 0);
  }

  template <typename L>
    requires std::constructible_from<std::wstring, L> &&
             (!std::convertible_to<L, ExprPtr>)
  Power(L&& label, exponent_type exponent)
      : Power(ex<Variable>(std::forward<L>(label)), std::move(exponent)) {}

  template <typename V>
    requires(!std::constructible_from<std::wstring, V> &&
             !std::convertible_to<V, ExprPtr> &&
             std::constructible_from<Constant::scalar_type, V>)
  Power(V&& value, exponent_type exponent)
      : Power(ex<Constant>(std::forward<V>(value)), std::move(exponent)) {}

  const ExprPtr& base() const { return base_; }

  const exponent_type& exponent() const { return exponent_; }

  bool conjugated() const { return conjugated_; }

  void conjugate() {
    conjugated_ = !conjugated_;
    reset_hash_value();
  }

  bool is_zero() const override {
    return exponent_ > 0 && base_->is<Constant>() &&
           base_->as<Constant>().is_zero();
  }

  static void flatten(ExprPtr& expr) {
    if (!expr || !expr->is<Power>()) return;
    const auto& pw = expr->as<Power>();

    // b^1 = b and conjugate if needed
    if (pw.exponent_ == 1) {
      auto lifted = pw.base_->clone();
      if (pw.conjugated_) lifted->adjoint();
      expr = std::move(lifted);
      return;
    }
    // b^0 = 1 for any base (the ctor rejects 0^(negative)
    if (pw.exponent_ == 0) {
      expr = ex<Constant>(Constant::scalar_type{1});
      return;
    }
    if (!pw.base_->is<Constant>()) return;

    using scalar_type = Constant::scalar_type;
    const auto& base_val = pw.base_->as<Constant>().value();

    // 1^k = 1 for any rational k.
    if (base_val == scalar_type{1}) {
      expr = ex<Constant>(scalar_type{1});
      return;
    }

    // Both remaining fold cases share one shape — `rational base raised to
    // an integer exponent` — so we normalize to that shape and run a single
    // exp-by-squaring loop. Anything else is left untouched.
    //
    // Case A: integer exponent (any Constant base, real or complex).
    //   `base` is just `base_val`.
    // Case B: half-integer exponent on a non-negative real rational base
    //   `p/q` with both `p` and `q` perfect squares. Then
    //     (p/q)^(m/2) = (sqrt(p)/sqrt(q))^m,
    //   so we replace `base` with `sqrt(p)/sqrt(q)` (still a rational) and
    //   keep `exp_int = m`.

    // initialize the base
    scalar_type base{0};
    auto exp_nr = numerator(pw.exponent_);  // numerator of exponent

    if (denominator(pw.exponent_) == 1) {
      base = base_val;
    } else if (denominator(pw.exponent_) == 2 && base_val.imag() == 0 &&
               base_val.real() >= 0) {
      intmax_t p = numerator(base_val.real());
      intmax_t q = denominator(base_val.real());  // > 0 by Boost's convention,
                                                  // sign is with the numerator

      // check for perfect squares
      intmax_t p_rem{0}, q_rem{0};
      intmax_t p_root = boost::multiprecision::sqrt(p, p_rem);
      intmax_t q_root = boost::multiprecision::sqrt(q, q_rem);
      // fold if p and q are perfect squares, else return
      if (p_rem != 0 || q_rem != 0) return;
      base = scalar_type{rational(p_root) / rational(q_root)};
    } else {
      return;
    }

    // Standard exp-by-squaring; for negative exponents we power the
    // magnitude and invert at the end.
    const bool negate = exp_nr < 0;
    if (negate) exp_nr = -exp_nr;
    scalar_type value{1};
    scalar_type b = base;
    while (exp_nr > 0) {
      if (exp_nr % 2 != 0) value *= b;
      exp_nr /= 2;
      if (exp_nr > 0) b *= b;
    }
    if (negate) value = scalar_type{1} / value;

    if (pw.conjugated_) value = conj(value);
    expr = ex<Constant>(std::move(value));
  }

  type_id_type type_id() const override { return get_type_id<Power>(); }

  bool is_scalar() const override { return true; }

  ExprPtr clone() const override {
    auto cloned = ex<Power>(base_, exponent_);
    if (conjugated_) cloned->as<Power>().conjugate();
    return cloned;
  }

  void adjoint() override { conjugate(); }

  Expr& operator*=(const Expr& that) override {
    // b^e1 *= b^e2  ->  b^(e1+e2)
    if (that.is<Power>()) {
      const auto& other = that.as<Power>();
      if (conjugated_ == other.conjugated_ && *base_ == *other.base_) {
        exponent_ += other.exponent_;
        reset_hash_value();
        return *this;
      }
    }
    // (b^e)* *= b*  ->  (b^(e+1))*
    else if (base_->is<Variable>() && that.is<Variable>()) {
      // check effective conjugation of Variable in this and that, if valid
      // operation iff they match
      const auto& base_var = base_->as<Variable>();
      const auto& that_var = that.as<Variable>();
      if (base_var.label() == that_var.label() &&
          (base_var.conjugated() ^ conjugated_) == that_var.conjugated()) {
        exponent_ += rational{1};
        reset_hash_value();
        return *this;
      }
    }
    // C^e *= C  ->  C^(e+1)
    else if (!conjugated_ && *base_ == that) {
      exponent_ += rational{1};
      reset_hash_value();
      return *this;
    }
    throw Exception("Power::operator*=(that): not valid for that");
  }

 private:
  ExprPtr base_;
  exponent_type exponent_;
  bool conjugated_ = false;

  hash_type memoizing_hash() const override {
    auto compute_hash = [this]() {
      if (exponent_ == 1 && !conjugated_) return hash::value(*base_);
      auto val = hash::value(*base_);
      hash::combine(val, hash::value(exponent_));
      hash::combine(val, conjugated_);
      return val;
    };

    if (!hash_value_) {
      hash_value_ = compute_hash();
    } else {
      SEQUANT_ASSERT(*hash_value_ == compute_hash());
    }
    return *hash_value_;
  }

  bool static_equal(const Expr& that) const override {
    const auto& other = static_cast<const Power&>(that);
    return exponent_ == other.exponent_ && conjugated_ == other.conjugated_ &&
           *base_ == *other.base_;
  }

  bool static_less_than(const Expr& that) const override {
    const auto& other = static_cast<const Power&>(that);
    if (*base_ != *other.base_) return *base_ < *other.base_;
    if (exponent_ != other.exponent_) return exponent_ < other.exponent_;
    return conjugated_ < other.conjugated_;
  }
};
}  // namespace sequant

#endif  // SEQUANT_EXPRESSIONS_POWER_HPP