Program Listing for File context.cpp

Return to documentation for file (SeQuant/domain/mbpt/context.cpp)

#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/domain/mbpt/context.hpp>

#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
#include <mutex>
#endif

namespace sequant::mbpt {

#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
static std::recursive_mutex mbpt_ctx_mtx;  // used to protect the MBPT context
#endif

Context::Context(Options options)
    : csv_(options.csv),
      op_registry_(options.op_registry_ptr
                       ? std::move(options.op_registry_ptr)
                       : (options.op_registry
                              ? std::make_shared<OpRegistry>(
                                    std::move(options.op_registry.value()))
                              : nullptr)) {}

Context Context::clone() const {
  Context ctx(*this);
  if (op_registry_) {
    ctx.op_registry_ = std::make_shared<OpRegistry>(op_registry_->clone());
  }
  return ctx;
}

CSV Context::csv() const { return csv_; }

std::shared_ptr<const OpRegistry> Context::op_registry() const {
  SEQUANT_ASSERT(op_registry_, "mbpt::Context has null OpRegistry");
  return op_registry_;
}

std::shared_ptr<OpRegistry> Context::mutable_op_registry() const {
  SEQUANT_ASSERT(op_registry_, "mbpt::Context has null OpRegistry");
  return op_registry_;
}

Context& Context::set(const OpRegistry& op_registry) {
  op_registry_ = std::make_shared<OpRegistry>(op_registry);
  return *this;
}

Context& Context::set(std::shared_ptr<OpRegistry> op_registry) {
  op_registry_ = std::move(op_registry);
  return *this;
}

Context& Context::set(CSV csv) {
  csv_ = csv;
  return *this;
}

bool operator==(Context const& left, Context const& right) {
  if (left.csv() != right.csv()) return false;

  // both null -> equal; one null -> not equal
  if (!left.op_registry_ && !right.op_registry_) return true;
  if (!left.op_registry_ || !right.op_registry_) return false;

  return *left.op_registry_ == *right.op_registry_;
}

bool operator!=(Context const& left, Context const& right) {
  return !(left == right);
}

const Context& get_default_mbpt_context() {
#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
  std::scoped_lock lock(mbpt_ctx_mtx);
#endif
  return sequant::detail::get_implicit_context<Context>();
}

void set_default_mbpt_context(const Context& ctx) {
#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
  std::scoped_lock lock(mbpt_ctx_mtx);
#endif
  sequant::detail::set_implicit_context(ctx);
}

void set_default_mbpt_context(const Context::Options& options) {
  return set_default_mbpt_context(Context(options));
}

void reset_default_mbpt_context() {
#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
  std::scoped_lock lock(mbpt_ctx_mtx);
#endif
  sequant::detail::reset_implicit_context<Context>();
}

[[nodiscard]] sequant::detail::ImplicitContextResetter<Context>
set_scoped_default_mbpt_context(const Context& f) {
#ifdef SEQUANT_CONTEXT_MANIPULATION_THREADSAFE
  std::scoped_lock lock(mbpt_ctx_mtx);
#endif
  return sequant::detail::set_scoped_implicit_context(f);
}

[[nodiscard]] sequant::detail::ImplicitContextResetter<Context>
set_scoped_default_mbpt_context(const Context::Options& f) {
  return set_scoped_default_mbpt_context(Context(f));
}

std::shared_ptr<OpRegistry> make_minimal_registry() {
  auto registry = std::make_shared<OpRegistry>();

  registry
      ->add(L"h", OpClass::gen)
      .add(L"g", OpClass::gen)
      .add(L"f", OpClass::gen)
      .add(L"θ", OpClass::gen)
      .add(L"t", OpClass::ex)
      .add(L"λ", OpClass::deex)
      .add(L"R", OpClass::ex)
      .add(L"L", OpClass::deex);

  return registry;
}

std::shared_ptr<OpRegistry> make_legacy_registry() {
  auto registry = std::make_shared<OpRegistry>();

  registry->add(L"h", OpClass::gen)
      .add(L"f", OpClass::gen)
      .add(L"f̃", OpClass::gen)
      .add(L"g", OpClass::gen)
      .add(L"θ", OpClass::gen)
      .add(L"t", OpClass::ex)
      .add(L"λ", OpClass::deex)
      .add(L"R", OpClass::ex)
      .add(L"L", OpClass::deex)
      .add(L"F", OpClass::gen)
      .add(L"GR", OpClass::gen)
      .add(L"C", OpClass::gen)
      .add(L"γ", OpClass::gen)
      .add(L"κ", OpClass::gen);

  return registry;
}

OpClass to_op_class(const std::wstring& op) {
  // check reserved labels first
  if (ranges::contains(reserved::labels(), op)) {
    return OpClass::gen;  // all reserved labels are gen
  } else {
    return get_default_mbpt_context().op_registry()->to_class(op);
  }
}

}  // namespace sequant::mbpt