Program Listing for File eval_result.hpp¶
↰ Return to documentation for file (SeQuant/domain/eval/eval_result.hpp
)
#ifndef SEQUANT_EVAL_RESULT_HPP
#define SEQUANT_EVAL_RESULT_HPP
#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/logger.hpp>
#include <TiledArray/einsum/tiledarray.h>
#include <btas/btas.h>
#include <tiledarray.h>
#include <range/v3/algorithm.hpp>
#include <range/v3/numeric.hpp>
#include <range/v3/view.hpp>
#include <any>
#include <memory>
#include <utility>
namespace sequant {
namespace {
[[maybe_unused]] std::logic_error invalid_operand(
std::string_view msg = "Invalid operand for binary op") noexcept {
return std::logic_error{msg.data()};
}
[[maybe_unused]] std::logic_error unimplemented_method(
std::string_view msg) noexcept {
using namespace std::string_literals;
return std::logic_error{"Not implemented in this derived class: "s +
msg.data()};
}
template <typename T>
struct Annot {
explicit Annot(std::array<std::any, 3> const& a)
: lannot(std::any_cast<T>(a[0])),
rannot(std::any_cast<T>(a[1])),
this_annot(std::any_cast<T>(a[2])) {}
T const lannot;
T const rannot;
T const this_annot;
};
// It is an iterator type
template <typename It>
struct IterPair {
It first, second;
IterPair(It beg, It end) noexcept : first{beg}, second{end} {};
};
template <typename It>
void swap(IterPair<It>& l, IterPair<It>& r) {
using std::iter_swap;
std::iter_swap(l.first, r.first);
std::iter_swap(l.second, r.second);
}
template <typename It>
bool operator<(IterPair<It> const& l, IterPair<It> const& r) noexcept {
return *l.first < *r.first;
}
auto valid_particle_range = [](auto const& tpl) -> bool {
using std::distance;
auto [b1, b2, l] = tpl;
return distance(b1, b1 + l) == distance(b2, b2 + l);
};
auto iter_pairs = [](auto&& tpl) {
using ranges::views::iota;
using ranges::views::transform;
using std::get;
auto b1 = get<0>(tpl);
auto b2 = get<1>(tpl);
auto l = get<2>(tpl);
return iota(size_t{0}, l) | transform([b1, b2](auto i) {
return IterPair{b1 + i, b2 + i};
}) |
ranges::to_vector;
};
using perm_t = container::svector<size_t>;
using particle_range_t = std::array<size_t, 3>;
template <typename F,
std::enable_if_t<std::is_invocable_v<F, int>, bool> = true>
void antisymmetric_permutation(
container::svector<
std::tuple<perm_t::iterator, perm_t::iterator, size_t>> const& groups,
F const& call_back) {
using ranges::views::transform;
auto const n = groups.size();
if (n == 0) return;
assert(ranges::all_of(groups, valid_particle_range));
call_back(0);
for (int i = n - 1; i >= 0; --i) {
auto [bra_beg, ket_beg, len] = groups[i];
auto bra_end = bra_beg + len;
auto ket_end = ket_beg + len;
int bra_p = 0;
auto outer = 0;
for (auto bra_yn = true; bra_yn;
bra_yn = next_permutation_parity(bra_p, bra_beg, bra_end), ++outer) {
auto inner = 0;
int ket_p = 0;
for (auto ket_yn = true; ket_yn;
ket_yn = next_permutation_parity(ket_p, ket_beg, ket_end), ++inner) {
if (!(outer == 0 && inner == 0)) call_back((bra_p + ket_p) % 2);
}
}
}
}
template <typename F, std::enable_if_t<std::is_invocable_v<F>, bool> = true>
void symmetric_permutation(
container::svector<
std::tuple<perm_t::iterator, perm_t::iterator, size_t>> const& groups,
F const& call_back) {
using ranges::views::transform;
auto const n = groups.size();
if (n == 0) return;
assert(ranges::all_of(groups, valid_particle_range));
auto groups_vec = groups | transform(iter_pairs) | ranges::to_vector;
call_back();
// using reverse iterator (instead of indices) not allowed for some reason
// iter from the end group
for (int i = n - 1; i >= 0; --i) {
auto beg = groups_vec[i].begin();
auto end = groups_vec[i].end();
auto yn = std::next_permutation(beg, end);
for (; yn; yn = std::next_permutation(beg, end)) call_back();
}
}
template <
typename F,
std::enable_if_t<std::is_invocable_v<F, int, perm_t const&>, bool> = true>
void antisymmetrize_backend(size_t rank,
container::svector<particle_range_t> const& groups,
F const& call_back) {
using ranges::views::iota;
auto perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto groups_vec = container::svector<
std::tuple<perm_t::iterator, perm_t::iterator, size_t>>{};
groups_vec.reserve(groups.size());
auto beg = perm.begin();
for (auto&& g : groups) {
groups_vec.emplace_back(beg + g[0], beg + g[1], g[2]);
}
antisymmetric_permutation(
groups_vec,
[&call_back, &perm = std::as_const(perm)](int p) { call_back(p, perm); });
}
template <typename F,
std::enable_if_t<std::is_invocable_v<F, perm_t const&>, bool> = true>
void symmetrize_backend(size_t rank,
container::svector<particle_range_t> const& groups,
F const& call_back) {
using ranges::views::iota;
auto perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto groups_vec = container::svector<
std::tuple<perm_t::iterator, perm_t::iterator, size_t>>{};
groups_vec.reserve(groups.size());
auto beg = perm.begin();
for (auto&& g : groups) {
groups_vec.emplace_back(beg + g[0], beg + g[1], g[2]);
}
symmetric_permutation(
groups_vec,
[&call_back, &perm = std::as_const(perm)]() { call_back(perm); });
}
template <typename RngOfOrdinals>
std::string ords_to_annot(RngOfOrdinals const& ords) {
using ranges::views::intersperse;
using ranges::views::join;
using ranges::views::transform;
auto to_str = [](auto x) { return std::to_string(x); };
return ords | transform(to_str) | intersperse(std::string{","}) | join |
ranges::to<std::string>;
}
template <typename Iterable>
auto index_hash(Iterable const& bk) {
return ranges::views::transform(bk, [](auto const& idx) {
//
// WARNING!
// The BTAS expects index types to be long by default.
// There is no straight-forward way to turn the default.
// Hence, here we explicitly cast the size_t values to long
// Which is a potentially narrowing conversion leading to
// integral overflow. Hence, the values in the returned
// container are mixed negative and positive integers (long type)
//
return static_cast<long>(sequant::hash::value(Index{idx}.label()));
});
}
template <typename... Args>
auto symmetrize_ta(TA::DistArray<Args...> const& arr,
container::svector<particle_range_t> const& groups) {
using ranges::views::iota;
size_t const rank = arr.trange().rank();
TA::DistArray<Args...> result;
auto const lannot = ords_to_annot(iota(size_t{0}, rank));
auto call_back = [&result, &lannot, &arr](perm_t const& perm) {
auto const rannot = ords_to_annot(perm);
if (result.is_initialized()) {
result(lannot) += arr(rannot);
} else {
result(lannot) = arr(rannot);
}
};
symmetrize_backend(rank, groups, call_back);
TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
return result;
}
template <typename... Args>
auto antisymmetrize_ta(
TA::DistArray<Args...> const& arr,
container::svector<particle_range_t> const& groups = {}) {
using ranges::views::iota;
size_t const rank = arr.trange().rank();
TA::DistArray<Args...> result;
auto const lannot = ords_to_annot(iota(size_t{0}, rank));
auto call_back = [&lannot, &arr, &result](int p, perm_t const& perm) {
typename decltype(result)::numeric_type p_ = p == 0 ? 1 : -1;
if (result.is_initialized())
result(lannot) += p_ * arr(ords_to_annot(perm));
else
result(lannot) = p_ * arr(ords_to_annot(perm));
};
antisymmetrize_backend(rank, groups, call_back);
TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
return result;
}
template <typename... Args>
auto symmetrize_btas(btas::Tensor<Args...> const& arr,
container::svector<particle_range_t> const& groups) {
using ranges::views::iota;
size_t const rank = arr.rank();
// Caveat:
// clang-format off
// auto const lannot = iota(size_t{0}, rank) | ranges::to<perm_t>;
// clang-format on
auto const lannot = [rank]() {
auto p = perm_t(rank);
for (auto i = 0; i < rank; ++i) p[i] = i;
return p;
}();
auto result = btas::Tensor<Args...>{arr.range()};
result.fill(0);
auto call_back = [&result, &lannot, &arr](auto const& permutation) {
auto const& rannot = permutation;
btas::Tensor<Args...> temp;
btas::permute(arr, lannot, temp, rannot);
result += temp;
};
symmetrize_backend(rank, groups, call_back);
return result;
}
template <typename... Args>
auto antisymmetrize_btas(
btas::Tensor<Args...> const& arr,
container::svector<particle_range_t> const& groups = {}) {
using ranges::views::iota;
size_t const rank = arr.rank();
// Caveat:
// auto const lannot = iota(size_t{0}, rank) | ranges::to<perm_t>;
//
auto const lannot = [rank]() {
auto p = perm_t(rank);
for (auto i = 0; i < rank; ++i) p[i] = i;
return p;
}();
auto result = btas::Tensor<Args...>{arr.range()};
result.fill(0);
auto call_back = [&result, &lannot, &arr](int p, perm_t const& perm) {
typename decltype(result)::numeric_type p_ = p == 0 ? 1 : -1;
auto const& rannot = perm;
btas::Tensor<Args...> temp;
btas::permute(arr, lannot, temp, rannot);
btas::scal(p_, temp);
result += temp;
};
antisymmetrize_backend(rank, groups, call_back);
return result;
}
template <typename... Args>
inline void log_result(Args const&... args) noexcept {
#ifdef SEQUANT_EVAL_TRACE
auto l = Logger::instance();
if (l.log_level_eval > 1) write_log(l, args...);
#endif
}
template <typename... Args>
inline void log_ta(Args const&... args) noexcept {
#ifdef SEQUANT_EVAL_TRACE
log_result("[TA] ", args...);
#endif
}
template <typename... Args>
inline void log_constant(Args const&... args) noexcept {
#ifdef SEQUANT_EVAL_TRACE
log_result("[CONST] ", args...);
#endif
}
} // namespace
void log_ta_tensor_host_memory_use(madness::World& world,
std::string_view label = "");
class EvalResult;
using ERPtr = std::shared_ptr<EvalResult>;
template <typename T, typename... Args>
ERPtr eval_result(Args&&... args) noexcept {
return std::make_shared<T>(std::forward<Args>(args)...);
}
class EvalResult {
public:
using id_t = size_t;
virtual ~EvalResult() noexcept = default;
template <typename T>
[[nodiscard]] bool is() const noexcept {
return this->type_id() == id_for_type<std::decay_t<T>>();
}
template <typename T>
[[nodiscard]] T const& as() const {
assert(this->is<std::decay_t<T>>());
return static_cast<T const&>(*this);
}
[[nodiscard]] virtual ERPtr sum(EvalResult const&,
std::array<std::any, 3> const&) const = 0;
[[nodiscard]] virtual ERPtr prod(EvalResult const&,
std::array<std::any, 3> const&,
TA::DeNest DeNestFlag) const = 0;
[[nodiscard]] virtual ERPtr permute(std::array<std::any, 2> const&) const = 0;
virtual void add_inplace(EvalResult const&) = 0;
[[nodiscard]] virtual ERPtr symmetrize(
container::svector<std::array<size_t, 3>> const&) const = 0;
[[nodiscard]] virtual ERPtr antisymmetrize(
container::svector<std::array<size_t, 3>> const&) const = 0;
[[nodiscard]] bool has_value() const noexcept;
template <typename T>
[[nodiscard]] T& get() {
assert(has_value());
return *std::any_cast<T>(&value_);
}
template <typename T>
[[nodiscard]] T const& get() const {
return const_cast<EvalResult&>(*this).get<T>();
}
protected:
template <typename T,
typename = std::enable_if_t<!std::is_convertible_v<T, EvalResult>>>
explicit EvalResult(T&& arg) noexcept
: value_{std::make_any<std::decay_t<T>>(std::forward<T>(arg))} {}
[[nodiscard]] virtual id_t type_id() const noexcept = 0;
template <typename T>
[[nodiscard]] static id_t id_for_type() noexcept {
static id_t id = next_id();
return id;
}
private:
std::any value_;
[[nodiscard]] static id_t next_id() noexcept;
};
template <typename T>
class EvalScalar final : public EvalResult {
public:
using EvalResult::id_t;
explicit EvalScalar(T v) noexcept : EvalResult{std::move(v)} {}
[[nodiscard]] T value() const noexcept { return get<T>(); }
[[nodiscard]] ERPtr sum(EvalResult const& other,
std::array<std::any, 3> const&) const override {
if (other.is<EvalScalar<T>>()) {
auto const& o = other.as<EvalScalar<T>>();
auto s = value() + o.value();
log_constant(value(), " + ", o.value(), " = ", s, "\n");
return eval_result<EvalScalar<T>>(s);
} else {
throw invalid_operand();
}
}
[[nodiscard]] ERPtr prod(EvalResult const& other,
std::array<std::any, 3> const& maybe_empty,
TA::DeNest DeNestFlag) const override {
if (other.is<EvalScalar<T>>()) {
auto const& o = other.as<EvalScalar<T>>();
auto p = value() * o.value();
log_constant(value(), " * ", o.value(), " = ", p, "\n");
return eval_result<EvalScalar<T>>(value() * o.value());
} else {
auto maybe_empty_ = maybe_empty;
std::swap(maybe_empty_[0], maybe_empty_[1]);
return other.prod(*this, maybe_empty_, DeNestFlag);
}
}
[[nodiscard]] ERPtr permute(std::array<std::any, 2> const&) const override {
throw unimplemented_method("permute");
}
void add_inplace(EvalResult const& other) override {
assert(other.is<EvalScalar<T>>());
log_constant(value(), " += ", other.get<T>(), "\n");
auto& val = get<T>();
val += other.get<T>();
}
[[nodiscard]] ERPtr symmetrize(
container::svector<std::array<size_t, 3>> const&) const override {
throw unimplemented_method("symmetrize");
}
[[nodiscard]] ERPtr antisymmetrize(
container::svector<std::array<size_t, 3>> const&) const override {
throw unimplemented_method("antisymmetrize");
}
private:
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<EvalScalar<T>>();
}
};
template <typename ArrayT, typename = std::enable_if_t<TA::detail::is_tensor_v<
typename ArrayT::value_type>>>
class EvalTensorTA final : public EvalResult {
public:
using EvalResult::id_t;
using numeric_type = typename ArrayT::numeric_type;
explicit EvalTensorTA(ArrayT arr) : EvalResult{std::move(arr)} {}
private:
using this_type = EvalTensorTA<ArrayT>;
using annot_wrap = Annot<std::string>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<this_type>();
}
[[nodiscard]] ERPtr sum(EvalResult const& other,
std::array<std::any, 3> const& annot) const override {
assert(other.is<this_type>());
auto const a = annot_wrap{annot};
log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result(a.this_annot) =
get<ArrayT>()(a.lannot) + other.get<ArrayT>()(a.rannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ERPtr prod(EvalResult const& other,
std::array<std::any, 3> const& annot,
TA::DeNest DeNestFlag) const override {
auto const a = annot_wrap{annot};
if (other.is<EvalScalar<numeric_type>>()) {
auto result = get<ArrayT>();
auto scalar = other.get<numeric_type>();
log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n");
result(a.this_annot) = scalar * result(a.lannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
if (a.this_annot.empty()) {
// DOT product
assert(other.is<this_type>());
numeric_type d =
TA::dot(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot));
ArrayT::wait_for_lazy_cleanup(get<ArrayT>().world());
ArrayT::wait_for_lazy_cleanup(other.get<ArrayT>().world());
log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n");
return eval_result<EvalScalar<numeric_type>>(d);
}
if (!other.is<this_type>()) {
// potential T * ToT
auto annot_swap = annot;
std::swap(annot_swap[0], annot_swap[1]);
return other.prod(*this, annot_swap, DeNestFlag);
}
// confirmed: other.is<this_type>() is true
log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result = TA::einsum(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot),
a.this_annot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ERPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<std::string>(ann[0]);
auto const post_annot = std::any_cast<std::string>(ann[1]);
log_ta(pre_annot, " = ", post_annot, "\n");
ArrayT result;
result(post_annot) = get<ArrayT>()(pre_annot);
ArrayT::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
void add_inplace(EvalResult const& other) override {
assert(other.is<this_type>());
auto& t = get<ArrayT>();
auto const& o = other.get<ArrayT>();
assert(t.trange() == o.trange());
auto ann = TA::detail::dummy_annotation(t.trange().rank());
log_ta(ann, " += ", ann, "\n");
t(ann) += o(ann);
ArrayT::wait_for_lazy_cleanup(t.world());
}
[[nodiscard]] ERPtr symmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
return eval_result<this_type>(symmetrize_ta(get<ArrayT>(), groups));
}
[[nodiscard]] ERPtr antisymmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
return eval_result<this_type>(antisymmetrize_ta(get<ArrayT>(), groups));
}
};
template <typename ArrayT,
typename = std::enable_if_t<
TA::detail::is_tensor_of_tensor_v<typename ArrayT::value_type>>>
class EvalTensorOfTensorTA final : public EvalResult {
public:
using EvalResult::id_t;
using numeric_type = typename ArrayT::numeric_type;
explicit EvalTensorOfTensorTA(ArrayT arr) : EvalResult{std::move(arr)} {}
private:
using this_type = EvalTensorOfTensorTA<ArrayT>;
using annot_wrap = Annot<std::string>;
using _inner_tensor_type = typename ArrayT::value_type::value_type;
using compatible_regular_distarray_type =
TA::DistArray<_inner_tensor_type, typename ArrayT::policy_type>;
// Only @c that_type type is allowed for ToT * T computation
using that_type = EvalTensorTA<compatible_regular_distarray_type>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<this_type>();
}
[[nodiscard]] ERPtr sum(EvalResult const& other,
std::array<std::any, 3> const& annot) const override {
assert(other.is<this_type>());
auto const a = annot_wrap{annot};
log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result(a.this_annot) =
get<ArrayT>()(a.lannot) + other.get<ArrayT>()(a.rannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ERPtr prod(EvalResult const& other,
std::array<std::any, 3> const& annot,
TA::DeNest DeNestFlag) const override {
auto const a = annot_wrap{annot};
if (other.is<EvalScalar<numeric_type>>()) {
auto result = get<ArrayT>();
auto scalar = other.get<numeric_type>();
log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n");
result(a.this_annot) = scalar * result(a.lannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
} else if (a.this_annot.empty()) {
// DOT product
assert(other.is<this_type>());
numeric_type d =
TA::dot(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot));
ArrayT::wait_for_lazy_cleanup(get<ArrayT>().world());
ArrayT::wait_for_lazy_cleanup(other.get<ArrayT>().world());
log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n");
return eval_result<EvalScalar<numeric_type>>(d);
}
log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n");
if (other.is<that_type>()) {
// ToT * T -> ToT
auto result =
TA::einsum(get<ArrayT>()(a.lannot),
other.get<compatible_regular_distarray_type>()(a.rannot),
a.this_annot);
return eval_result<this_type>(std::move(result));
} else if (other.is<this_type>() && DeNestFlag == TA::DeNest::True) {
// ToT * ToT -> T
auto result = TA::einsum<TA::DeNest::True>(
get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot), a.this_annot);
return eval_result<that_type>(std::move(result));
} else if (other.is<this_type>() && DeNestFlag == TA::DeNest::False) {
// ToT * ToT -> ToT
auto result = TA::einsum(get<ArrayT>()(a.lannot),
other.get<ArrayT>()(a.rannot), a.this_annot);
return eval_result<this_type>(std::move(result));
} else {
throw invalid_operand();
}
}
[[nodiscard]] ERPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<std::string>(ann[0]);
auto const post_annot = std::any_cast<std::string>(ann[1]);
log_ta(pre_annot, " = ", post_annot, "\n");
ArrayT result;
result(post_annot) = get<ArrayT>()(pre_annot);
ArrayT::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
void add_inplace(EvalResult const& other) override {
assert(other.is<this_type>());
auto& t = get<ArrayT>();
auto const& o = other.get<ArrayT>();
assert(t.trange() == o.trange());
auto ann = TA::detail::dummy_annotation(t.trange().rank());
log_ta(ann, " += ", ann, "\n");
t(ann) += o(ann);
ArrayT::wait_for_lazy_cleanup(t.world());
}
[[nodiscard]] ERPtr symmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
// todo
// return eval_result<this_type>(symmetrize_ta(get<ArrayT>(), groups));
return nullptr;
}
[[nodiscard]] ERPtr antisymmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
// todo
// return eval_result<this_type>(antisymmetrize_ta(get<ArrayT>(), groups));
return nullptr;
}
};
template <typename T>
class EvalTensorBTAS final : public EvalResult {
public:
using EvalResult::id_t;
using numeric_type = typename T::numeric_type;
explicit EvalTensorBTAS(T arr) : EvalResult{std::move(arr)} {}
private:
// TODO make it same as that used by EvalExprBTAS class from eval.hpp file
using annot_t = container::svector<long>;
using annot_wrap = Annot<annot_t>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<EvalTensorBTAS<T>>();
}
[[nodiscard]] ERPtr sum(EvalResult const& other,
std::array<std::any, 3> const& annot) const override {
assert(other.is<EvalTensorBTAS<T>>());
auto const a = annot_wrap{annot};
T lres, rres;
btas::permute(get<T>(), a.lannot, lres, a.this_annot);
btas::permute(other.get<T>(), a.rannot, rres, a.this_annot);
return eval_result<EvalTensorBTAS<T>>(lres + rres);
}
[[nodiscard]] ERPtr prod(EvalResult const& other,
std::array<std::any, 3> const& annot,
TA::DeNest DeNestFlag) const override {
auto const a = annot_wrap{annot};
if (other.is<EvalScalar<numeric_type>>()) {
T result;
btas::permute(get<T>(), a.lannot, result, a.this_annot);
btas::scal(other.as<EvalScalar<numeric_type>>().value(), result);
return eval_result<EvalTensorBTAS<T>>(std::move(result));
}
assert(other.is<EvalTensorBTAS<T>>());
if (a.this_annot.empty()) {
T rres;
btas::permute(other.get<T>(), a.rannot, rres, a.lannot);
return eval_result<EvalScalar<numeric_type>>(btas::dot(get<T>(), rres));
}
T result;
btas::contract(numeric_type{1}, //
get<T>(), a.lannot, //
other.get<T>(), a.rannot, //
numeric_type{0}, //
result, a.this_annot);
return eval_result<EvalTensorBTAS<T>>(std::move(result));
}
[[nodiscard]] ERPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<annot_t>(ann[0]);
auto const post_annot = std::any_cast<annot_t>(ann[1]);
T result;
btas::permute(get<T>(), pre_annot, result, post_annot);
return eval_result<EvalTensorBTAS<T>>(std::move(result));
}
void add_inplace(EvalResult const& other) override {
auto& t = get<T>();
auto const& o = other.get<T>();
assert(t.range() == o.range());
t += o;
}
[[nodiscard]] ERPtr symmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
return eval_result<EvalTensorBTAS<T>>(symmetrize_btas(get<T>(), groups));
}
[[nodiscard]] ERPtr antisymmetrize(
container::svector<std::array<size_t, 3>> const& groups) const override {
return eval_result<EvalTensorBTAS<T>>(
antisymmetrize_btas(get<T>(), groups));
}
};
} // namespace sequant
#endif // SEQUANT_EVAL_RESULT_HPP