Program Listing for File result.hpp

Return to documentation for file (SeQuant/core/eval/result.hpp)

#ifndef SEQUANT_EVAL_RESULT_HPP
#define SEQUANT_EVAL_RESULT_HPP

#include <SeQuant/core/eval/fwd.hpp>

#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <range/v3/numeric.hpp>
#include <range/v3/view.hpp>

#include <any>
#include <memory>
#include <utility>

namespace sequant {

namespace {

[[maybe_unused]] std::logic_error invalid_operand(
    std::string_view msg = "Invalid operand for binary op") noexcept {
  return std::logic_error{msg.data()};
}

[[maybe_unused]] std::logic_error unimplemented_method(
    std::string_view msg) noexcept {
  using namespace std::string_literals;
  return std::logic_error{"Not implemented in this derived class: "s +
                          msg.data()};
}

// It is an iterator type
template <typename It>
struct IterPair {
  It first, second;
  IterPair(It beg, It end) noexcept : first{beg}, second{end} {};
};

template <typename It>
void swap(IterPair<It>& l, IterPair<It>& r) {
  using std::iter_swap;
  std::iter_swap(l.first, r.first);
  std::iter_swap(l.second, r.second);
}

template <typename It>
bool operator<(IterPair<It> const& l, IterPair<It> const& r) noexcept {
  return *l.first < *r.first;
}

using perm_t = container::svector<size_t>;

struct SymmetricParticleRange {
  perm_t::iterator bra_beg;
  perm_t::iterator ket_beg;
  size_t nparticles;
};

struct ParticleRange {
  perm_t::iterator beg;
  size_t size;
};

inline bool valid_particle_range(SymmetricParticleRange const& rng) {
  using std::distance;
  auto bra_end = rng.bra_beg + rng.nparticles;
  auto ket_end = rng.ket_beg + rng.nparticles;
  return std::is_sorted(rng.bra_beg, bra_end) &&
         std::is_sorted(rng.ket_beg, ket_end) &&
         distance(rng.bra_beg, bra_end) == distance(rng.ket_beg, ket_end);
}

inline auto iter_pairs(SymmetricParticleRange const& rng) {
  using ranges::views::iota;
  using ranges::views::transform;

  return iota(size_t{0}, rng.nparticles)  //
         | transform([b = rng.bra_beg, k = rng.ket_beg](auto i) {
             return IterPair{b + i, k + i};
           });
}

template <typename F, typename = std::enable_if_t<std::is_invocable_v<F, int>>>
void antisymmetric_permutation(ParticleRange const& rng, F call_back) {
  // if the range has 1 or no elements, there is no permutation
  if (rng.size <= 1) {
    call_back(0);
    return;
  }
  int parity = 0;
  auto end = rng.beg + rng.size;
  for (auto yn = true; yn; yn = next_permutation_parity(parity, rng.beg, end)) {
    call_back(parity);
  }
}

template <typename F, typename = std::enable_if_t<std::is_invocable_v<F>>>
void symmetric_permutation(SymmetricParticleRange const& rng, F call_back) {
  auto ips = iter_pairs(rng) | ranges::to_vector;
  do {
    call_back();
  } while (std::next_permutation(ips.begin(), ips.end()));
}

template <typename RngOfOrdinals>
std::string ords_to_annot(RngOfOrdinals const& ords) {
  using ranges::views::intersperse;
  using ranges::views::join;
  using ranges::views::transform;
  auto to_str = [](auto x) { return std::to_string(x); };
  return ords | transform(to_str) | intersperse(std::string{","}) | join |
         ranges::to<std::string>;
}

template <typename... Args>
inline void log_result(Args const&... args) noexcept {
  auto& l = Logger::instance();
  if (l.eval.level > 1) write_log(l, args...);
}

template <typename... Args>
inline void log_constant(Args const&... args) noexcept {
  log_result("[CONST] ", args...);
}

}  // namespace

/******************************************************************************/

template <typename T, typename... Args>
ResultPtr eval_result(Args&&... args) noexcept {
  return std::make_shared<T>(std::forward<Args>(args)...);
}

template <typename T>
struct Annot {
  explicit Annot(std::array<std::any, 3> const& a)
      : lannot(std::any_cast<T>(a[0])),
        rannot(std::any_cast<T>(a[1])),
        this_annot(std::any_cast<T>(a[2])) {}

  T const lannot;

  T const rannot;

  T const this_annot;
};

class Result {
 public:
  using id_t = size_t;

  virtual ~Result() noexcept = default;

  template <typename T>
  [[nodiscard]] bool is() const noexcept {
    return this->type_id() == id_for_type<std::decay_t<T>>();
  }

  template <typename T>
  [[nodiscard]] T const& as() const {
    SEQUANT_ASSERT(this->is<std::decay_t<T>>());
    return static_cast<T const&>(*this);
  }

  [[nodiscard]] virtual ResultPtr sum(Result const&,
                                      std::array<std::any, 3> const&) const = 0;

  [[nodiscard]] virtual ResultPtr prod(Result const&,
                                       std::array<std::any, 3> const&,
                                       DeNest DeNestFlag) const = 0;

  [[nodiscard]] virtual ResultPtr permute(
      std::array<std::any, 2> const&) const = 0;

  virtual void add_inplace(Result const&) = 0;

  [[nodiscard]] virtual ResultPtr symmetrize() const = 0;

  [[nodiscard]] virtual ResultPtr antisymmetrize(size_t bra_rank) const = 0;

  [[nodiscard]] bool has_value() const noexcept;

  [[nodiscard]] virtual ResultPtr mult_by_phase(std::int8_t) const = 0;

  template <typename T>
  [[nodiscard]] T& get() {
    SEQUANT_ASSERT(has_value());
    return *std::any_cast<T>(&value_);
  }

  template <typename T>
  [[nodiscard]] T const& get() const {
    SEQUANT_ASSERT(has_value());
    return *std::any_cast<const T>(&value_);
  }

  [[nodiscard]] virtual std::size_t size_in_bytes() const = 0;

 protected:
  template <typename T,
            typename = std::enable_if_t<!std::is_convertible_v<T, Result>>>
  explicit Result(T&& arg) noexcept
      : value_{std::make_any<std::decay_t<T>>(std::forward<T>(arg))} {}

  [[nodiscard]] virtual id_t type_id() const noexcept = 0;

  template <typename T>
  [[nodiscard]] static id_t id_for_type() noexcept {
    static id_t id = next_id();
    return id;
  }

 private:
  std::any value_;

  [[nodiscard]] static id_t next_id() noexcept;
};

template <typename T>
class ResultScalar final : public Result {
 public:
  using Result::id_t;

  explicit ResultScalar(T v) noexcept : Result{std::move(v)} {}

  [[nodiscard]] T value() const noexcept { return get<T>(); }

  [[nodiscard]] ResultPtr sum(Result const& other,
                              std::array<std::any, 3> const&) const override {
    if (other.is<ResultScalar<T>>()) {
      auto const& o = other.as<ResultScalar<T>>();
      auto s = value() + o.value();

      log_constant(value(), " + ", o.value(), " = ", s, "\n");

      return eval_result<ResultScalar<T>>(s);
    } else {
      throw invalid_operand();
    }
  }

  [[nodiscard]] ResultPtr prod(Result const& other,
                               std::array<std::any, 3> const& maybe_empty,
                               DeNest DeNestFlag) const override {
    if (other.is<ResultScalar<T>>()) {
      auto const& o = other.as<ResultScalar<T>>();
      auto p = value() * o.value();

      log_constant(value(), " * ", o.value(), " = ", p, "\n");

      return eval_result<ResultScalar<T>>(value() * o.value());
    } else {
      auto maybe_empty_ = maybe_empty;
      std::swap(maybe_empty_[0], maybe_empty_[1]);
      return other.prod(*this, maybe_empty_, DeNestFlag);
    }
  }

  [[nodiscard]] ResultPtr permute(
      std::array<std::any, 2> const&) const override {
    throw unimplemented_method("permute");
  }

  void add_inplace(Result const& other) override {
    SEQUANT_ASSERT(other.is<ResultScalar<T>>());
    log_constant(value(), " += ", other.get<T>(), "\n");
    auto& val = get<T>();
    val += other.get<T>();
  }

  [[nodiscard]] ResultPtr symmetrize() const override {
    throw unimplemented_method("symmetrize");
  }

  [[nodiscard]] ResultPtr antisymmetrize(size_t /*bra_rank*/) const override {
    throw unimplemented_method("antisymmetrize");
  }

  [[nodiscard]] ResultPtr mult_by_phase(std::int8_t factor) const override {
    return eval_result<ResultScalar<T>>(value() * T(factor));
  }

 private:
  [[nodiscard]] id_t type_id() const noexcept override {
    return id_for_type<ResultScalar<T>>();
  }

  [[nodiscard]] std::size_t size_in_bytes() const final { return sizeof(T); }
};

}  // namespace sequant

#endif  // SEQUANT_EVAL_RESULT_HPP