Program Listing for File sum.hpp

Return to documentation for file (SeQuant/core/expressions/sum.hpp)

#ifndef SEQUANT_EXPRESSIONS_SUM_HPP
#define SEQUANT_EXPRESSIONS_SUM_HPP

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expressions/constant.hpp>
#include <SeQuant/core/expressions/expr.hpp>
#include <SeQuant/core/expressions/expr_ptr.hpp>
#include <SeQuant/core/expressions/product.hpp>
#include <SeQuant/core/meta.hpp>
#include <SeQuant/core/runtime.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <optional>
#include <type_traits>

namespace sequant {


class Sum : public Expr {
 public:
  using summands_type = container::svector<ExprPtr, 2>;

  Sum() = default;
  virtual ~Sum() = default;
  Sum(const Sum &) = default;
  Sum(Sum &&) = default;
  Sum &operator=(const Sum &) = default;
  Sum &operator=(Sum &&) = default;

  void swap(Sum &other) {
    Sum tmp = std::move(other);
    other = std::move(*this);
    *this = std::move(tmp);
  }

  Sum(ExprPtrList summands) {
    // use append to flatten out Sum summands
    for (auto &&summand : summands) {
      append(std::forward<decltype(summand)>(summand));
    }
  }

  template <typename Iterator>
  Sum(Iterator begin, Iterator end) {
    // use append to flatten out Sum summands
    for (auto it = begin; it != end; ++it) {
      append(*it);
    }
  }

  template <typename Range>
    requires(meta::is_range_v<std::remove_cvref_t<Range>> &&
             !meta::is_same_v<std::remove_cvref_t<Range>, ExprPtrList>)
  explicit Sum(Range &&rng) {
    // N.B. use append to flatten out Sum summands
    constexpr auto rng_is_expr =
        meta::is_base_of_v<Expr, std::remove_cvref_t<Range>>;
    constexpr auto rng_is_exprptr =
        meta::is_same_v<ExprPtr, std::remove_cvref_t<Range>>;
    if constexpr (rng_is_expr || rng_is_exprptr) {
      ExprPtr rng_as_exprptr;
      if constexpr (rng_is_expr) {
        rng_as_exprptr = rng.exprptr_from_this();
      } else {
        rng_as_exprptr = rng;
      }
      this->append(rng_as_exprptr);
    } else {
      for (auto &&v : rng) {
        append(std::forward<decltype(v)>(v));
      }
    }
  }

  struct move_only_tag {};

  explicit Sum(summands_type &&summands, move_only_tag)
      : summands_(std::move(summands)) {
    std::size_t pos = 0;
    for (auto it = summands_.begin(); it != summands_.end(); ++it) {
      auto &summand = *it;
      bool do_erase = false;
      if (summand->is_zero()) {
        do_erase = true;
      } else if (summand->is<Constant>()) {
        auto summand_constant = summand.as_shared_ptr<Constant>();
        if (constant_summand_idx_) {  // add up to the existing constant ...
          SEQUANT_ASSERT(summands_.at(*constant_summand_idx_)->is<Constant>());
          *summands_[*constant_summand_idx_] += *summand_constant;
          do_erase = true;
        } else {  // or memorize the position of the constant
          constant_summand_idx_ = pos;
        }
      }

      // erase if needed
      if (do_erase) {
        summands_.erase(it);
        it = summands_.begin();
        std::advance(it, pos);
      } else
        ++pos;
    }
  }

  Sum &append(ExprPtr summand) {
    SEQUANT_ASSERT(summand);
    if (!summand->is<Sum>()) {
      if (!summand->is_zero()) {        // exclude zeros
        if (summand->is<Constant>()) {  // add up constants
          // immediately, if possible
          auto summand_constant = summand.as_shared_ptr<Constant>();
          if (constant_summand_idx_) {
            SEQUANT_ASSERT(
                summands_.at(*constant_summand_idx_)->is<Constant>());
            *(summands_[*constant_summand_idx_]) += *summand;
          } else {
            summands_.push_back(summand->clone());
            constant_summand_idx_ = summands_.size() - 1;
          }
        } else {
          summands_.push_back(summand->clone());
        }
        reset_hash_value();
      }
    } else {  // this recursively flattens Sum summands
      for (auto &subsummand : *summand) this->append(subsummand);
    }
    return *this;
  }

  Sum &prepend(ExprPtr summand) {
    SEQUANT_ASSERT(summand);
    if (!summand->is<Sum>()) {
      if (!summand->is_zero()) {
        // exclude zeros
        if (summand->is<Constant>()) {
          auto summand_constant = summand.as_shared_ptr<Constant>();
          if (constant_summand_idx_) {  // add up to the existing constant ...
            SEQUANT_ASSERT(
                summands_.at(*constant_summand_idx_)->is<Constant>());
            *summands_[*constant_summand_idx_] += *summand_constant;
          } else {  // or include the nonzero constant and update
            // constant_summand_idx_
            summands_.insert(summands_.begin(), summand->clone());
            constant_summand_idx_ = 0;
          }
        } else {
          summands_.insert(summands_.begin(), summand->clone());
          if (constant_summand_idx_)  // if have a constant, update its position
            ++*constant_summand_idx_;
        }
        reset_hash_value();
      }
    } else {  // this recursively flattens Sum summands
      for (auto &subsummand : *summand) this->prepend(subsummand);
    }
    return *this;
  }

  const auto &summands() const { return summands_; }

  const ExprPtr &summand(size_t i) const { return summands_.at(i); }

  ExprPtr take_n(size_t count) const {
    const auto e = (count >= summands_.size() ? summands_.end()
                                              : (summands_.begin() + count));
    return ex<Sum>(summands_.begin(), e);
  }

  ExprPtr take_n(size_t offset, size_t count) const {
    const auto offset_plus_count = offset + count;
    const auto b = (offset >= summands_.size() ? summands_.end()
                                               : (summands_.begin() + offset));
    const auto e = (offset_plus_count >= summands_.size()
                        ? summands_.end()
                        : (summands_.begin() + offset_plus_count));
    return ex<Sum>(b, e);
  }

  template <typename Filter>
  ExprPtr filter(Filter &&f) const {
    return ex<Sum>(summands_ | ranges::views::filter(f));
  }

  bool empty() const { return summands_.empty(); }

  std::size_t size() const { return summands_.size(); }

  std::wstring to_latex() const override {
    std::wstring result;
    result = L"{ \\bigl(";
    std::size_t counter = 0;
    for (const auto &i : summands()) {
      const auto i_is_product = i->is<Product>();
      if (!i_is_product) {
        result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
      } else {  // i_is_product
        const auto i_prod = i->as<Product>();
        const auto scalar = i_prod.scalar();
        if (scalar.real() < 0 || (scalar.real() == 0 && scalar.imag() < 0)) {
          result += L" - " + i_prod.to_latex(true);
        } else {
          result += (counter == 0) ? i->to_latex() : (L" + " + i->to_latex());
        }
      }
      ++counter;
    }
    result += L"\\bigr) }";
    return result;
  }

  std::wstring to_wolfram() const override {
    std::wstring result;
    result = L"Plus[";
    std::size_t counter = 0;
    for (const auto &i : summands()) {
      result += i->to_wolfram();
      ++counter;
      if (counter != summands().size()) result += L",";
    }
    result += L"]";
    return result;
  }

  Expr::type_id_type type_id() const override {
    return Expr::get_type_id<Sum>();
  };

  ExprPtr clone() const override {
    auto cloned_summands =
        summands() | ranges::views::transform(
                         [](const ExprPtr &ptr) { return ptr->clone(); });
    return ex<Sum>(ranges::begin(cloned_summands),
                   ranges::end(cloned_summands));
  }

  virtual void adjoint() override;

  virtual Expr &operator+=(const Expr &that) override {
    this->append(const_cast<Expr &>(that).shared_from_this());
    return *this;
  }

  virtual Expr &operator-=(const Expr &that) override {
    if (that.is<Constant>())
      this->append(ex<Constant>(-that.as<Constant>().value()));
    else
      this->append(ex<Product>(
          -1, ExprPtrList{const_cast<Expr &>(that).shared_from_this()}));
    return *this;
  }

 private:
  summands_type summands_{};
  std::optional<size_t>
      constant_summand_idx_{};  // points to the constant summand, if any; used
                                // to sum up constants in append/prepend

  cursor begin_cursor() override {
    return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
  };
  cursor end_cursor() override {
    return summands_.empty() ? Expr::end_cursor()
                             : cursor{&summands_[0] + summands_.size()};
  };
  cursor begin_cursor() const override {
    return summands_.empty() ? Expr::begin_cursor() : cursor{&summands_[0]};
  };
  cursor end_cursor() const override {
    return summands_.empty() ? Expr::end_cursor()
                             : cursor{&summands_[0] + summands_.size()};
  };

  hash_type memoizing_hash() const override {
    auto compute_hash = [this]() {
      if (summands_.size() == 1)
        return summands_[0]->hash_value();
      else {
        auto deref_summands =
            summands() |
            ranges::views::transform(
                [](const ExprPtr &ptr) -> const Expr & { return *ptr; });
        auto value = hash::range(ranges::begin(deref_summands),
                                 ranges::end(deref_summands));
        return value;
      }
    };

    if (!hash_value_) {
      hash_value_ = compute_hash();
    } else {
      SEQUANT_ASSERT(*hash_value_ == compute_hash());
    }

    return *hash_value_;
  }

  ExprPtr canonicalize_impl(bool multipass, CanonicalizeOptions opt);

  virtual ExprPtr canonicalize(
      CanonicalizeOptions opt =
          CanonicalizeOptions::default_options()) override {
    return canonicalize_impl(true, opt);
  }
  virtual ExprPtr rapid_canonicalize(
      CanonicalizeOptions opts =
          CanonicalizeOptions::default_options().copy_and_set(
              CanonicalizationMethod::Rapid)) override {
    SEQUANT_ASSERT(opts.method == CanonicalizationMethod::Rapid);
    return canonicalize_impl(false, opts);
  }

  bool static_equal(const Expr &that) const override {
    const auto &that_cast = static_cast<const Sum &>(that);
    if (summands().size() == that_cast.summands().size()) {
      if (this->empty()) return true;
      // compare hash values first
      if (this->hash_value() ==
          that.hash_value())  // hash values agree -> do full comparison
        return std::equal(begin_subexpr(), end_subexpr(), that.begin_subexpr(),
                          expr_ptr_comparer);
      else
        return false;
    } else
      return false;
  }
};  // class Sum

class HashingAccumulator {
 public:
  HashingAccumulator &append(ExprPtr summand, bool flatten = true);

  SumPtr make_sum();

  SumPtr make_canonicalized_sum();

  ExprPtr make_expr(bool canonicalize = true);

  bool empty() const { return summands_.empty(); }

 private:
  SumPtr make_sum_impl(bool canonicalize);

  container::unordered_set<ExprPtr, sequant::hash::_<ExprPtr>, proportional_to>
      summands_;
};

struct TransformSumExprOptions {
  bool canonicalize = true;
  bool flatten = true;
};

template <typename SizedRange, typename UnaryMapOp>
  requires(meta::is_range_v<std::remove_cvref_t<SizedRange>>)
ExprPtr transform_sum_expr(SizedRange &&rng, const UnaryMapOp &map,
                           const TransformSumExprOptions &options = {}) {
  HashingAccumulator result_acc;
  std::mutex result_mtx;  // serializes updates of result

  auto task = [&result_acc, &result_mtx, &map,
               canonicalize = options.canonicalize,
               flatten = options.flatten](const ExprPtr &input) {
    auto task_result = map(input);
    if (task_result) {
      if (canonicalize) {
        auto bp = task_result->canonicalize();
        if (bp) {
          task_result = bp * task_result;
        }
      }

      std::scoped_lock<std::mutex> lock(result_mtx);
      result_acc.append(task_result, flatten);
    }
  };
  sequant::for_each(std::forward<SizedRange>(rng), task);
  return result_acc.make_expr(options.canonicalize);
}

}  // namespace sequant

#endif  // SEQUANT_EXPRESSIONS_SUM_HPP