Program Listing for File result.hpp¶
↰ Return to documentation for file (SeQuant/core/eval/result.hpp)
#ifndef SEQUANT_EVAL_RESULT_HPP
#define SEQUANT_EVAL_RESULT_HPP
#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval/eval_fwd.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <TiledArray/einsum/tiledarray.h>
#include <btas/btas.h>
#include <tiledarray.h>
#include <range/v3/numeric.hpp>
#include <range/v3/view.hpp>
#include <any>
#include <memory>
#include <utility>
namespace sequant {
namespace {
[[maybe_unused]] std::logic_error invalid_operand(
std::string_view msg = "Invalid operand for binary op") noexcept {
return std::logic_error{msg.data()};
}
[[maybe_unused]] std::logic_error unimplemented_method(
std::string_view msg) noexcept {
using namespace std::string_literals;
return std::logic_error{"Not implemented in this derived class: "s +
msg.data()};
}
// It is an iterator type
template <typename It>
struct IterPair {
It first, second;
IterPair(It beg, It end) noexcept : first{beg}, second{end} {};
};
template <typename It>
void swap(IterPair<It>& l, IterPair<It>& r) {
using std::iter_swap;
std::iter_swap(l.first, r.first);
std::iter_swap(l.second, r.second);
}
template <typename It>
bool operator<(IterPair<It> const& l, IterPair<It> const& r) noexcept {
return *l.first < *r.first;
}
using perm_t = container::svector<size_t>;
struct SymmetricParticleRange {
perm_t::iterator bra_beg;
perm_t::iterator ket_beg;
size_t nparticles;
};
struct ParticleRange {
perm_t::iterator beg;
size_t size;
};
inline bool valid_particle_range(SymmetricParticleRange const& rng) {
using std::distance;
auto bra_end = rng.bra_beg + rng.nparticles;
auto ket_end = rng.ket_beg + rng.nparticles;
return std::is_sorted(rng.bra_beg, bra_end) &&
std::is_sorted(rng.ket_beg, ket_end) &&
distance(rng.bra_beg, bra_end) == distance(rng.ket_beg, ket_end);
}
inline auto iter_pairs(SymmetricParticleRange const& rng) {
using ranges::views::iota;
using ranges::views::transform;
return iota(size_t{0}, rng.nparticles) //
| transform([b = rng.bra_beg, k = rng.ket_beg](auto i) {
return IterPair{b + i, k + i};
});
}
template <typename F, typename = std::enable_if_t<std::is_invocable_v<F, int>>>
void antisymmetric_permutation(ParticleRange const& rng, F call_back) {
// if the range has 1 or no elements, there is no permutation
if (rng.size <= 1) {
call_back(0);
return;
}
int parity = 0;
auto end = rng.beg + rng.size;
for (auto yn = true; yn; yn = next_permutation_parity(parity, rng.beg, end)) {
call_back(parity);
}
}
template <typename F, typename = std::enable_if_t<std::is_invocable_v<F>>>
void symmetric_permutation(SymmetricParticleRange const& rng, F call_back) {
auto ips = iter_pairs(rng) | ranges::to_vector;
do {
call_back();
} while (std::next_permutation(ips.begin(), ips.end()));
}
template <typename RngOfOrdinals>
std::string ords_to_annot(RngOfOrdinals const& ords) {
using ranges::views::intersperse;
using ranges::views::join;
using ranges::views::transform;
auto to_str = [](auto x) { return std::to_string(x); };
return ords | transform(to_str) | intersperse(std::string{","}) | join |
ranges::to<std::string>;
}
template <typename... Args>
auto column_symmetrize_ta(TA::DistArray<Args...> const& arr) {
using ranges::views::iota;
size_t const rank = arr.trange().rank();
if (rank % 2 != 0)
throw std::domain_error("This function only supports even-ranked tensors");
TA::DistArray<Args...> result;
perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto const lannot = ords_to_annot(perm);
auto call_back = [&result, &lannot, &arr, &perm = std::as_const(perm)]() {
auto const rannot = ords_to_annot(perm);
if (result.is_initialized()) {
result(lannot) += arr(rannot);
} else {
result(lannot) = arr(rannot);
}
};
auto const nparticles = rank / 2;
symmetric_permutation(SymmetricParticleRange{perm.begin(), //
perm.begin() + nparticles, //
nparticles},
call_back);
TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
return result;
}
template <typename... Args>
auto particle_antisymmetrize_ta(TA::DistArray<Args...> const& arr,
size_t bra_rank) {
using ranges::views::iota;
size_t const rank = arr.trange().rank();
SEQUANT_ASSERT(bra_rank <= rank);
size_t const ket_rank = rank - bra_rank;
if (bra_rank <= 1 && ket_rank <= 1) {
// nothing to do
return arr;
}
perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;
const auto lannot = ords_to_annot(perm);
auto process_permutations = [&lannot](const TA::DistArray<Args...>& input_arr,
size_t range_rank, perm_t range_perm,
const std::string& other_annot,
bool is_bra) -> TA::DistArray<Args...> {
if (range_rank <= 1) return input_arr;
TA::DistArray<Args...> result;
auto callback = [&](int parity) {
const auto range_annot = ords_to_annot(range_perm);
const auto annot = other_annot.empty()
? range_annot
: (is_bra ? range_annot + "," + other_annot
: other_annot + "," + range_annot);
typename decltype(result)::numeric_type p_ = parity == 0 ? 1 : -1;
if (result.is_initialized()) {
result(lannot) += p_ * input_arr(annot);
} else {
result(lannot) = p_ * input_arr(annot);
}
};
antisymmetric_permutation(ParticleRange{range_perm.begin(), range_rank},
callback);
return result;
};
// Process bra permutations first
const auto ket_annot = ket_rank == 0 ? "" : ords_to_annot(ket_perm);
auto result = process_permutations(arr, bra_rank, bra_perm, ket_annot, true);
// Process ket permutations
const auto bra_annot = bra_rank == 0 ? "" : ords_to_annot(bra_perm);
result = process_permutations(result, ket_rank, ket_perm, bra_annot, false);
TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
return result;
}
template <typename... Args>
auto column_symmetrize_btas(btas::Tensor<Args...> const& arr) {
using ranges::views::iota;
size_t const rank = arr.rank();
if (rank % 2 != 0)
throw std::domain_error("This function only supports even-ranked tensors");
perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto const lannot = perm;
auto result = btas::Tensor<Args...>{arr.range()};
result.fill(0);
auto call_back = [&result, &lannot, &arr, &perm = std::as_const(perm)]() {
btas::Tensor<Args...> temp;
btas::permute(arr, lannot, temp, perm);
result += temp;
};
auto const nparticles = rank / 2;
symmetric_permutation(SymmetricParticleRange{perm.begin(), //
perm.begin() + nparticles, //
nparticles},
call_back);
return result;
}
template <typename... Args>
auto particle_antisymmetrize_btas(btas::Tensor<Args...> const& arr,
size_t bra_rank) {
using ranges::views::concat;
using ranges::views::iota;
size_t const rank = arr.rank();
SEQUANT_ASSERT(bra_rank <= rank);
size_t const ket_rank = rank - bra_rank;
perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;
const auto lannot = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto process_permutations = [&lannot](const btas::Tensor<Args...>& input_arr,
size_t range_rank, perm_t range_perm,
const perm_t& other_perm, bool is_bra) {
if (range_rank <= 1) return input_arr;
btas::Tensor<Args...> result{input_arr.range()};
auto callback = [&](int parity) {
const auto annot =
is_bra ? concat(range_perm, other_perm) | ranges::to<perm_t>()
: concat(other_perm, range_perm) | ranges::to<perm_t>();
typename decltype(result)::numeric_type p_ = parity == 0 ? 1 : -1;
btas::Tensor<Args...> temp;
btas::permute(input_arr, lannot, temp, annot);
btas::scal(p_, temp);
result += temp;
};
antisymmetric_permutation(ParticleRange{range_perm.begin(), range_rank},
callback);
return result;
};
// Process bra permutations first
const auto ket_annot = ket_rank == 0 ? perm_t{} : ket_perm;
auto result = process_permutations(arr, bra_rank, bra_perm, ket_annot, true);
// Process ket permutations if needed
const auto bra_annot = bra_rank == 0 ? perm_t{} : bra_perm;
result = process_permutations(result, ket_rank, ket_perm, bra_annot, false);
return result;
}
template <typename... Args>
auto biorthogonal_nns_project_ta(TA::DistArray<Args...> const& arr,
size_t bra_rank) {
using ranges::views::iota;
size_t const rank = arr.trange().rank();
SEQUANT_ASSERT(bra_rank <= rank);
size_t const ket_rank = rank - bra_rank;
if (rank <= 4) {
return arr;
}
using numeric_type = typename TA::DistArray<Args...>::numeric_type;
size_t factorial_ket = 1;
for (size_t i = 2; i <= ket_rank; ++i) {
factorial_ket *= i;
}
numeric_type norm_factor = numeric_type(1) / numeric_type(factorial_ket);
TA::DistArray<Args...> result;
perm_t perm = iota(size_t{0}, rank) | ranges::to<perm_t>;
perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;
const auto lannot = ords_to_annot(perm);
auto process_permutations = [&lannot](const TA::DistArray<Args...>& input_arr,
size_t range_rank, perm_t range_perm,
const std::string& other_annot,
bool is_bra) -> TA::DistArray<Args...> {
if (range_rank <= 1) return input_arr;
TA::DistArray<Args...> result;
auto callback = [&]([[maybe_unused]] int parity) {
const auto range_annot = ords_to_annot(range_perm);
const auto annot = other_annot.empty()
? range_annot
: (is_bra ? range_annot + "," + other_annot
: other_annot + "," + range_annot);
// ignore parity, all permutations get same coefficient
numeric_type p_ = 1;
if (result.is_initialized()) {
result(lannot) += p_ * input_arr(annot);
} else {
result(lannot) = p_ * input_arr(annot);
}
};
antisymmetric_permutation(ParticleRange{range_perm.begin(), range_rank},
callback);
return result;
};
// identity term with coefficient +1
result(lannot) = arr(lannot);
// process only ket permutations with coefficient norm_factor
if (ket_rank > 1) {
const auto bra_annot = bra_rank == 0 ? "" : ords_to_annot(bra_perm);
auto ket_result =
process_permutations(arr, ket_rank, ket_perm, bra_annot, false);
result(lannot) -= norm_factor * ket_result(lannot);
}
TA::DistArray<Args...>::wait_for_lazy_cleanup(result.world());
return result;
}
template <typename... Args>
auto biorthogonal_nns_project_btas(btas::Tensor<Args...> const& arr,
size_t bra_rank) {
using ranges::views::concat;
using ranges::views::iota;
size_t const rank = arr.rank();
SEQUANT_ASSERT(bra_rank <= rank);
size_t const ket_rank = rank - bra_rank;
if (rank <= 4) {
return arr;
}
using numeric_type = typename btas::Tensor<Args...>::numeric_type;
size_t factorial_ket = 1;
for (size_t i = 2; i <= ket_rank; ++i) {
factorial_ket *= i;
}
numeric_type norm_factor = numeric_type(1) / numeric_type(factorial_ket);
perm_t bra_perm = iota(size_t{0}, bra_rank) | ranges::to<perm_t>;
perm_t ket_perm = iota(bra_rank, rank) | ranges::to<perm_t>;
const auto lannot = iota(size_t{0}, rank) | ranges::to<perm_t>;
auto process_permutations = [&lannot](const btas::Tensor<Args...>& input_arr,
size_t range_rank, perm_t range_perm,
const perm_t& other_perm, bool is_bra) {
if (range_rank <= 1) return input_arr;
btas::Tensor<Args...> result{input_arr.range()};
result.fill(0);
auto callback = [&]([[maybe_unused]] int parity) {
const auto annot =
is_bra ? concat(range_perm, other_perm) | ranges::to<perm_t>()
: concat(other_perm, range_perm) | ranges::to<perm_t>();
// ignore parity, all permutations get same coefficient
numeric_type p_ = 1;
btas::Tensor<Args...> temp;
btas::permute(input_arr, lannot, temp, annot);
btas::scal(p_, temp);
result += temp;
};
antisymmetric_permutation(ParticleRange{range_perm.begin(), range_rank},
callback);
return result;
};
// identity term with coefficient +1
auto result = arr;
// process only ket permutations with coefficient norm_factor
if (ket_rank > 1) {
const auto bra_annot = bra_rank == 0 ? perm_t{} : bra_perm;
auto ket_result =
process_permutations(arr, ket_rank, ket_perm, bra_annot, false);
btas::scal(norm_factor, ket_result);
result -= ket_result;
}
return result;
}
template <typename... Args>
inline void log_result(Args const&... args) noexcept {
auto& l = Logger::instance();
if (l.eval.level > 1) write_log(l, args...);
}
template <typename... Args>
inline void log_ta(Args const&... args) noexcept {
log_result("[TA] ", args...);
}
template <typename... Args>
inline void log_constant(Args const&... args) noexcept {
log_result("[CONST] ", args...);
}
} // namespace
void log_ta_tensor_host_memory_use(madness::World& world,
std::string_view label = "");
/******************************************************************************/
template <typename T, typename... Args>
ResultPtr eval_result(Args&&... args) noexcept {
return std::make_shared<T>(std::forward<Args>(args)...);
}
template <typename T>
struct Annot {
explicit Annot(std::array<std::any, 3> const& a)
: lannot(std::any_cast<T>(a[0])),
rannot(std::any_cast<T>(a[1])),
this_annot(std::any_cast<T>(a[2])) {}
T const lannot;
T const rannot;
T const this_annot;
};
class Result {
public:
using id_t = size_t;
virtual ~Result() noexcept = default;
template <typename T>
[[nodiscard]] bool is() const noexcept {
return this->type_id() == id_for_type<std::decay_t<T>>();
}
template <typename T>
[[nodiscard]] T const& as() const {
SEQUANT_ASSERT(this->is<std::decay_t<T>>());
return static_cast<T const&>(*this);
}
[[nodiscard]] virtual ResultPtr sum(Result const&,
std::array<std::any, 3> const&) const = 0;
[[nodiscard]] virtual ResultPtr prod(Result const&,
std::array<std::any, 3> const&,
TA::DeNest DeNestFlag) const = 0;
[[nodiscard]] virtual ResultPtr permute(
std::array<std::any, 2> const&) const = 0;
virtual void add_inplace(Result const&) = 0;
[[nodiscard]] virtual ResultPtr symmetrize() const = 0;
[[nodiscard]] virtual ResultPtr antisymmetrize(size_t bra_rank) const = 0;
[[nodiscard]] virtual ResultPtr biorthogonal_nns_project(
size_t bra_rank) const = 0;
[[nodiscard]] bool has_value() const noexcept;
[[nodiscard]] virtual ResultPtr mult_by_phase(std::int8_t) const = 0;
template <typename T>
[[nodiscard]] T& get() {
SEQUANT_ASSERT(has_value());
return *std::any_cast<T>(&value_);
}
template <typename T>
[[nodiscard]] T const& get() const {
SEQUANT_ASSERT(has_value());
return *std::any_cast<const T>(&value_);
}
[[nodiscard]] virtual std::size_t size_in_bytes() const = 0;
protected:
template <typename T,
typename = std::enable_if_t<!std::is_convertible_v<T, Result>>>
explicit Result(T&& arg) noexcept
: value_{std::make_any<std::decay_t<T>>(std::forward<T>(arg))} {}
[[nodiscard]] virtual id_t type_id() const noexcept = 0;
template <typename T>
[[nodiscard]] static id_t id_for_type() noexcept {
static id_t id = next_id();
return id;
}
private:
std::any value_;
[[nodiscard]] static id_t next_id() noexcept;
};
template <typename T>
class ResultScalar final : public Result {
public:
using Result::id_t;
explicit ResultScalar(T v) noexcept : Result{std::move(v)} {}
[[nodiscard]] T value() const noexcept { return get<T>(); }
[[nodiscard]] ResultPtr sum(Result const& other,
std::array<std::any, 3> const&) const override {
if (other.is<ResultScalar<T>>()) {
auto const& o = other.as<ResultScalar<T>>();
auto s = value() + o.value();
log_constant(value(), " + ", o.value(), " = ", s, "\n");
return eval_result<ResultScalar<T>>(s);
} else {
throw invalid_operand();
}
}
[[nodiscard]] ResultPtr prod(Result const& other,
std::array<std::any, 3> const& maybe_empty,
TA::DeNest DeNestFlag) const override {
if (other.is<ResultScalar<T>>()) {
auto const& o = other.as<ResultScalar<T>>();
auto p = value() * o.value();
log_constant(value(), " * ", o.value(), " = ", p, "\n");
return eval_result<ResultScalar<T>>(value() * o.value());
} else {
auto maybe_empty_ = maybe_empty;
std::swap(maybe_empty_[0], maybe_empty_[1]);
return other.prod(*this, maybe_empty_, DeNestFlag);
}
}
[[nodiscard]] ResultPtr permute(
std::array<std::any, 2> const&) const override {
throw unimplemented_method("permute");
}
void add_inplace(Result const& other) override {
SEQUANT_ASSERT(other.is<ResultScalar<T>>());
log_constant(value(), " += ", other.get<T>(), "\n");
auto& val = get<T>();
val += other.get<T>();
}
[[nodiscard]] ResultPtr symmetrize() const override {
throw unimplemented_method("symmetrize");
}
[[nodiscard]] ResultPtr antisymmetrize(size_t /*bra_rank*/) const override {
throw unimplemented_method("antisymmetrize");
}
[[nodiscard]] ResultPtr biorthogonal_nns_project(
[[maybe_unused]] size_t bra_rank) const override {
throw unimplemented_method("biorthogonal_nns_project");
}
[[nodiscard]] ResultPtr mult_by_phase(std::int8_t factor) const override {
return eval_result<ResultScalar<T>>(value() * T(factor));
}
private:
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<ResultScalar<T>>();
}
[[nodiscard]] std::size_t size_in_bytes() const final { return sizeof(T); }
};
template <typename ArrayT, typename = std::enable_if_t<TA::detail::is_tensor_v<
typename ArrayT::value_type>>>
class ResultTensorTA final : public Result {
public:
using Result::id_t;
using numeric_type = typename ArrayT::numeric_type;
explicit ResultTensorTA(ArrayT arr) : Result{std::move(arr)} {}
private:
using this_type = ResultTensorTA<ArrayT>;
using annot_wrap = Annot<std::string>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<this_type>();
}
[[nodiscard]] ResultPtr sum(
Result const& other,
std::array<std::any, 3> const& annot) const override {
SEQUANT_ASSERT(other.is<this_type>());
auto const a = annot_wrap{annot};
log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result(a.this_annot) =
get<ArrayT>()(a.lannot) + other.get<ArrayT>()(a.rannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ResultPtr prod(Result const& other,
std::array<std::any, 3> const& annot,
TA::DeNest DeNestFlag) const override {
auto const a = annot_wrap{annot};
if (other.is<ResultScalar<numeric_type>>()) {
auto result = get<ArrayT>();
auto scalar = other.get<numeric_type>();
log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n");
result(a.this_annot) = scalar * result(a.lannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
if (a.this_annot.empty()) {
// DOT product
SEQUANT_ASSERT(other.is<this_type>());
numeric_type d =
TA::dot(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot));
ArrayT::wait_for_lazy_cleanup(get<ArrayT>().world());
ArrayT::wait_for_lazy_cleanup(other.get<ArrayT>().world());
log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n");
return eval_result<ResultScalar<numeric_type>>(d);
}
if (!other.is<this_type>()) {
// potential T * ToT
auto annot_swap = annot;
std::swap(annot_swap[0], annot_swap[1]);
return other.prod(*this, annot_swap, DeNestFlag);
}
// confirmed: other.is<this_type>() is true
log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result = TA::einsum(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot),
a.this_annot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ResultPtr mult_by_phase(std::int8_t factor) const override {
auto pre = get<ArrayT>();
TA::scale(pre, numeric_type(factor));
return eval_result<this_type>(std::move(pre));
}
[[nodiscard]] ResultPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<std::string>(ann[0]);
auto const post_annot = std::any_cast<std::string>(ann[1]);
log_ta(pre_annot, " = ", post_annot, "\n");
ArrayT result;
result(post_annot) = get<ArrayT>()(pre_annot);
ArrayT::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
void add_inplace(Result const& other) override {
SEQUANT_ASSERT(other.is<this_type>());
auto& t = get<ArrayT>();
auto const& o = other.get<ArrayT>();
SEQUANT_ASSERT(t.trange() == o.trange());
auto ann = TA::detail::dummy_annotation(t.trange().rank());
log_ta(ann, " += ", ann, "\n");
t(ann) += o(ann);
ArrayT::wait_for_lazy_cleanup(t.world());
}
[[nodiscard]] ResultPtr symmetrize() const override {
return eval_result<this_type>(column_symmetrize_ta(get<ArrayT>()));
}
[[nodiscard]] ResultPtr antisymmetrize(size_t bra_rank) const override {
return eval_result<this_type>(
particle_antisymmetrize_ta(get<ArrayT>(), bra_rank));
}
[[nodiscard]] ResultPtr biorthogonal_nns_project(
size_t bra_rank) const override {
return eval_result<this_type>(
biorthogonal_nns_project_ta(get<ArrayT>(), bra_rank));
}
private:
[[nodiscard]] std::size_t size_in_bytes() const final {
auto& v = get<ArrayT>();
auto local_size = TA::size_of<TA::MemorySpace::Host>(v);
v.world().gop.sum(local_size);
return local_size;
}
};
template <typename ArrayT,
typename = std::enable_if_t<
TA::detail::is_tensor_of_tensor_v<typename ArrayT::value_type>>>
class ResultTensorOfTensorTA final : public Result {
public:
using Result::id_t;
using numeric_type = typename ArrayT::numeric_type;
explicit ResultTensorOfTensorTA(ArrayT arr) : Result{std::move(arr)} {}
private:
using this_type = ResultTensorOfTensorTA<ArrayT>;
using annot_wrap = Annot<std::string>;
using _inner_tensor_type = typename ArrayT::value_type::value_type;
using compatible_regular_distarray_type =
TA::DistArray<_inner_tensor_type, typename ArrayT::policy_type>;
// Only @c that_type type is allowed for ToT * T computation
using that_type = ResultTensorTA<compatible_regular_distarray_type>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<this_type>();
}
[[nodiscard]] ResultPtr sum(
Result const& other,
std::array<std::any, 3> const& annot) const override {
SEQUANT_ASSERT(other.is<this_type>());
auto const a = annot_wrap{annot};
log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n");
ArrayT result;
result(a.this_annot) =
get<ArrayT>()(a.lannot) + other.get<ArrayT>()(a.rannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
[[nodiscard]] ResultPtr prod(Result const& other,
std::array<std::any, 3> const& annot,
TA::DeNest DeNestFlag) const override {
auto const a = annot_wrap{annot};
if (other.is<ResultScalar<numeric_type>>()) {
auto result = get<ArrayT>();
auto scalar = other.get<numeric_type>();
log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n");
result(a.this_annot) = scalar * result(a.lannot);
decltype(result)::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
} else if (a.this_annot.empty()) {
// DOT product
SEQUANT_ASSERT(other.is<this_type>());
numeric_type d =
TA::dot(get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot));
ArrayT::wait_for_lazy_cleanup(get<ArrayT>().world());
ArrayT::wait_for_lazy_cleanup(other.get<ArrayT>().world());
log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n");
return eval_result<ResultScalar<numeric_type>>(d);
}
log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n");
if (other.is<that_type>()) {
// ToT * T -> ToT
auto result =
TA::einsum(get<ArrayT>()(a.lannot),
other.get<compatible_regular_distarray_type>()(a.rannot),
a.this_annot);
return eval_result<this_type>(std::move(result));
} else if (other.is<this_type>() && DeNestFlag == TA::DeNest::True) {
// ToT * ToT -> T
auto result = TA::einsum<TA::DeNest::True>(
get<ArrayT>()(a.lannot), other.get<ArrayT>()(a.rannot), a.this_annot);
return eval_result<that_type>(std::move(result));
} else if (other.is<this_type>() && DeNestFlag == TA::DeNest::False) {
// ToT * ToT -> ToT
auto result = TA::einsum(get<ArrayT>()(a.lannot),
other.get<ArrayT>()(a.rannot), a.this_annot);
return eval_result<this_type>(std::move(result));
} else {
throw invalid_operand();
}
}
[[nodiscard]] ResultPtr mult_by_phase(std::int8_t factor) const override {
auto pre = get<ArrayT>();
TA::scale(pre, numeric_type(factor));
return eval_result<this_type>(std::move(pre));
}
[[nodiscard]] ResultPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<std::string>(ann[0]);
auto const post_annot = std::any_cast<std::string>(ann[1]);
log_ta(pre_annot, " = ", post_annot, "\n");
ArrayT result;
result(post_annot) = get<ArrayT>()(pre_annot);
ArrayT::wait_for_lazy_cleanup(result.world());
return eval_result<this_type>(std::move(result));
}
void add_inplace(Result const& other) override {
SEQUANT_ASSERT(other.is<this_type>());
auto& t = get<ArrayT>();
auto const& o = other.get<ArrayT>();
SEQUANT_ASSERT(t.trange() == o.trange());
auto ann = TA::detail::dummy_annotation(t.trange().rank());
log_ta(ann, " += ", ann, "\n");
t(ann) += o(ann);
ArrayT::wait_for_lazy_cleanup(t.world());
}
[[nodiscard]] ResultPtr symmetrize() const override {
// not implemented yet
return nullptr;
}
[[nodiscard]] ResultPtr antisymmetrize(size_t /*bra_rank*/) const override {
// not implemented yet
return nullptr;
}
[[nodiscard]] ResultPtr biorthogonal_nns_project(
[[maybe_unused]] size_t bra_rank) const override {
// or? throw unimplemented_method("biorthogonal_nns_project");
// not implemented yet, I think I need it for CSV
return nullptr;
}
private:
[[nodiscard]] std::size_t size_in_bytes() const final {
auto& v = get<ArrayT>();
auto local_size = TA::size_of<TA::MemorySpace::Host>(v);
v.world().gop.sum(local_size);
return local_size;
}
};
template <typename T>
class ResultTensorBTAS final : public Result {
public:
using Result::id_t;
using numeric_type = typename T::numeric_type;
explicit ResultTensorBTAS(T arr) : Result{std::move(arr)} {}
private:
// TODO make it same as that used by EvalExprBTAS class from eval.hpp file
using annot_t = container::svector<long>;
using annot_wrap = Annot<annot_t>;
[[nodiscard]] id_t type_id() const noexcept override {
return id_for_type<ResultTensorBTAS<T>>();
}
[[nodiscard]] ResultPtr sum(
Result const& other,
std::array<std::any, 3> const& annot) const override {
SEQUANT_ASSERT(other.is<ResultTensorBTAS<T>>());
auto const a = annot_wrap{annot};
T lres, rres;
btas::permute(get<T>(), a.lannot, lres, a.this_annot);
btas::permute(other.get<T>(), a.rannot, rres, a.this_annot);
return eval_result<ResultTensorBTAS<T>>(lres + rres);
}
[[nodiscard]] ResultPtr prod(Result const& other,
std::array<std::any, 3> const& annot,
TA::DeNest /*DeNestFlag*/) const override {
auto const a = annot_wrap{annot};
if (other.is<ResultScalar<numeric_type>>()) {
T result;
btas::permute(get<T>(), a.lannot, result, a.this_annot);
btas::scal(other.as<ResultScalar<numeric_type>>().value(), result);
return eval_result<ResultTensorBTAS<T>>(std::move(result));
}
SEQUANT_ASSERT(other.is<ResultTensorBTAS<T>>());
if (a.this_annot.empty()) {
T rres;
btas::permute(other.get<T>(), a.rannot, rres, a.lannot);
return eval_result<ResultScalar<numeric_type>>(btas::dot(get<T>(), rres));
}
T result;
btas::contract(numeric_type{1}, //
get<T>(), a.lannot, //
other.get<T>(), a.rannot, //
numeric_type{0}, //
result, a.this_annot);
return eval_result<ResultTensorBTAS<T>>(std::move(result));
}
[[nodiscard]] ResultPtr mult_by_phase(std::int8_t factor) const override {
auto pre = get<T>();
btas::scal(numeric_type(factor), pre);
return eval_result<ResultTensorBTAS<T>>(std::move(pre));
}
[[nodiscard]] ResultPtr permute(
std::array<std::any, 2> const& ann) const override {
auto const pre_annot = std::any_cast<annot_t>(ann[0]);
auto const post_annot = std::any_cast<annot_t>(ann[1]);
T result;
btas::permute(get<T>(), pre_annot, result, post_annot);
return eval_result<ResultTensorBTAS<T>>(std::move(result));
}
void add_inplace(Result const& other) override {
auto& t = get<T>();
auto const& o = other.get<T>();
SEQUANT_ASSERT(t.range() == o.range());
t += o;
}
[[nodiscard]] ResultPtr symmetrize() const override {
return eval_result<ResultTensorBTAS<T>>(column_symmetrize_btas(get<T>()));
}
[[nodiscard]] ResultPtr antisymmetrize(size_t bra_rank) const override {
return eval_result<ResultTensorBTAS<T>>(
particle_antisymmetrize_btas(get<T>(), bra_rank));
}
[[nodiscard]] ResultPtr biorthogonal_nns_project(
[[maybe_unused]] size_t bra_rank) const override {
return eval_result<ResultTensorBTAS<T>>(
biorthogonal_nns_project_btas(get<T>(), bra_rank));
}
private:
[[nodiscard]] std::size_t size_in_bytes() const final {
static_assert(std::is_arithmetic_v<typename T::value_type>);
const auto& tensor = get<T>();
// only count data
return tensor.range().volume() * sizeof(T);
}
};
} // namespace sequant
#endif // SEQUANT_EVAL_RESULT_HPP