Program Listing for File eval_expr.hpp

Return to documentation for file (SeQuant/core/eval_expr.hpp)

#ifndef SEQUANT_EVAL_EXPR_HPP
#define SEQUANT_EVAL_EXPR_HPP

#include <SeQuant/core/binary_node.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/external/bliss/graph.hh>

#include <cstddef>
#include <memory>
#include <string>

namespace sequant {

class Tensor;

enum class EvalOp {

  Sum,

  Product
};

enum class ResultType { Tensor, Scalar };

struct EvalOpSetter;

class EvalExpr {
 public:
  friend struct EvalOpSetter;
  using index_vector = Index::index_vector;

  explicit EvalExpr(Tensor const& tnsr);

  explicit EvalExpr(Constant const& c);

  explicit EvalExpr(Variable const& v);

  EvalExpr(EvalOp op, ResultType res, ExprPtr const& expr, index_vector ixs,
           std::int8_t phase, size_t hash,
           std::shared_ptr<bliss::Graph> connectivity);

  [[nodiscard]] const std::optional<EvalOp>& op_type() const noexcept;

  [[nodiscard]] ResultType result_type() const noexcept;

  [[nodiscard]] size_t hash_value() const noexcept;

  [[nodiscard]] ExprPtr expr() const noexcept;

  [[nodiscard]] bool tot() const noexcept;

  [[nodiscard]] std::wstring to_latex() const noexcept;

  [[nodiscard]] bool is_tensor() const noexcept;

  [[nodiscard]] bool is_scalar() const noexcept;

  [[nodiscard]] bool is_constant() const noexcept;

  [[nodiscard]] bool is_variable() const noexcept;

  [[nodiscard]] bool is_primary() const noexcept;

  [[nodiscard]] bool is_product() const noexcept;

  [[nodiscard]] bool is_sum() const noexcept;

  [[nodiscard]] Tensor const& as_tensor() const;

  [[nodiscard]] Constant const& as_constant() const;

  [[nodiscard]] Variable const& as_variable() const;

  [[nodiscard]] std::string label() const noexcept;

  [[nodiscard]] std::string indices_annot() const noexcept;

  [[nodiscard]] index_vector const& canon_indices() const noexcept;

  [[nodiscard]] std::int8_t canon_phase() const noexcept;

  [[nodiscard]] bool has_connectivity_graph() const noexcept;

  [[nodiscard]] const bliss::Graph& connectivity_graph() const noexcept;

  [[nodiscard]] std::shared_ptr<bliss::Graph> copy_connectivity_graph()
      const noexcept;

 protected:
  std::optional<EvalOp> op_type_ = std::nullopt;

  ResultType result_type_;

  ExprPtr expr_;

  index_vector canon_indices_;

  std::int8_t canon_phase_{1};

  size_t hash_value_;

  std::shared_ptr<bliss::Graph> connectivity_;
};

struct EvalOpSetter {
  void set(EvalExpr& expr, EvalOp op) { expr.op_type_ = op; }
};

class EvalExprTA final : public EvalExpr {
 public:
  template <typename... Args, typename = std::enable_if_t<
                                  std::is_constructible_v<EvalExpr, Args...>>>
  EvalExprTA(Args&&... args) : EvalExpr{std::forward<Args>(args)...} {
    annot_ = indices_annot();
  }

  [[nodiscard]] inline auto const& annot() const noexcept { return annot_; }

 private:
  std::string annot_;
};

class EvalExprBTAS final : public EvalExpr {
 public:
  using annot_t = container::svector<long>;

  template <typename Iterable>
  static auto index_hash(Iterable&& bk) {
    return ranges::views::transform(
        std::forward<Iterable>(bk), [](auto const& idx) {
          //
          // WARNING!
          // The BTAS uses long for scalar indexing by default.
          // Hence, here we explicitly cast the size_t values to long
          // Which is a potentially narrowing conversion leading to
          // integral overflow. Hence, the values in the returned
          // container are mixed negative and positive integers (long type)
          //
          return static_cast<long>(sequant::hash::value(Index{idx}.label()));
        });
  }

  template <typename... Args, typename = std::enable_if_t<
                                  std::is_constructible_v<EvalExpr, Args...>>>
  EvalExprBTAS(Args&&... args) : EvalExpr{std::forward<Args>(args)...} {
    annot_ = index_hash(canon_indices()) | ranges::to<annot_t>;
  }

  [[nodiscard]] inline annot_t const& annot() const noexcept { return annot_; }

 private:
  annot_t annot_;
};

namespace meta {

namespace detail {
template <typename, typename = void>
constexpr bool is_eval_expr{};

template <typename T>
constexpr bool
    is_eval_expr<T, std::enable_if_t<std::is_convertible_v<T, EvalExpr>>>{true};

template <typename, typename = void>
constexpr bool is_eval_node{};

template <typename T>
constexpr bool
    is_eval_node<FullBinaryNode<T>, std::enable_if_t<is_eval_expr<T>>>{true};

}  // namespace detail

template <typename T>
concept eval_expr = detail::is_eval_expr<T>;

template <typename T>
concept eval_node = detail::is_eval_node<std::remove_cvref_t<T>>;

template <typename Rng>
concept eval_node_range =
    std::ranges::range<Rng> && eval_node<std::ranges::range_value_t<Rng>>;

}  // namespace meta

namespace impl {
FullBinaryNode<EvalExpr> binarize(ExprPtr const&);
}

template <meta::eval_expr T>
using EvalNode = FullBinaryNode<T>;

template <typename ExprT = EvalExpr>
  requires std::is_constructible_v<ExprT, EvalExpr>
FullBinaryNode<ExprT> binarize(ExprPtr const& expr) {
  if constexpr (std::is_same_v<ExprT, EvalExpr>) return impl::binarize(expr);
  return transform_node(impl::binarize(expr),
                        [](auto&& val) { return ExprT{val}; });
}

template <typename ExprT = EvalExpr>
  requires std::is_constructible_v<ExprT, EvalExpr>
FullBinaryNode<ExprT> binarize(ResultExpr const& res) {
  FullBinaryNode<ExprT> tree = binarize<ExprT>(res.expression());

  const bool is_scalar =
      res.bra().empty() && res.ket().empty() && res.aux().empty();

  if (tree.size() < 2) {
    // We want to have a result node with the result from the ResultExpr.
    // In order for that to work, we need a dedicated result node in the first
    // place. Hence, we adapt the represented expression for terminals to be
    // that terminal multiplied by 1.
    ExprT result = [&]() {
      if (is_scalar) {
        return *binarize<ExprT>(ex<Variable>(res.result_as_variable()));
      }
      return *binarize<ExprT>(ex<Tensor>(res.result_as_tensor()));
    }();
    EvalOpSetter{}.set(result, EvalOp::Product);

    tree = FullBinaryNode<ExprT>(std::move(result), std::move(tree),
                                 binarize<ExprT>(ex<Constant>(1)));
  }
  SEQUANT_ASSERT(tree.size() > 1);

  if (is_scalar) {
    if (res.has_label()) {
      tree->expr().template as<Variable>().set_label(res.label());
    }
  } else {
    Tensor& tensor = tree->expr().template as<Tensor>();

    tensor = res.result_as_tensor();

    // if (res.has_label()) {
    //   tensor.set_label(res.label());
    // }

    // SEQUANT_ASSERT(tensor.num_slots() ==
    //        res.bra().size() + res.ket().size() + res.aux().size());
    // tensor.set_bra(res.bra());
    // tensor.set_ket(res.ket());
    // tensor.set_aux(res.aux());
  }

  return tree;
}

ExprPtr to_expr(meta::eval_node auto const& node) {
  auto const op = node->op_type();
  auto const& evxpr = *node;

  if (node.leaf()) return evxpr.expr();

  if (op == EvalOp::Product) {
    auto prod = Product{};

    ExprPtr lexpr = to_expr(node.left());
    ExprPtr rexpr = to_expr(node.right());

    prod.append(1, lexpr, Product::Flatten::No);
    prod.append(1, rexpr, Product::Flatten::No);

    SEQUANT_ASSERT(!prod.empty());

    if (prod.size() == 1 && !prod.factor(0)->is<Tensor>()) {
      return ex<Product>(Product{prod.scalar(), prod.factor(0)->begin(),
                                 prod.factor(0)->end(), Product::Flatten::No});
    } else {
      return ex<Product>(std::move(prod));
    }

  } else {
    SEQUANT_ASSERT(op == EvalOp::Sum && "unsupported operation type");
    return ex<Sum>(Sum{to_expr(node.left()), to_expr(node.right())});
  }
}

}  // namespace sequant

#endif  // SEQUANT_EVAL_EXPR_HPP