Program Listing for File wick.impl.hpp¶
↰ Return to documentation for file (SeQuant/core/wick.impl.hpp)
//
// Created by Eduard Valeyev on 3/31/18.
//
#ifndef SEQUANT_WICK_IMPL_HPP
#define SEQUANT_WICK_IMPL_HPP
#include <SeQuant/core/bliss.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/tensor_network/vertex.hpp>
#include <SeQuant/core/utility/debug.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>
#ifdef SEQUANT_HAS_EXECUTION_HEADER
#include <execution>
#endif
namespace sequant {
namespace detail {
class index_repl_dst_t {
public:
explicit index_repl_dst_t(Index dst, Index src)
: dst_(std::move(dst)), src_{std::move(src)} {
if (Logger::instance().wick_reduce) {
sequant::wprintf("index_repl_dst_t: ctor, src=", src.to_latex(),
" dst=", dst_.to_latex(), "\n");
}
}
const Index &dst() const { return dst_; }
void update_dst(Index idx) {
if (Logger::instance().wick_reduce) {
sequant::wprintf("index_repl_dst_t: changing dst=", dst_.to_latex(),
" to dst=", idx.to_latex(), "\n");
}
dst_ = std::move(idx);
}
const container::svector<Index> &src() const { return src_; }
index_repl_dst_t &append_src(Index src) {
SEQUANT_ASSERT(ranges::contains(src_, src) == false);
if (Logger::instance().wick_reduce) {
sequant::wprintf("index_repl_dst_t: appended src=", src.to_latex(),
" -> dst=", dst_.to_latex(), "\n");
}
src_.emplace_back(std::move(src));
return *this;
}
private:
Index dst_; // does not have proto indices
container::svector<Index> src_;
};
template <Statistics S>
std::optional<std::pair<container::map<Index, Index>, bool>>
compute_index_replacement_rules(
std::shared_ptr<Product> &product,
const container::set<Index> &external_indices,
const container::set<Index> &noncovariant_indices,
const std::set<Index, Index::LabelCompare> &all_indices,
const std::shared_ptr<const IndexSpaceRegistry> &isr =
get_default_context(S).index_space_registry()) {
bool zero_result_status = false;
auto zero_result = [&zero_result_status]() -> void {
zero_result_status = true;
};
#define SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(x) \
{ x; } \
if (zero_result_status) return {};
#define SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(x) \
{ x; } \
if (zero_result_status) return;
expr_range exrng(product);
bool have_kroneckers = false;
auto index_validator = [&all_indices](const Index &idx) {
return all_indices.find(idx) == all_indices.end();
};
IndexFactory idxfac(index_validator);
container::map<Index /* src */, std::shared_ptr<index_repl_dst_t> /* dst */>
src2dst; // src->dst, will be converted to
container::svector<std::shared_ptr<index_repl_dst_t>>
dst_list; // unsorted list of dst indices
// transfers proto indices from src (if any) to dst
// which proto-indices should dst inherit from src? a dst index without
// proto indices will inherit its src counterpart's indices, unless it
// already has its own protoindices: <a_ij|p> = <a_ij|a_ij> (hence replace p
// with a_ij), but <a_ij|p_kl> = <a_ij|a_kl> != <a_ij|a_ij> (hence replace
// p_kl with a_kl)
auto proto = [](const Index &dst, const Index &src) {
if (src.has_proto_indices()) {
if (dst.has_proto_indices()) {
SEQUANT_ASSERT(dst.proto_indices() == src.proto_indices());
return dst;
} else
return Index(dst, src.proto_indices());
} else {
return dst;
}
};
// adds src->dst, optionally assigning proto indices from protosrc
auto add_src_to_existing_dst = [&src2dst](const Index &src, auto srd2dst_it) {
SEQUANT_ASSERT(srd2dst_it != ranges::end(src2dst));
srd2dst_it->second->append_src(src);
src2dst.emplace(src, srd2dst_it->second);
};
// change dst index
auto replace_dst_index = [&src2dst](const Index &new_dst, auto srd2dst_it) {
SEQUANT_ASSERT(srd2dst_it != ranges::end(src2dst));
srd2dst_it->second->update_dst(new_dst);
};
// merges dst2 into dst1
auto merge_dst2_into_dst1 = [&src2dst, &dst_list](auto srd2dst_it1,
auto srd2dst_it2) {
SEQUANT_ASSERT(
srd2dst_it1->second !=
srd2dst_it2->second); // caller should ensure no self-merges,
// indicates faulty logic upstream
SEQUANT_ASSERT(srd2dst_it1 != ranges::end(src2dst));
SEQUANT_ASSERT(srd2dst_it2 != ranges::end(src2dst));
const auto dst1 = srd2dst_it1->second;
const auto dst2 = srd2dst_it2->second;
// repoint all source indices of dst2 to dst1
for (const auto &src2 : srd2dst_it2->second->src()) {
auto it = src2dst.find(src2);
SEQUANT_ASSERT(it != src2dst.end());
it->second = dst1;
dst1->append_src(src2);
}
// there should be no refs to dst2 in src2dst
SEQUANT_ASSERT(ranges::contains(src2dst, dst2, [](const auto &it) {
return it.second;
}) == false);
// erase dst2 from dst_list
auto it2 = ranges::find(dst_list, dst2);
SEQUANT_ASSERT(it2 != ranges::end(dst_list));
dst_list.erase(it2);
// there should be no "viewers" of dst2
SEQUANT_ASSERT(dst2.use_count() == 1);
};
// adds src->dst, optionally assigning proto indices from protosrc
auto add_rule = [&src2dst, &dst_list, &proto](
const Index &src, const Index &dst,
std::optional<const Index> protosrc = std::nullopt) {
auto real_dst = protosrc ? proto(dst, protosrc.value()) : proto(dst, src);
auto it = ranges::find_if(
dst_list, [&real_dst](const auto &d) { return d->dst() == real_dst; });
if (it != ranges::end(dst_list)) {
(*it)->append_src(src);
src2dst.emplace(src, *it);
} else {
[[maybe_unused]] auto insertion_result = src2dst.emplace(
src, std::make_shared<index_repl_dst_t>(real_dst, src));
SEQUANT_ASSERT(insertion_result.second);
dst_list.emplace_back(insertion_result.first->second);
}
};
// changes src->current_dst to src->intersection(dst,current_dst)
auto update_rule = [&src2dst, &proto, &isr, &idxfac, &zero_result](
auto src_it, const Index &src, const Index &dst,
std::optional<const Index> protosrc = std::nullopt) {
SEQUANT_ASSERT(src_it != src2dst.end());
auto &old_dst = src_it->second->dst();
// do we need to change space of dst?
const bool change_dst_space = (dst.space() != old_dst.space());
const IndexSpace &new_dst_space =
change_dst_space ? isr->intersection(old_dst.space(), dst.space())
: dst.space();
if (!new_dst_space) return zero_result();
// do we need to change protoindices?
bool change_dst_protoindices = false;
// already have protoindices?
if (old_dst.has_proto_indices()) {
// same index should never be mapped to indices with
// different protoindices, this is what the logic dealing with
// noncovariant indices is meant to avoid
if (protosrc.value_or(src).has_proto_indices()) {
SEQUANT_ASSERT(protosrc.value_or(src).proto_indices() ==
old_dst.proto_indices());
}
} else {
change_dst_protoindices = protosrc.value_or(src).has_proto_indices();
}
if (change_dst_space || change_dst_protoindices) {
auto plain_dst = idxfac.make(new_dst_space);
const auto real_dst = change_dst_protoindices
? proto(plain_dst, protosrc.value_or(src))
: proto(plain_dst, old_dst);
src_it->second->update_dst(real_dst);
}
};
// adds src->dst or changes src->current_dst to
// src->intersection(dst,current_dst)
auto add_or_update_rule = [&add_rule, &update_rule, &src2dst,
&zero_result_status](const Index &src,
const Index &dst) {
auto src_it = src2dst.find(src);
if (src_it == src2dst.end()) {
add_rule(src, dst);
} else {
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
update_rule(src_it, src, dst));
}
};
// adds src1->dst and src2->dst; if src1->dst1 and/or src2->dst2 already
// exist the existing rules are updated to map to the intersection of dst1,
// dst2 and dst
auto add_or_update_rules = [&add_rule, &update_rule, &add_src_to_existing_dst,
&replace_dst_index, &merge_dst2_into_dst1,
&src2dst, &idxfac, &proto, &zero_result,
&zero_result_status,
&isr](const Index &src1, const Index &src2,
const Index &dst) {
// are there replacement rules already for src{1,2}?
auto src1_it = src2dst.find(src1);
auto src2_it = src2dst.find(src2);
const auto has_src1_rule = src1_it != src2dst.end();
const auto has_src2_rule = src2_it != src2dst.end();
// which proto-indices should dst1 and dst2 inherit? a destination index
// without proto indices will inherit its source counterpart's indices,
// unless it already has its own protoindices: <a_ij|p> = <a_ij|a_ij> (hence
// replace p with a_ij), but <a_ij|p_kl> = <a_ij|a_kl> != <a_ij|a_ij> (hence
// replace p_kl with a_kl)
const auto &dst1_proto =
!src1.has_proto_indices() && src2.has_proto_indices() ? src2 : src1;
const auto &dst2_proto =
!src2.has_proto_indices() && src1.has_proto_indices() ? src1 : src2;
if (!has_src1_rule && !has_src2_rule) { // if brand new, add the rules
add_rule(src1, dst, dst1_proto);
add_rule(src2, dst, dst2_proto);
} else if (has_src1_rule && !has_src2_rule) {
// update the existing rule for src1
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
update_rule(src1_it, src1, dst, dst1_proto));
// create new rule: src2->dst1
add_src_to_existing_dst(src2, src2dst.find(src1));
} else if (!has_src1_rule && has_src2_rule) {
// update the existing rule for src2
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
update_rule(src2_it, src2, dst, dst2_proto));
// create new rule: src1->dst2
add_src_to_existing_dst(src1, src2dst.find(src2));
} else {
// merge the existing rules
// - compute new target index space
// - compute new target index
// - repoint old rules to the new target
const auto &old_dst1 = src1_it->second->dst();
const auto &old_dst2 = src2_it->second->dst();
const auto new_dst_space =
(dst.space() != old_dst1.space() || dst.space() != old_dst2.space())
? isr->intersection(
isr->intersection(old_dst1.space(), old_dst2.space()),
dst.space())
: dst.space();
if (!new_dst_space) return zero_result();
Index new_dst;
if (new_dst_space == old_dst1.space()) {
new_dst = old_dst1;
if (new_dst_space == old_dst2.space() && old_dst2 < new_dst) {
new_dst = old_dst2;
}
if (new_dst_space == dst.space() && dst < new_dst) {
new_dst = dst;
}
} else if (new_dst_space == old_dst2.space()) {
new_dst = old_dst2;
if (new_dst_space == dst.space() && dst < new_dst) {
new_dst = dst;
}
} else if (new_dst_space == dst.space()) {
new_dst = dst;
} else
new_dst = idxfac.make(new_dst_space);
// update dst1 and dst2 with new_dst, then merge them
auto new_real_dst = proto(new_dst, dst1_proto);
SEQUANT_ASSERT(
new_real_dst ==
proto(new_dst, dst2_proto)); // don't know how to handle this yet
auto src2dst_it1 = src2dst.find(src1);
replace_dst_index(new_real_dst, src2dst_it1);
// if src1 and src2 were pointing to same destination, we are done, else
// merge their destinations
auto src2dst_it2 = src2dst.find(src2);
if (src2dst_it1->second != src2dst_it2->second) {
replace_dst_index(new_real_dst, src2dst_it2);
merge_dst2_into_dst1(src2dst_it1, src2dst_it2);
}
}
};
for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
const auto &factor = *it;
if (factor.is<Tensor>()) {
const auto &tensor = factor.as<Tensor>();
const auto is_overlap = tensor.label() == overlap_label();
const auto is_kronecker = tensor.label() == kronecker_label();
if (is_overlap || is_kronecker) {
have_kroneckers = true;
SEQUANT_ASSERT(tensor.bra().size() == 1);
SEQUANT_ASSERT(tensor.ket().size() == 1);
const auto &bra = tensor.bra().at(0);
const auto &ket = tensor.ket().at(0);
// skip if
// - self-kronecker (will be replaced by 1 already)
bool do_skip = bra == ket;
// - nontrivial overlap between 2 noncovariant modes
if (is_overlap) {
// N.B. noncovariant bra or ket is OK because we can always rotate it
// to match the basis of the other
do_skip = do_skip || (noncovariant_indices.contains(bra) &&
noncovariant_indices.contains(ket));
}
if (!do_skip) {
const auto bra_is_ext = ranges::find(external_indices, bra) !=
ranges::end(external_indices);
const auto ket_is_ext = ranges::find(external_indices, ket) !=
ranges::end(external_indices);
const auto intersection_space =
isr->intersection(bra.space(), ket.space());
// if overlap's indices are from non-overlapping spaces, return zero
if (!intersection_space) {
return std::nullopt;
}
if (!bra_is_ext && !ket_is_ext) {
// int + int
const auto new_dummy = idxfac.make(intersection_space);
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
add_or_update_rules(bra, ket, new_dummy));
} else if (bra_is_ext && !ket_is_ext) { // ext + int
if (includes(ket.space(), bra.space())) {
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
add_or_update_rule(ket, bra));
} else {
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
add_or_update_rule(ket, idxfac.make(intersection_space)));
}
} else if (!bra_is_ext && ket_is_ext) { // int + ext
if (includes(bra.space(), ket.space())) {
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
add_or_update_rule(bra, ket));
} else {
SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
add_or_update_rule(bra, idxfac.make(intersection_space)));
}
}
// ext + ext => leave overlap as is
}
}
}
}
// make 1-to-1 version of src->dst
container::map<Index /* src */, Index /* dst */> result;
for (auto &&[src, d] : src2dst) {
result.emplace(src, d->dst());
}
return std::make_pair(result, have_kroneckers);
#undef SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT
#undef SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT
}
inline bool apply_index_replacement_rules(
std::shared_ptr<Product> &product,
const container::map<Index, Index> &const_replrules,
std::set<Index, Index::LabelCompare> &all_indices) {
expr_range exrng(product);
#ifdef SEQUANT_ASSERT_ENABLED
// assert that tensors_ indices are not tagged since going to tag indices
{
for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
const auto &factor = *it;
if (factor->is<AbstractTensor>()) {
auto &tensor = factor->as<AbstractTensor>();
SEQUANT_ASSERT(ranges::none_of(tensor._slots(), [](const Index &idx) {
return idx.tag().has_value();
}));
}
}
}
#endif
bool mutated = false;
bool pass_mutated = false;
do {
pass_mutated = false;
for (auto it = ranges::begin(exrng); it != ranges::end(exrng);) {
const auto &factor = *it;
if (factor->is<AbstractTensor>()) {
auto &tensor = factor->as<AbstractTensor>();
pass_mutated &= tensor._transform_indices(const_replrules);
if (tensor._label() == overlap_label() ||
tensor._label() == kronecker_label()) {
const auto bra = tensor._bra().at(0);
const auto ket = tensor._ket().at(0);
if (bra == ket) {
pass_mutated = true;
*it = ex<Constant>(1);
}
} // Kronecker delta
}
++it;
}
mutated |= pass_mutated;
} while (pass_mutated); // keep replacing til fixed point
// assert that tensors_ indices are not tagged since going to tag indices
{
for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
const auto &factor = *it;
if (factor->is<AbstractTensor>()) {
factor->as<AbstractTensor>()._reset_tags();
}
}
}
// update all_indices
std::set<Index, Index::LabelCompare> all_indices_new;
ranges::for_each(
all_indices, [&const_replrules, &all_indices_new](const Index &idx) {
auto dst_it = const_replrules.find(idx);
[[maybe_unused]] auto insertion_result = all_indices_new.emplace(
dst_it != const_replrules.end() ? dst_it->second : idx);
});
std::swap(all_indices_new, all_indices);
return mutated;
}
template <Statistics S>
bool reduce_wick_impl(std::shared_ptr<Product> &expr,
const container::set<Index> &external_indices,
const container::set<Index> &noncovariant_indices,
const Context &ctx) {
// if have noncovariant indices, will need to update them at the beginning of
// every pass
const auto have_noncovariant_indices = !noncovariant_indices.empty();
if (Logger::instance().wick_reduce) {
sequant::wprintf(
"reduce_wick_impl(expr, external_indices):\n input expr = ",
expr->to_latex(), "\n external_indices = ");
ranges::for_each(external_indices, [](auto &index) {
sequant::wprintf(index.full_label(), " ");
});
sequant::wprintf("\n");
}
std::int64_t pass = -1;
bool pass_mutated = false;
do {
pass_mutated = false;
++pass;
// extract current indices
auto idx_counter = get_used_indices_with_counts(expr);
auto all_indices =
idx_counter |
ranges::views::transform([](const auto &v) { return v.first; }) |
ranges::to<std::set<Index, Index::LabelCompare>>;
// update list of noncovariant indices every iteration
container::set<Index> all_noncovariant_indices;
if (have_noncovariant_indices) {
// see extract_indices
all_noncovariant_indices =
idx_counter |
ranges::views::filter([&external_indices](const auto &v) {
return (v.first.has_proto_indices() == true ||
v.second.proto != 0 || v.second.nonproto() != 2) &&
!external_indices.contains(v.first);
}) |
ranges::views::transform([](const auto &v) { return v.first; }) |
ranges::to<container::set<Index>>;
// augment list of noncovariant indices:
// - any index with protoindices is noncovariant
// - if kronecker delta has a noncovariant index, the other index is also
// noncovariant
expr->visit([&all_noncovariant_indices](const ExprPtr &ex) {
if (ex.is<AbstractTensor>()) {
auto &t = ex.template as<AbstractTensor>();
for (auto &idx : slots(t)) {
if (idx.has_proto_indices()) all_noncovariant_indices.emplace(idx);
}
const auto tlabel = t._label();
const auto is_kronecker =
tlabel == kronecker_label() ||
(tlabel == overlap_label() &&
get_default_context().metric() == IndexSpaceMetric::Unit &&
t._bra()[0].proto_indices() == t._ket()[0].proto_indices());
if (is_kronecker) {
SEQUANT_ASSERT(t._bra_rank() == 1);
Index b = t._bra()[0];
SEQUANT_ASSERT(t._ket_rank() == 1);
Index k = t._ket()[0];
if (all_noncovariant_indices.contains(b)) {
auto it = all_noncovariant_indices.find(k);
if (it == all_noncovariant_indices.end()) {
all_noncovariant_indices.emplace_hint(it, std::move(k));
}
} else if (all_noncovariant_indices.contains(k)) {
auto it = all_noncovariant_indices.find(b);
if (it == all_noncovariant_indices.end()) {
all_noncovariant_indices.emplace_hint(it, std::move(b));
}
}
}
}
});
if (Logger::instance().wick_reduce) {
sequant::wprintf("all_noncovariant_indices = ");
for (auto &idx : all_noncovariant_indices) {
sequant::wprintf(" ", idx.to_latex());
}
sequant::wprintf("\n");
}
}
auto nonnull_result_opt = compute_index_replacement_rules<S>(
expr, external_indices, all_noncovariant_indices, all_indices,
ctx.index_space_registry());
if (!nonnull_result_opt) return false;
const auto &[replacement_rules, found_kroneckers] = *nonnull_result_opt;
if (Logger::instance().wick_reduce) {
sequant::wprintf("reduce_wick_impl(expr, external_indices) pass=", pass,
":\n expr = ", expr->to_latex(),
"\n external_indices = ");
ranges::for_each(external_indices, [](auto &index) {
sequant::wprintf(index.full_label(), " ");
});
sequant::wprintf("\n replrules = ");
ranges::for_each(replacement_rules, [](auto &index) {
sequant::wprintf(to_latex(index.first), "\\to", to_latex(index.second),
"\\,");
});
}
// N.B. even if replacement list is empty, but have trivial kroneckers
// invoke apply_index_replacement_rules
if (found_kroneckers) {
pass_mutated =
apply_index_replacement_rules(expr, replacement_rules, all_indices);
}
if (Logger::instance().wick_reduce) {
sequant::wprintf("\n result = ", expr->to_latex(), "\n");
}
} while (pass_mutated); // keep reducing until stop changing
return true;
}
template <Statistics S>
struct NullNormalOperatorCanonicalizerDeregister {
void operator()(void *) {
const auto nop_labels = NormalOperator<S>::labels();
TensorCanonicalizer::deregister_instance(nop_labels[0]);
TensorCanonicalizer::deregister_instance(nop_labels[1]);
}
};
} // namespace detail
template <Statistics S>
void WickTheorem<S>::extract_indices(const Expr &expr,
bool force_external) const {
auto idx_counter = get_used_indices_with_counts(expr);
all_indices_ =
idx_counter |
ranges::views::transform([](const auto &v) { return v.first; }) |
ranges::to<container::set<Index>>;
if (!user_defined_external_indices_) {
const auto &copts = get_default_context().canonicalization_options();
if (copts && copts->named_indices) {
external_indices_ = copts->named_indices.value();
} else {
// external indices either appears once in nonproto slot or is pure
// protoindex
external_indices_ =
idx_counter | ranges::views::filter([force_external](const auto &v) {
return v.second.nonproto() <= 1 || force_external;
}) |
ranges::views::transform([](const auto &v) { return v.first; }) |
ranges::to<container::set<Index>>;
}
}
// covariant indices are indices that do not depend on other indices,
// are not protoindices for other indices, and are dummy, i.e. summed over by
// appearing twice in nonproto slots and not among external indices
// noncovariant indices are the rest
noncovariant_indices_ =
idx_counter | ranges::views::filter([this](const auto &v) {
return (v.first.has_proto_indices() == true || v.second.proto != 0 ||
v.second.nonproto() != 2) &&
!external_indices_->contains(v.first);
}) |
ranges::views::transform([](const auto &v) { return v.first; }) |
ranges::to<container::set<Index>>;
}
template <Statistics S>
ExprPtr WickTheorem<S>::compute(const bool count_only,
const bool skip_input_canonicalization) {
// need to avoid recanonicalization of operators produced by WickTheorem
// by rapid canonicalization to avoid undoing all the good
// the NormalOperator<S>::normalize did ... use RAII
// 1. detail::NullNormalOperatorCanonicalizerDeregister<S> will restore state
// of tensor canonicalizer
// 2. this is the RAII object whose destruction will restore state of
// the tensor canonicalizer
std::unique_ptr<void, detail::NullNormalOperatorCanonicalizerDeregister<S>>
raii_null_nop_canonicalizer;
// 3. this makes the RAII object ... NOT reentrant, only to be called in
// top-level WickTheorem after initial canonicalization
auto disable_nop_canonicalization = [&raii_null_nop_canonicalizer]() {
if (!raii_null_nop_canonicalizer) {
const auto nop_labels = NormalOperator<S>::labels();
SEQUANT_ASSERT(nop_labels.size() == 2);
TensorCanonicalizer::try_register_instance(
std::make_shared<NullTensorCanonicalizer>(), nop_labels[0]);
TensorCanonicalizer::try_register_instance(
std::make_shared<NullTensorCanonicalizer>(), nop_labels[1]);
raii_null_nop_canonicalizer = decltype(raii_null_nop_canonicalizer)(
(void *)&raii_null_nop_canonicalizer, {});
}
};
// have an Expr as input? Apply recursively ...
if (expr_input_) {
if (Logger::instance().wick_harness)
std::wcout << "WickTheorem<S>::compute: input (before expand) = "
<< to_latex_align(expr_input_) << std::endl;
expand(expr_input_);
if (Logger::instance().wick_harness)
std::wcout << "WickTheorem<S>::compute: input (after expand) = "
<< to_latex_align(expr_input_) << std::endl;
// if sum, canonicalize and apply to each summand ...
if (expr_input_->is<Sum>()) {
if (!skip_input_canonicalization) {
// initial full canonicalization
canonicalize(expr_input_);
SEQUANT_ASSERT(!expr_input_->as<Sum>().empty());
}
// NOW disable canonicalization of normal operators
// N.B. even if skipped initial input canonicalization need to disable
// subsequent nop canonicalization
disable_nop_canonicalization();
// parallelize over summands
HashingAccumulator result_acc;
std::mutex result_mtx; // serializes updates of result
auto summands = expr_input_->as<Sum>().summands();
// find external_indices if don't have them
if (!external_indices_) {
const auto &copts = get_default_context().canonicalization_options();
if (copts && copts->named_indices) {
external_indices_ = copts->named_indices.value();
} else {
ranges::find_if(summands, [this](const auto &summand) {
if (summand.template is<Sum>()) // summands must not be a Sum
throw std::invalid_argument(
"WickTheorem<S>::compute(expr): expr is a Sum with one of "
"the "
"summands also a Sum, WickTheorem can only accept a fully "
"expanded Sum");
else if (summand.template is<Product>()) {
extract_indices(*(summand.template as_shared_ptr<Product>()));
return true;
} else
return false;
});
}
}
if (Logger::instance().wick_harness)
std::wcout << "WickTheorem<S>::compute: input (after canonicalize) has "
<< summands.size()
<< " terms = " << to_latex_align(expr_input_) << std::endl;
auto wick_task = [&result_acc, &result_mtx, this,
&count_only](const ExprPtr &input) {
WickTheorem wt(input->clone(), *this);
auto task_result = wt.compute(
count_only, /* definitely skip input canonicalization */ true);
stats() += wt.stats();
if (task_result) {
std::scoped_lock<std::mutex> lock(result_mtx);
result_acc.append(task_result);
}
};
sequant::for_each(summands, wick_task);
// if the sum is empty return zero
// if the sum has 1 summand, return it directly
return result_acc.make_expr();
}
// ... else if a product, find NormalOperatorSequence, if any, and compute
// ...
else if (expr_input_->is<Product>()) {
if (!skip_input_canonicalization) { // canonicalize, unless told to skip
auto canon_byproduct = expr_input_->rapid_canonicalize();
SEQUANT_ASSERT(
canon_byproduct ==
nullptr); // canonicalization of Product always returns nullptr
}
// NOW disable canonicalization of normal operators
// N.B. even if skipped initial input canonicalization need to disable
// subsequent nop canonicalization
disable_nop_canonicalization();
if (!all_indices_) {
extract_indices(*(expr_input_.as_shared_ptr<Product>()));
}
// split off NormalOperators into input_
auto first_nop_it = ranges::find_if(
*expr_input_,
[](const ExprPtr &expr) { return expr->is<NormalOperator<S>>(); });
// if have ops, split into nop sequence and cnumber "prefactor"
if (first_nop_it != ranges::end(*expr_input_)) {
// extract into prefactor and op sequence
ExprPtr prefactor =
ex<CProduct>(expr_input_->as<Product>().scalar(), ExprPtrList{});
auto nopseq = std::make_shared<NormalOperatorSequence<S>>();
for (const auto &factor : *expr_input_) {
if (factor->template is<NormalOperator<S>>()) {
nopseq->push_back(factor->template as<NormalOperator<S>>());
} else {
SEQUANT_ASSERT(factor->is_cnumber());
*prefactor *= *factor;
}
}
init_input(nopseq);
// compute and record/analyze topological NormalOperator and Index
// partitions
if (use_topology_) {
if (Logger::instance().wick_topology)
std::wcout
<< "WickTheorem<S>::compute: input to topology computation = "
<< to_latex(expr_input_) << std::endl;
// construct graph representation of the tensor product
using TN = TensorNetwork;
TN tn(expr_input_->as<Product>().factors());
auto g = tn.create_graph({.distinct_named_indices = true});
const auto &graph = g.bliss_graph;
const auto &vlabels = g.vertex_labels;
[[maybe_unused]] const auto &vcolors = g.vertex_colors;
const auto &vtypes = g.vertex_types;
const auto n = vtypes.size();
SEQUANT_ASSERT(vcolors.size() == n);
SEQUANT_ASSERT(vlabels.size() == n);
const auto &tn_edges = tn.edges();
const auto &tn_tensors = tn.tensors();
auto idx_vertex_to_edge_ptr =
[&](const auto idx_vertex) -> const TN::Edge * {
SEQUANT_ASSERT(idx_vertex < n);
const auto edge_idx = g.vertex_to_index_idx(idx_vertex);
if (edge_idx < tn_edges.size())
return &tn_edges[edge_idx];
else // indices without matching edges are pure protoindices
return nullptr;
};
if (Logger::instance().wick_topology) {
std::basic_ostringstream<wchar_t> oss;
graph->write_dot(oss, {.labels = vlabels});
std::wcout
<< "WickTheorem<S>::compute: colored graph produced from TN = "
<< std::endl
<< oss.str() << std::endl;
}
// identify vertex indices of NormalOperator objects and Indices
// 1. list of vertex indices corresponding to NormalOperator objects
// on the TN graph and their ordinals in NormalOperatorSequence
// N.B. for NormalOperators the vertex indices coincide with
// the ordinals
container::map<size_t, size_t> nop_vidx_ord;
// 2. list of vertex indices corresponding to Index objects on the TN
// graph that appear in NormalOperatorsSequence and
// their ordinals therein
// N.B. for Index objects the vertex indices do NOT coincide with
// the ordinals
container::map<size_t, size_t> index_vidx_ord;
{
const auto &nop_labels = NormalOperator<S>::labels();
const auto nop_labels_begin = begin(nop_labels);
const auto nop_labels_end = end(nop_labels);
using opseq_view_type =
flattened_rangenest<NormalOperatorSequence<S>>;
auto opseq_view = opseq_view_type(input_.get());
const auto opseq_view_begin = ranges::begin(opseq_view);
const auto opseq_view_end = ranges::end(opseq_view);
// NormalOperators are not reordered by canonicalization, hence the
// ordinal can be computed by counting
std::size_t nop_ord = 0;
for (size_t v = 0; v != n; ++v) {
if (vtypes[v] == VertexType::TensorCore &&
(std::find(nop_labels_begin, nop_labels_end, vlabels[v]) !=
nop_labels_end)) {
[[maybe_unused]] auto insertion_result =
nop_vidx_ord.emplace(v, nop_ord++);
SEQUANT_ASSERT(insertion_result.second);
}
if (vtypes[v] == VertexType::Index && !input_->empty()) {
auto *edge_ptr = idx_vertex_to_edge_ptr(v);
if (edge_ptr) { // do not consider pure protoindices
auto &idx = edge_ptr->idx();
auto idx_it_in_opseq = ranges::find_if(
opseq_view,
[&idx](const auto &v) { return v.index() == idx; });
if (idx_it_in_opseq != opseq_view_end) {
const auto ord =
ranges::distance(opseq_view_begin, idx_it_in_opseq);
[[maybe_unused]] auto insertion_result =
index_vidx_ord.emplace(v, ord);
SEQUANT_ASSERT(insertion_result.second);
}
}
}
}
}
// compute and save graph automorphism generators
std::vector<std::vector<unsigned int>> aut_generators;
{
bliss::Stats stats;
graph->set_splitting_heuristic(bliss::Graph::shs_fsm);
auto save_aut = [&aut_generators](const unsigned int n,
const unsigned int *aut) {
aut_generators.emplace_back(aut, aut + n);
};
graph->find_automorphisms(
stats, &bliss::aut_hook<decltype(save_aut)>, &save_aut);
if (Logger::instance().wick_topology) {
std::basic_ostringstream<wchar_t> oss2;
bliss::print_auts(aut_generators, oss2, vlabels);
std::wcout << "WickTheorem<S>::compute: colored graph "
"automorphism generators = \n"
<< oss2.str() << std::endl;
}
}
// Use automorphisms to determine groups of topologically equivalent
// NormalOperator and Op objects.
// @param vertices maps vertex indices of the objects to their
// ordinals in the sequence of such objects within
// the NormalOperatorSequence
// @param nontrivial_partitions_only if true, only partitions with
// more than one element, are reported, else even trivial
// partitions with a single partition will be reported
// @param vertex_pair_exclude a callable that accepts 2 vertex
// indices and returns true if the automorphism of this pair
// of indices is to be ignored; this is used to disregard
// automorphisms of Index objects unless connected to same bra/ket
// of an (anti)symmetric NormalOperator.
// @return the \c {vertex_to_partition_idx,npartitions} pair in
// which \c vertex_to_partition_idx maps vertex indices that are
// part of nontrivial partitions to their (1-based) partition indices
auto compute_partitions = [&aut_generators](
const container::map<size_t, size_t>
&vertices,
bool nontrivial_partitions_only,
auto &&vertex_pair_exclude) {
container::map<size_t, size_t> vertex_to_partition_idx;
int next_partition_idx = -1;
// using each automorphism generator
for (auto &&aut : aut_generators) {
// skip automorphism generators that do not involve vertices
// in `vertices` list
bool aut_contains_other_vertices = true;
for (auto &&[v, ord] : vertices) {
(void)ord;
const auto v_is_in_aut = v != aut[v];
if (v_is_in_aut) {
aut_contains_other_vertices = false;
break;
}
}
if (aut_contains_other_vertices) continue;
// update partitions
for (auto &&[v1, ord1] : vertices) {
const auto v2 = aut[v1];
if (v2 != v1 &&
!vertex_pair_exclude(
v1, v2)) { // if the automorphism maps this vertex to
// another ... they both must be in the same
// partition
SEQUANT_ASSERT(vertices.find(v2) != vertices.end());
auto v1_partition_it = vertex_to_partition_idx.find(v1);
auto v2_partition_it = vertex_to_partition_idx.find(v2);
const bool v1_has_partition =
v1_partition_it != vertex_to_partition_idx.end();
const bool v2_has_partition =
v2_partition_it != vertex_to_partition_idx.end();
if (v1_has_partition &&
v2_has_partition) { // both are in partitions? make sure
// they are in the same partition.
// N.B. this may leave gaps in
// partition indices ... no biggie
const auto v1_part_idx = v1_partition_it->second;
const auto v2_part_idx = v2_partition_it->second;
if (v1_part_idx !=
v2_part_idx) { // if they have different partition
// indices, change the larger of the two
// indices to match the lower
const auto target_part_idx =
std::min(v1_part_idx, v2_part_idx);
for (auto &v : vertex_to_partition_idx) {
if (v.second == v1_part_idx || v.second == v2_part_idx)
v.second = target_part_idx;
}
}
} else if (v1_has_partition) { // only v1 is in a partition?
// place v2 in it
const auto v1_part_idx = v1_partition_it->second;
vertex_to_partition_idx.emplace(v2, v1_part_idx);
} else if (v2_has_partition) { // only v2 is in a partition?
// place v1 in it
const auto v2_part_idx = v2_partition_it->second;
vertex_to_partition_idx.emplace(v1, v2_part_idx);
} else { // neither is in a partition? place both in the next
// available partition
const size_t target_part_idx = ++next_partition_idx;
vertex_to_partition_idx.emplace(v1, target_part_idx);
vertex_to_partition_idx.emplace(v2, target_part_idx);
}
}
}
}
if (!nontrivial_partitions_only) {
ranges::for_each(vertices, [&](const auto &vidx_ord) {
auto &&[vidx, ord] = vidx_ord;
if (vertex_to_partition_idx.find(vidx) ==
vertex_to_partition_idx.end()) {
vertex_to_partition_idx.emplace(vidx, ++next_partition_idx);
}
});
}
const auto npartitions = next_partition_idx + 1;
return std::make_tuple(vertex_to_partition_idx, npartitions);
};
// compute NormalOperator->partition map, convert to partition lists
// (if any), and register via set_nop_partitions to be used in full
// contractions
auto do_not_skip_elements = [](size_t, size_t) { return false; };
auto [nop_vidx2pidx, nop_npartitions] = compute_partitions(
nop_vidx_ord, /* nontrivial_partitions_only = */ true,
do_not_skip_elements);
// converts vertex ordinal to partition key map into a sequence of
// partitions, each composed of the corresponding ordinals of the
// vertices in the vertex_list sequence
// @param vidx2pidx a map from vertex index (in TN) to its
// (1-based) partition index
// @param npartitions the total number of partitions
// @param vidx_ord ordered sequence of vertex indices, object
// with vertex index `vidx` will be mapped to ordinal
// `vidx_ord[vidx]`
// @return sequence of partitions, sorted by the smallest ordinal
auto extract_partitions = [](const auto &vidx2pidx,
const auto npartitions,
const auto &vidx_ord) {
container::svector<container::svector<size_t>> partitions;
SEQUANT_ASSERT(npartitions > -1);
const size_t max_pidx = npartitions;
partitions.reserve(max_pidx);
// iterate over all partition indices ... note that there may be
// gaps so count the actual partitions
size_t partition_cnt = 0;
for (size_t p = 0; p <= max_pidx; ++p) {
bool p_found = false;
for (const auto &[vidx, pidx] : vidx2pidx) {
if (pidx == p) {
// !!remember to map the vertex index into the operator
// index!!
SEQUANT_ASSERT(vidx_ord.find(vidx) != vidx_ord.end());
const auto ordinal = vidx_ord.find(vidx)->second;
if (p_found == false) { // first time this is found
partitions.emplace_back(container::svector<size_t>{
static_cast<size_t>(ordinal)});
} else
partitions[partition_cnt].emplace_back(ordinal);
p_found = true;
}
}
if (p_found) ++partition_cnt;
}
// sort each partition
for (auto &partition : partitions) {
ranges::sort(partition);
}
// sort partitions in the order of increasing first element
ranges::sort(partitions, [](const auto &p1, const auto &p2) {
return p1.front() < p2.front();
});
return partitions;
};
if (!nop_vidx2pidx.empty()) {
container::svector<container::svector<size_t>> nop_partitions;
nop_partitions = extract_partitions(nop_vidx2pidx, nop_npartitions,
nop_vidx_ord);
if (Logger::instance().wick_topology) {
std::wcout
<< "WickTheorem<S>::compute: topological nop partitions:{\n";
ranges::for_each(nop_partitions, [](auto &&part) {
std::wcout << "{";
ranges::for_each(part,
[](auto &&p) { std::wcout << p << " "; });
std::wcout << "}";
});
std::wcout << "}" << std::endl;
}
this->set_nop_partitions(nop_partitions);
}
// compute Index->partition map, and convert to partition lists (if
// any), and check that use_topology_ is compatible with index
// partitions
// Index partitions are constructed to *only* include Index
// objects attached to the bra/ket of any NormalOperator! hence
// need to use filter in computing partitions
auto exclude_index_vertex_pair = [&tn_tensors,
&idx_vertex_to_edge_ptr](
size_t v1, size_t v2) {
const auto *edge1_ptr = idx_vertex_to_edge_ptr(v1);
const auto *edge2_ptr = idx_vertex_to_edge_ptr(v2);
if (!edge1_ptr || !edge2_ptr) return true;
const auto &edge1 = *edge1_ptr;
const auto &edge2 = *edge2_ptr;
auto connected_to_bra_or_ket_of_same_symmetric_nop =
[&tn_tensors](const auto &edge1, const auto &edge2) -> bool {
const auto nt1 = edge1.vertex_count();
SEQUANT_ASSERT(nt1 <= 2);
const auto nt2 = edge2.vertex_count();
SEQUANT_ASSERT(nt2 <= 2);
for (auto i1 = 0; i1 != nt1; ++i1) {
const auto tensor1_ord = edge1.vertex(i1).getTerminalIndex();
for (auto i2 = 0; i2 != nt2; ++i2) {
const auto tensor2_ord = edge2.vertex(i2).getTerminalIndex();
// do not skip if connected to same ...
if (tensor1_ord == tensor2_ord) {
auto tensor_ord = tensor1_ord;
const std::shared_ptr<AbstractTensor> &tensor_ptr =
tn_tensors.at(tensor_ord);
// ... (anti)symmetric ...
if (tensor_ptr->_symmetry() != Symmetry::Nonsymm) {
const auto tensor1_slot_type =
edge1.vertex(i1).getOrigin();
const auto tensor2_slot_type =
edge2.vertex(i2).getOrigin();
// ... bra/ket of ...
if (tensor1_slot_type == tensor2_slot_type) {
// ... NormalOperator!
if (std::dynamic_pointer_cast<NormalOperator<S>>(
tensor_ptr)) {
return true;
}
}
}
}
}
}
return false;
};
const bool exclude =
!connected_to_bra_or_ket_of_same_symmetric_nop(edge1, edge2);
return exclude;
};
// index_vidx2pidx maps vertex index (see
// index_vidx_ord) to partition index
container::map<size_t, size_t> index_vidx2pidx;
int index_npartitions = -1;
std::tie(index_vidx2pidx, index_npartitions) = compute_partitions(
index_vidx_ord, /* nontrivial_partitions_only = */ false,
/* this is to ensure that each index partition only involves
indices attached to bra or to ket of same
symmetric/antisymmetric nop.*/
exclude_index_vertex_pair);
if (!index_vidx2pidx.empty()) {
container::svector<container::svector<size_t>> index_partitions;
index_partitions = extract_partitions(
index_vidx2pidx, index_npartitions, index_vidx_ord);
if (Logger::instance().wick_topology) {
std::wcout << "WickTheorem<S>::compute: topological index "
"partitions:{\n";
ranges::for_each(
index_vidx2pidx, [&idx_vertex_to_edge_ptr](auto &&vidx_pidx) {
auto &&[vidx, pidx] = vidx_pidx;
auto *edge_ptr = idx_vertex_to_edge_ptr(vidx);
// skip pure proto indices
if (edge_ptr) {
auto &idx = edge_ptr->idx();
std::wcout << "Index " << idx.full_label()
<< " -> partition " << pidx << "\n";
}
});
std::wcout << "}" << std::endl;
}
this->set_op_partitions(index_partitions);
// TODO determine partitions of braket index pairs to be able to
// exploit topology for spin-free WT note that right now indices
// attached to bra/ket of spin-free normal operators are excluded
// from index partitions above
}
}
if (!input_->empty()) {
if (Logger::instance().wick_contract) {
std::wcout
<< "WickTheorem<S>::compute: input to compute_nopseq = {\n";
for (auto &&nop : input_) std::wcout << to_latex(nop) << "\n";
std::wcout << "}" << std::endl;
}
prefactor_ = prefactor;
auto result = compute_nopseq(count_only);
prefactor_.reset();
if (result) { // simplify if obtained nonzero ...
// rapid_simplify(result);
// canonicalize(result);
rapid_simplify(
result); // rapid_simplify again since canonization may produce
// new opportunities (e.g. terms cancel, etc.)
} else
result = ex<Constant>(0);
return result;
}
} else { // product does not include ops
return expr_input_;
}
} // expr_input_->is<Product>()
// ... else if NormalOperatorSequence already, compute ...
else if (expr_input_->is<NormalOperatorSequence<S>>()) {
SEQUANT_ABORT(
"expr_input_ should no longer be nonnull if constructed with an "
"expression that's a NormalOperatorSequence<S>");
} else // ... else do nothing
return expr_input_;
} else {
// given a NormalOperatorSequence instead of an expression
auto result = compute_nopseq(count_only);
if (result) { // simplify if obtained nonzero ...
this->reduce(result);
// N.B. DO NOT CANONICALIZE to preserve index pairings if doing partial
// contraction
rapid_simplify(result);
} else
result = ex<Constant>(0);
return result;
}
SEQUANT_UNREACHABLE;
}
template <Statistics S>
void WickTheorem<S>::reduce(ExprPtr &expr) const {
if (Logger::instance().wick_reduce) {
std::wcout << "WickTheorem<S>::reduce: input = "
<< to_latex_align(expr, 20, 1) << std::endl;
}
const bool extracted_indices = !all_indices_;
if (extracted_indices) {
extract_indices(*expr);
}
// there are 2 possibilities: expr is a single Product, or it's a Sum of
// Products
if (expr.is<Product>()) {
auto expr_cast = std::static_pointer_cast<Product>(expr);
SEQUANT_ASSERT(external_indices_);
if (detail::reduce_wick_impl<S>(expr_cast, *external_indices_,
*external_indices_,
get_default_context(S))) {
expr = expr_cast;
} else {
expr = std::make_shared<Constant>(0);
}
} else if (expr.is<Sum>()) {
for (auto &&subexpr : *expr) {
SEQUANT_ASSERT(subexpr->is<Product>());
auto subexpr_cast = std::static_pointer_cast<Product>(subexpr);
SEQUANT_ASSERT(external_indices_);
if (detail::reduce_wick_impl<S>(subexpr_cast, *external_indices_,
*noncovariant_indices_,
get_default_context(S)))
subexpr = subexpr_cast;
else
subexpr = std::make_shared<Constant>(0);
}
}
if (Logger::instance().wick_reduce) {
sequant::wprintf(
"WickTheorem<S>::reduce: result = ", to_latex_align(expr, 20, 1), "\n");
}
if (extracted_indices) reset_indices();
}
template <Statistics S>
WickTheorem<S>::~WickTheorem() {}
} // namespace sequant
#endif // SEQUANT_WICK_IMPL_HPP