.. _program_listing_file_SeQuant_domain_eval_eval_result.hpp: Program Listing for File eval_result.hpp ======================================== |exhale_lsh| :ref:`Return to documentation for file ` (``SeQuant/domain/eval/eval_result.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef SEQUANT_EVAL_RESULT_HPP #define SEQUANT_EVAL_RESULT_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace sequant { namespace { [[maybe_unused]] std::logic_error invalid_operand( std::string_view msg = "Invalid operand for binary op") noexcept { return std::logic_error{msg.data()}; } [[maybe_unused]] std::logic_error unimplemented_method( std::string_view msg) noexcept { using namespace std::string_literals; return std::logic_error{"Not implemented in this derived class: "s + msg.data()}; } template struct Annot { explicit Annot(std::array const& a) : lannot(std::any_cast(a[0])), rannot(std::any_cast(a[1])), this_annot(std::any_cast(a[2])) {} T const lannot; T const rannot; T const this_annot; }; // It is an iterator type template struct IterPair { It first, second; IterPair(It beg, It end) noexcept : first{beg}, second{end} {}; }; template void swap(IterPair& l, IterPair& r) { using std::iter_swap; std::iter_swap(l.first, r.first); std::iter_swap(l.second, r.second); } template bool operator<(IterPair const& l, IterPair const& r) noexcept { return *l.first < *r.first; } auto valid_particle_range = [](auto const& tpl) -> bool { using std::distance; auto [b1, b2, l] = tpl; return distance(b1, b1 + l) == distance(b2, b2 + l); }; auto iter_pairs = [](auto&& tpl) { using ranges::views::iota; using ranges::views::transform; using std::get; auto b1 = get<0>(tpl); auto b2 = get<1>(tpl); auto l = get<2>(tpl); return iota(size_t{0}, l) | transform([b1, b2](auto i) { return IterPair{b1 + i, b2 + i}; }) | ranges::to_vector; }; using perm_t = container::svector; using particle_range_t = std::array; template , bool> = true> void antisymmetric_permutation( container::svector< std::tuple> const& groups, F const& call_back) { using ranges::views::transform; auto const n = groups.size(); if (n == 0) return; assert(ranges::all_of(groups, valid_particle_range)); call_back(0); for (int i = n - 1; i >= 0; --i) { auto [bra_beg, ket_beg, len] = groups[i]; auto bra_end = bra_beg + len; auto ket_end = ket_beg + len; int bra_p = 0; auto outer = 0; for (auto bra_yn = true; bra_yn; bra_yn = next_permutation_parity(bra_p, bra_beg, bra_end), ++outer) { auto inner = 0; int ket_p = 0; for (auto ket_yn = true; ket_yn; ket_yn = next_permutation_parity(ket_p, ket_beg, ket_end), ++inner) { if (!(outer == 0 && inner == 0)) call_back((bra_p + ket_p) % 2); } } } } template , bool> = true> void symmetric_permutation( container::svector< std::tuple> const& groups, F const& call_back) { using ranges::views::transform; auto const n = groups.size(); if (n == 0) return; assert(ranges::all_of(groups, valid_particle_range)); auto groups_vec = groups | transform(iter_pairs) | ranges::to_vector; call_back(); // using reverse iterator (instead of indices) not allowed for some reason // iter from the end group for (int i = n - 1; i >= 0; --i) { auto beg = groups_vec[i].begin(); auto end = groups_vec[i].end(); auto yn = std::next_permutation(beg, end); for (; yn; yn = std::next_permutation(beg, end)) call_back(); } } template < typename F, std::enable_if_t, bool> = true> void antisymmetrize_backend(size_t rank, container::svector const& groups, F const& call_back) { using ranges::views::iota; auto perm = iota(size_t{0}, rank) | ranges::to; auto groups_vec = container::svector< std::tuple>{}; groups_vec.reserve(groups.size()); auto beg = perm.begin(); for (auto&& g : groups) { groups_vec.emplace_back(beg + g[0], beg + g[1], g[2]); } antisymmetric_permutation( groups_vec, [&call_back, &perm = std::as_const(perm)](int p) { call_back(p, perm); }); } template , bool> = true> void symmetrize_backend(size_t rank, container::svector const& groups, F const& call_back) { using ranges::views::iota; auto perm = iota(size_t{0}, rank) | ranges::to; auto groups_vec = container::svector< std::tuple>{}; groups_vec.reserve(groups.size()); auto beg = perm.begin(); for (auto&& g : groups) { groups_vec.emplace_back(beg + g[0], beg + g[1], g[2]); } symmetric_permutation( groups_vec, [&call_back, &perm = std::as_const(perm)]() { call_back(perm); }); } template std::string ords_to_annot(RngOfOrdinals const& ords) { using ranges::views::intersperse; using ranges::views::join; using ranges::views::transform; auto to_str = [](auto x) { return std::to_string(x); }; return ords | transform(to_str) | intersperse(std::string{","}) | join | ranges::to; } template auto index_hash(Iterable const& bk) { return ranges::views::transform(bk, [](auto const& idx) { // // WARNING! // The BTAS expects index types to be long by default. // There is no straight-forward way to turn the default. // Hence, here we explicitly cast the size_t values to long // Which is a potentially narrowing conversion leading to // integral overflow. Hence, the values in the returned // container are mixed negative and positive integers (long type) // return static_cast(sequant::hash::value(Index{idx}.label())); }); } template auto symmetrize_ta(TA::DistArray const& arr, container::svector const& groups) { using ranges::views::iota; size_t const rank = arr.trange().rank(); TA::DistArray result; auto const lannot = ords_to_annot(iota(size_t{0}, rank)); auto call_back = [&result, &lannot, &arr](perm_t const& perm) { auto const rannot = ords_to_annot(perm); if (result.is_initialized()) { result(lannot) += arr(rannot); } else { result(lannot) = arr(rannot); } }; symmetrize_backend(rank, groups, call_back); TA::DistArray::wait_for_lazy_cleanup(result.world()); return result; } template auto antisymmetrize_ta( TA::DistArray const& arr, container::svector const& groups = {}) { using ranges::views::iota; size_t const rank = arr.trange().rank(); TA::DistArray result; auto const lannot = ords_to_annot(iota(size_t{0}, rank)); auto call_back = [&lannot, &arr, &result](int p, perm_t const& perm) { typename decltype(result)::numeric_type p_ = p == 0 ? 1 : -1; if (result.is_initialized()) result(lannot) += p_ * arr(ords_to_annot(perm)); else result(lannot) = p_ * arr(ords_to_annot(perm)); }; antisymmetrize_backend(rank, groups, call_back); TA::DistArray::wait_for_lazy_cleanup(result.world()); return result; } template auto symmetrize_btas(btas::Tensor const& arr, container::svector const& groups) { using ranges::views::iota; size_t const rank = arr.rank(); // Caveat: // clang-format off // auto const lannot = iota(size_t{0}, rank) | ranges::to; // clang-format on auto const lannot = [rank]() { auto p = perm_t(rank); for (auto i = 0; i < rank; ++i) p[i] = i; return p; }(); auto result = btas::Tensor{arr.range()}; result.fill(0); auto call_back = [&result, &lannot, &arr](auto const& permutation) { auto const& rannot = permutation; btas::Tensor temp; btas::permute(arr, lannot, temp, rannot); result += temp; }; symmetrize_backend(rank, groups, call_back); return result; } template auto antisymmetrize_btas( btas::Tensor const& arr, container::svector const& groups = {}) { using ranges::views::iota; size_t const rank = arr.rank(); // Caveat: // auto const lannot = iota(size_t{0}, rank) | ranges::to; // auto const lannot = [rank]() { auto p = perm_t(rank); for (auto i = 0; i < rank; ++i) p[i] = i; return p; }(); auto result = btas::Tensor{arr.range()}; result.fill(0); auto call_back = [&result, &lannot, &arr](int p, perm_t const& perm) { typename decltype(result)::numeric_type p_ = p == 0 ? 1 : -1; auto const& rannot = perm; btas::Tensor temp; btas::permute(arr, lannot, temp, rannot); btas::scal(p_, temp); result += temp; }; antisymmetrize_backend(rank, groups, call_back); return result; } template inline void log_result(Args const&... args) noexcept { #ifdef SEQUANT_EVAL_TRACE auto l = Logger::instance(); if (l->log_level_eval > 1) write_log(l, args...); #endif } template inline void log_ta(Args const&... args) noexcept { #ifdef SEQUANT_EVAL_TRACE log_result("[TA] ", args...); #endif } template inline void log_constant(Args const&... args) noexcept { #ifdef SEQUANT_EVAL_TRACE log_result("[CONST] ", args...); #endif } } // namespace void log_ta_tensor_host_memory_use(madness::World& world, std::string_view label = ""); class EvalResult; using ERPtr = std::shared_ptr; template ERPtr eval_result(Args&&... args) noexcept { return std::make_shared(std::forward(args)...); } class EvalResult { public: using id_t = size_t; virtual ~EvalResult() noexcept = default; template [[nodiscard]] bool is() const noexcept { return this->type_id() == id_for_type>(); } template [[nodiscard]] T const& as() const { assert(this->is>()); return static_cast(*this); } [[nodiscard]] virtual ERPtr sum(EvalResult const&, std::array const&) const = 0; [[nodiscard]] virtual ERPtr prod(EvalResult const&, std::array const&, TA::DeNest DeNestFlag) const = 0; [[nodiscard]] virtual ERPtr permute(std::array const&) const = 0; virtual void add_inplace(EvalResult const&) = 0; [[nodiscard]] virtual ERPtr symmetrize( container::svector> const&) const = 0; [[nodiscard]] virtual ERPtr antisymmetrize( container::svector> const&) const = 0; [[nodiscard]] bool has_value() const noexcept; template [[nodiscard]] T& get() { assert(has_value()); return *std::any_cast(&value_); } template [[nodiscard]] T const& get() const { return const_cast(*this).get(); } protected: template >> explicit EvalResult(T&& arg) noexcept : value_{std::make_any>(std::forward(arg))} {} [[nodiscard]] virtual id_t type_id() const noexcept = 0; template [[nodiscard]] static id_t id_for_type() noexcept { static id_t id = next_id(); return id; } private: std::any value_; [[nodiscard]] static id_t next_id() noexcept; }; template class EvalScalar final : public EvalResult { public: using EvalResult::id_t; explicit EvalScalar(T v) noexcept : EvalResult{std::move(v)} {} [[nodiscard]] T value() const noexcept { return get(); } [[nodiscard]] ERPtr sum(EvalResult const& other, std::array const&) const override { if (other.is>()) { auto const& o = other.as>(); auto s = value() + o.value(); log_constant(value(), " + ", o.value(), " = ", s, "\n"); return eval_result>(s); } else { throw invalid_operand(); } } [[nodiscard]] ERPtr prod(EvalResult const& other, std::array const& maybe_empty, TA::DeNest DeNestFlag) const override { if (other.is>()) { auto const& o = other.as>(); auto p = value() * o.value(); log_constant(value(), " * ", o.value(), " = ", p, "\n"); return eval_result>(value() * o.value()); } else { auto maybe_empty_ = maybe_empty; std::swap(maybe_empty_[0], maybe_empty_[1]); return other.prod(*this, maybe_empty_, DeNestFlag); } } [[nodiscard]] ERPtr permute(std::array const&) const override { throw unimplemented_method("permute"); } void add_inplace(EvalResult const& other) override { assert(other.is>()); log_constant(value(), " += ", other.get(), "\n"); auto& val = get(); val += other.get(); } [[nodiscard]] ERPtr symmetrize( container::svector> const&) const override { throw unimplemented_method("symmetrize"); } [[nodiscard]] ERPtr antisymmetrize( container::svector> const&) const override { throw unimplemented_method("antisymmetrize"); } private: [[nodiscard]] id_t type_id() const noexcept override { return id_for_type>(); } }; template >> class EvalTensorTA final : public EvalResult { public: using EvalResult::id_t; using numeric_type = typename ArrayT::numeric_type; explicit EvalTensorTA(ArrayT arr) : EvalResult{std::move(arr)} {} private: using this_type = EvalTensorTA; using annot_wrap = Annot; [[nodiscard]] id_t type_id() const noexcept override { return id_for_type(); } [[nodiscard]] ERPtr sum(EvalResult const& other, std::array const& annot) const override { assert(other.is()); auto const a = annot_wrap{annot}; log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n"); ArrayT result; result(a.this_annot) = get()(a.lannot) + other.get()(a.rannot); decltype(result)::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } [[nodiscard]] ERPtr prod(EvalResult const& other, std::array const& annot, TA::DeNest DeNestFlag) const override { auto const a = annot_wrap{annot}; if (other.is>()) { auto result = get(); auto scalar = other.get(); log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n"); result(a.this_annot) = scalar * result(a.lannot); decltype(result)::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } if (a.this_annot.empty()) { // DOT product assert(other.is()); numeric_type d = TA::dot(get()(a.lannot), other.get()(a.rannot)); ArrayT::wait_for_lazy_cleanup(get().world()); ArrayT::wait_for_lazy_cleanup(other.get().world()); log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n"); return eval_result>(d); } if (!other.is()) { // potential T * ToT auto annot_swap = annot; std::swap(annot_swap[0], annot_swap[1]); return other.prod(*this, annot_swap, DeNestFlag); } // confirmed: other.is() is true log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n"); ArrayT result; result = TA::einsum(get()(a.lannot), other.get()(a.rannot), a.this_annot); decltype(result)::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } [[nodiscard]] ERPtr permute( std::array const& ann) const override { auto const pre_annot = std::any_cast(ann[0]); auto const post_annot = std::any_cast(ann[1]); log_ta(pre_annot, " = ", post_annot, "\n"); ArrayT result; result(post_annot) = get()(pre_annot); ArrayT::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } void add_inplace(EvalResult const& other) override { assert(other.is()); auto& t = get(); auto const& o = other.get(); assert(t.trange() == o.trange()); auto ann = TA::detail::dummy_annotation(t.trange().rank()); log_ta(ann, " += ", ann, "\n"); t(ann) += o(ann); ArrayT::wait_for_lazy_cleanup(t.world()); } [[nodiscard]] ERPtr symmetrize( container::svector> const& groups) const override { return eval_result(symmetrize_ta(get(), groups)); } [[nodiscard]] ERPtr antisymmetrize( container::svector> const& groups) const override { return eval_result(antisymmetrize_ta(get(), groups)); } }; template >> class EvalTensorOfTensorTA final : public EvalResult { public: using EvalResult::id_t; using numeric_type = typename ArrayT::numeric_type; explicit EvalTensorOfTensorTA(ArrayT arr) : EvalResult{std::move(arr)} {} private: using this_type = EvalTensorOfTensorTA; using annot_wrap = Annot; using _inner_tensor_type = typename ArrayT::value_type::value_type; using compatible_regular_distarray_type = TA::DistArray<_inner_tensor_type, typename ArrayT::policy_type>; // Only @c that_type type is allowed for ToT * T computation using that_type = EvalTensorTA; [[nodiscard]] id_t type_id() const noexcept override { return id_for_type(); } [[nodiscard]] ERPtr sum(EvalResult const& other, std::array const& annot) const override { assert(other.is()); auto const a = annot_wrap{annot}; log_ta(a.lannot, " + ", a.rannot, " = ", a.this_annot, "\n"); ArrayT result; result(a.this_annot) = get()(a.lannot) + other.get()(a.rannot); decltype(result)::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } [[nodiscard]] ERPtr prod(EvalResult const& other, std::array const& annot, TA::DeNest DeNestFlag) const override { auto const a = annot_wrap{annot}; if (other.is>()) { auto result = get(); auto scalar = other.get(); log_ta(a.lannot, " * ", scalar, " = ", a.this_annot, "\n"); result(a.this_annot) = scalar * result(a.lannot); decltype(result)::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } else if (a.this_annot.empty()) { // DOT product assert(other.is()); numeric_type d = TA::dot(get()(a.lannot), other.get()(a.rannot)); ArrayT::wait_for_lazy_cleanup(get().world()); ArrayT::wait_for_lazy_cleanup(other.get().world()); log_ta(a.lannot, " * ", a.rannot, " = ", d, "\n"); return eval_result>(d); } log_ta(a.lannot, " * ", a.rannot, " = ", a.this_annot, "\n"); if (other.is()) { // ToT * T -> ToT auto result = TA::einsum(get()(a.lannot), other.get()(a.rannot), a.this_annot); return eval_result(std::move(result)); } else if (other.is() && DeNestFlag == TA::DeNest::True) { // ToT * ToT -> T auto result = TA::einsum( get()(a.lannot), other.get()(a.rannot), a.this_annot); return eval_result(std::move(result)); } else if (other.is() && DeNestFlag == TA::DeNest::False) { // ToT * ToT -> ToT auto result = TA::einsum(get()(a.lannot), other.get()(a.rannot), a.this_annot); return eval_result(std::move(result)); } else { throw invalid_operand(); } } [[nodiscard]] ERPtr permute( std::array const& ann) const override { auto const pre_annot = std::any_cast(ann[0]); auto const post_annot = std::any_cast(ann[1]); log_ta(pre_annot, " = ", post_annot, "\n"); ArrayT result; result(post_annot) = get()(pre_annot); ArrayT::wait_for_lazy_cleanup(result.world()); return eval_result(std::move(result)); } void add_inplace(EvalResult const& other) override { assert(other.is()); auto& t = get(); auto const& o = other.get(); assert(t.trange() == o.trange()); auto ann = TA::detail::dummy_annotation(t.trange().rank()); log_ta(ann, " += ", ann, "\n"); t(ann) += o(ann); ArrayT::wait_for_lazy_cleanup(t.world()); } [[nodiscard]] ERPtr symmetrize( container::svector> const& groups) const override { // todo // return eval_result(symmetrize_ta(get(), groups)); return nullptr; } [[nodiscard]] ERPtr antisymmetrize( container::svector> const& groups) const override { // todo // return eval_result(antisymmetrize_ta(get(), groups)); return nullptr; } }; template class EvalTensorBTAS final : public EvalResult { public: using EvalResult::id_t; using numeric_type = typename T::numeric_type; explicit EvalTensorBTAS(T arr) : EvalResult{std::move(arr)} {} private: // TODO make it same as that used by EvalExprBTAS class from eval.hpp file using annot_t = container::svector; using annot_wrap = Annot; [[nodiscard]] id_t type_id() const noexcept override { return id_for_type>(); } [[nodiscard]] ERPtr sum(EvalResult const& other, std::array const& annot) const override { assert(other.is>()); auto const a = annot_wrap{annot}; T lres, rres; btas::permute(get(), a.lannot, lres, a.this_annot); btas::permute(other.get(), a.rannot, rres, a.this_annot); return eval_result>(lres + rres); } [[nodiscard]] ERPtr prod(EvalResult const& other, std::array const& annot, TA::DeNest DeNestFlag) const override { auto const a = annot_wrap{annot}; if (other.is>()) { T result; btas::permute(get(), a.lannot, result, a.this_annot); btas::scal(other.as>().value(), result); return eval_result>(std::move(result)); } assert(other.is>()); if (a.this_annot.empty()) { T rres; btas::permute(other.get(), a.rannot, rres, a.lannot); return eval_result>(btas::dot(get(), rres)); } T result; btas::contract(numeric_type{1}, // get(), a.lannot, // other.get(), a.rannot, // numeric_type{0}, // result, a.this_annot); return eval_result>(std::move(result)); } [[nodiscard]] ERPtr permute( std::array const& ann) const override { auto const pre_annot = std::any_cast(ann[0]); auto const post_annot = std::any_cast(ann[1]); T result; btas::permute(get(), pre_annot, result, post_annot); return eval_result>(std::move(result)); } void add_inplace(EvalResult const& other) override { auto& t = get(); auto const& o = other.get(); assert(t.range() == o.range()); t += o; } [[nodiscard]] ERPtr symmetrize( container::svector> const& groups) const override { return eval_result>(symmetrize_btas(get(), groups)); } [[nodiscard]] ERPtr antisymmetrize( container::svector> const& groups) const override { return eval_result>( antisymmetrize_btas(get(), groups)); } }; } // namespace sequant #endif // SEQUANT_EVAL_RESULT_HPP