Program Listing for File wick.impl.hpp

Return to documentation for file (SeQuant/core/wick.impl.hpp)

//
// Created by Eduard Valeyev on 3/31/18.
//

#ifndef SEQUANT_WICK_IMPL_HPP
#define SEQUANT_WICK_IMPL_HPP

#include <SeQuant/core/bliss.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/tensor_network/vertex.hpp>
#include <SeQuant/core/utility/debug.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>

#ifdef SEQUANT_HAS_EXECUTION_HEADER
#include <execution>
#endif

namespace sequant {

namespace detail {

class index_repl_dst_t {
 public:
  explicit index_repl_dst_t(Index dst, Index src)
      : dst_(std::move(dst)), src_{std::move(src)} {
    if (Logger::instance().wick_reduce) {
      sequant::wprintf("index_repl_dst_t: ctor, src=", src.to_latex(),
                       " dst=", dst_.to_latex(), "\n");
    }
  }

  const Index &dst() const { return dst_; }
  void update_dst(Index idx) {
    if (Logger::instance().wick_reduce) {
      sequant::wprintf("index_repl_dst_t: changing dst=", dst_.to_latex(),
                       " to dst=", idx.to_latex(), "\n");
    }
    dst_ = std::move(idx);
  }

  const container::svector<Index> &src() const { return src_; }
  index_repl_dst_t &append_src(Index src) {
    SEQUANT_ASSERT(ranges::contains(src_, src) == false);
    if (Logger::instance().wick_reduce) {
      sequant::wprintf("index_repl_dst_t: appended src=", src.to_latex(),
                       " -> dst=", dst_.to_latex(), "\n");
    }
    src_.emplace_back(std::move(src));
    return *this;
  }

 private:
  Index dst_;  // does not have proto indices
  container::svector<Index> src_;
};


template <Statistics S>
std::optional<std::pair<container::map<Index, Index>, bool>>
compute_index_replacement_rules(
    std::shared_ptr<Product> &product,
    const container::set<Index> &external_indices,
    const container::set<Index> &noncovariant_indices,
    const std::set<Index, Index::LabelCompare> &all_indices,
    const std::shared_ptr<const IndexSpaceRegistry> &isr =
        get_default_context(S).index_space_registry()) {
  bool zero_result_status = false;
  auto zero_result = [&zero_result_status]() -> void {
    zero_result_status = true;
  };
#define SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(x) \
  { x; }                                                          \
  if (zero_result_status) return {};
#define SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(x) \
  { x; }                                                               \
  if (zero_result_status) return;

  expr_range exrng(product);

  bool have_kroneckers = false;

  auto index_validator = [&all_indices](const Index &idx) {
    return all_indices.find(idx) == all_indices.end();
  };
  IndexFactory idxfac(index_validator);
  container::map<Index /* src */, std::shared_ptr<index_repl_dst_t> /* dst */>
      src2dst;  // src->dst, will be converted to
  container::svector<std::shared_ptr<index_repl_dst_t>>
      dst_list;  // unsorted list of dst indices

  // transfers proto indices from src (if any) to dst
  // which proto-indices should dst inherit from src? a dst index without
  // proto indices will inherit its src counterpart's indices, unless it
  // already has its own protoindices: <a_ij|p> = <a_ij|a_ij> (hence replace p
  // with a_ij), but <a_ij|p_kl> = <a_ij|a_kl> != <a_ij|a_ij> (hence replace
  // p_kl with a_kl)
  auto proto = [](const Index &dst, const Index &src) {
    if (src.has_proto_indices()) {
      if (dst.has_proto_indices()) {
        SEQUANT_ASSERT(dst.proto_indices() == src.proto_indices());
        return dst;
      } else
        return Index(dst, src.proto_indices());
    } else {
      return dst;
    }
  };

  // adds src->dst, optionally assigning proto indices from protosrc
  auto add_src_to_existing_dst = [&src2dst](const Index &src, auto srd2dst_it) {
    SEQUANT_ASSERT(srd2dst_it != ranges::end(src2dst));
    srd2dst_it->second->append_src(src);
    src2dst.emplace(src, srd2dst_it->second);
  };

  // change dst index
  auto replace_dst_index = [&src2dst](const Index &new_dst, auto srd2dst_it) {
    SEQUANT_ASSERT(srd2dst_it != ranges::end(src2dst));
    srd2dst_it->second->update_dst(new_dst);
  };

  // merges dst2 into dst1
  auto merge_dst2_into_dst1 = [&src2dst, &dst_list](auto srd2dst_it1,
                                                    auto srd2dst_it2) {
    SEQUANT_ASSERT(
        srd2dst_it1->second !=
        srd2dst_it2->second);  // caller should ensure no self-merges,
                               // indicates faulty logic upstream
    SEQUANT_ASSERT(srd2dst_it1 != ranges::end(src2dst));
    SEQUANT_ASSERT(srd2dst_it2 != ranges::end(src2dst));
    const auto dst1 = srd2dst_it1->second;
    const auto dst2 = srd2dst_it2->second;

    // repoint all source indices of dst2 to dst1
    for (const auto &src2 : srd2dst_it2->second->src()) {
      auto it = src2dst.find(src2);
      SEQUANT_ASSERT(it != src2dst.end());
      it->second = dst1;
      dst1->append_src(src2);
    }

    // there should be no refs to dst2 in src2dst
    SEQUANT_ASSERT(ranges::contains(src2dst, dst2, [](const auto &it) {
                     return it.second;
                   }) == false);

    // erase dst2 from dst_list
    auto it2 = ranges::find(dst_list, dst2);
    SEQUANT_ASSERT(it2 != ranges::end(dst_list));
    dst_list.erase(it2);
    // there should be no "viewers" of dst2
    SEQUANT_ASSERT(dst2.use_count() == 1);
  };

  // adds src->dst, optionally assigning proto indices from protosrc
  auto add_rule = [&src2dst, &dst_list, &proto](
                      const Index &src, const Index &dst,
                      std::optional<const Index> protosrc = std::nullopt) {
    auto real_dst = protosrc ? proto(dst, protosrc.value()) : proto(dst, src);
    auto it = ranges::find_if(
        dst_list, [&real_dst](const auto &d) { return d->dst() == real_dst; });
    if (it != ranges::end(dst_list)) {
      (*it)->append_src(src);
      src2dst.emplace(src, *it);
    } else {
      [[maybe_unused]] auto insertion_result = src2dst.emplace(
          src, std::make_shared<index_repl_dst_t>(real_dst, src));
      SEQUANT_ASSERT(insertion_result.second);
      dst_list.emplace_back(insertion_result.first->second);
    }
  };

  // changes src->current_dst to src->intersection(dst,current_dst)
  auto update_rule = [&src2dst, &proto, &isr, &idxfac, &zero_result](
                         auto src_it, const Index &src, const Index &dst,
                         std::optional<const Index> protosrc = std::nullopt) {
    SEQUANT_ASSERT(src_it != src2dst.end());
    auto &old_dst = src_it->second->dst();

    // do we need to change space of dst?
    const bool change_dst_space = (dst.space() != old_dst.space());
    const IndexSpace &new_dst_space =
        change_dst_space ? isr->intersection(old_dst.space(), dst.space())
                         : dst.space();
    if (!new_dst_space) return zero_result();

    // do we need to change protoindices?
    bool change_dst_protoindices = false;
    // already have protoindices?
    if (old_dst.has_proto_indices()) {
      // same index should never be mapped to indices with
      // different protoindices, this is what the logic dealing with
      // noncovariant indices is meant to avoid
      if (protosrc.value_or(src).has_proto_indices()) {
        SEQUANT_ASSERT(protosrc.value_or(src).proto_indices() ==
                       old_dst.proto_indices());
      }
    } else {
      change_dst_protoindices = protosrc.value_or(src).has_proto_indices();
    }

    if (change_dst_space || change_dst_protoindices) {
      auto plain_dst = idxfac.make(new_dst_space);
      const auto real_dst = change_dst_protoindices
                                ? proto(plain_dst, protosrc.value_or(src))
                                : proto(plain_dst, old_dst);
      src_it->second->update_dst(real_dst);
    }
  };

  // adds src->dst or changes src->current_dst to
  // src->intersection(dst,current_dst)
  auto add_or_update_rule = [&add_rule, &update_rule, &src2dst,
                             &zero_result_status](const Index &src,
                                                  const Index &dst) {
    auto src_it = src2dst.find(src);
    if (src_it == src2dst.end()) {
      add_rule(src, dst);
    } else {
      SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
          update_rule(src_it, src, dst));
    }
  };

  // adds src1->dst and src2->dst; if src1->dst1 and/or src2->dst2 already
  // exist the existing rules are updated to map to the intersection of dst1,
  // dst2 and dst
  auto add_or_update_rules = [&add_rule, &update_rule, &add_src_to_existing_dst,
                              &replace_dst_index, &merge_dst2_into_dst1,
                              &src2dst, &idxfac, &proto, &zero_result,
                              &zero_result_status,
                              &isr](const Index &src1, const Index &src2,
                                    const Index &dst) {
    // are there replacement rules already for src{1,2}?
    auto src1_it = src2dst.find(src1);
    auto src2_it = src2dst.find(src2);
    const auto has_src1_rule = src1_it != src2dst.end();
    const auto has_src2_rule = src2_it != src2dst.end();

    // which proto-indices should dst1 and dst2 inherit? a destination index
    // without proto indices will inherit its source counterpart's indices,
    // unless it already has its own protoindices: <a_ij|p> = <a_ij|a_ij> (hence
    // replace p with a_ij), but <a_ij|p_kl> = <a_ij|a_kl> != <a_ij|a_ij> (hence
    // replace p_kl with a_kl)
    const auto &dst1_proto =
        !src1.has_proto_indices() && src2.has_proto_indices() ? src2 : src1;
    const auto &dst2_proto =
        !src2.has_proto_indices() && src1.has_proto_indices() ? src1 : src2;

    if (!has_src1_rule && !has_src2_rule) {  // if brand new, add the rules
      add_rule(src1, dst, dst1_proto);
      add_rule(src2, dst, dst2_proto);
    } else if (has_src1_rule && !has_src2_rule) {
      // update the existing rule for src1
      SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
          update_rule(src1_it, src1, dst, dst1_proto));
      // create new rule: src2->dst1
      add_src_to_existing_dst(src2, src2dst.find(src1));
    } else if (!has_src1_rule && has_src2_rule) {
      // update the existing rule for src2
      SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT(
          update_rule(src2_it, src2, dst, dst2_proto));
      // create new rule: src1->dst2
      add_src_to_existing_dst(src1, src2dst.find(src2));
    } else {
      // merge the existing rules
      // - compute new target index space
      // - compute new target index
      // - repoint old rules to the new target
      const auto &old_dst1 = src1_it->second->dst();
      const auto &old_dst2 = src2_it->second->dst();
      const auto new_dst_space =
          (dst.space() != old_dst1.space() || dst.space() != old_dst2.space())
              ? isr->intersection(
                    isr->intersection(old_dst1.space(), old_dst2.space()),
                    dst.space())
              : dst.space();
      if (!new_dst_space) return zero_result();
      Index new_dst;
      if (new_dst_space == old_dst1.space()) {
        new_dst = old_dst1;
        if (new_dst_space == old_dst2.space() && old_dst2 < new_dst) {
          new_dst = old_dst2;
        }
        if (new_dst_space == dst.space() && dst < new_dst) {
          new_dst = dst;
        }
      } else if (new_dst_space == old_dst2.space()) {
        new_dst = old_dst2;
        if (new_dst_space == dst.space() && dst < new_dst) {
          new_dst = dst;
        }
      } else if (new_dst_space == dst.space()) {
        new_dst = dst;
      } else
        new_dst = idxfac.make(new_dst_space);

      // update dst1 and dst2 with new_dst, then merge them
      auto new_real_dst = proto(new_dst, dst1_proto);
      SEQUANT_ASSERT(
          new_real_dst ==
          proto(new_dst, dst2_proto));  // don't know how to handle this yet
      auto src2dst_it1 = src2dst.find(src1);
      replace_dst_index(new_real_dst, src2dst_it1);
      // if src1 and src2 were pointing to same destination, we are done, else
      // merge their destinations
      auto src2dst_it2 = src2dst.find(src2);
      if (src2dst_it1->second != src2dst_it2->second) {
        replace_dst_index(new_real_dst, src2dst_it2);
        merge_dst2_into_dst1(src2dst_it1, src2dst_it2);
      }
    }
  };

  for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
    const auto &factor = *it;
    if (factor.is<Tensor>()) {
      const auto &tensor = factor.as<Tensor>();
      const auto is_overlap = tensor.label() == overlap_label();
      const auto is_kronecker = tensor.label() == kronecker_label();
      if (is_overlap || is_kronecker) {
        have_kroneckers = true;
        SEQUANT_ASSERT(tensor.bra().size() == 1);
        SEQUANT_ASSERT(tensor.ket().size() == 1);
        const auto &bra = tensor.bra().at(0);
        const auto &ket = tensor.ket().at(0);

        // skip if
        // - self-kronecker (will be replaced by 1 already)
        bool do_skip = bra == ket;
        // - nontrivial overlap between 2 noncovariant modes
        if (is_overlap) {
          // N.B. noncovariant bra or ket is OK because we can always rotate it
          // to match the basis of the other
          do_skip = do_skip || (noncovariant_indices.contains(bra) &&
                                noncovariant_indices.contains(ket));
        }
        if (!do_skip) {
          const auto bra_is_ext = ranges::find(external_indices, bra) !=
                                  ranges::end(external_indices);
          const auto ket_is_ext = ranges::find(external_indices, ket) !=
                                  ranges::end(external_indices);

          const auto intersection_space =
              isr->intersection(bra.space(), ket.space());

          // if overlap's indices are from non-overlapping spaces, return zero
          if (!intersection_space) {
            return std::nullopt;
          }

          if (!bra_is_ext && !ket_is_ext) {
            // int + int
            const auto new_dummy = idxfac.make(intersection_space);
            SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
                add_or_update_rules(bra, ket, new_dummy));
          } else if (bra_is_ext && !ket_is_ext) {  // ext + int
            if (includes(ket.space(), bra.space())) {
              SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
                  add_or_update_rule(ket, bra));
            } else {
              SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
                  add_or_update_rule(ket, idxfac.make(intersection_space)));
            }
          } else if (!bra_is_ext && ket_is_ext) {  // int + ext
            if (includes(bra.space(), ket.space())) {
              SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
                  add_or_update_rule(bra, ket));
            } else {
              SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT(
                  add_or_update_rule(bra, idxfac.make(intersection_space)));
            }
          }
          // ext + ext => leave overlap as is
        }
      }
    }
  }

  // make 1-to-1 version of src->dst
  container::map<Index /* src */, Index /* dst */> result;
  for (auto &&[src, d] : src2dst) {
    result.emplace(src, d->dst());
  }
  return std::make_pair(result, have_kroneckers);

#undef SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_IF_ZERO_RESULT
#undef SEQUANT_WICK_IMPL_HPP_CIRR_EARLY_RETURN_VOID_IF_ZERO_RESULT
}

inline bool apply_index_replacement_rules(
    std::shared_ptr<Product> &product,
    const container::map<Index, Index> &const_replrules,
    std::set<Index, Index::LabelCompare> &all_indices) {
  expr_range exrng(product);

#ifdef SEQUANT_ASSERT_ENABLED
  // assert that tensors_ indices are not tagged since going to tag indices
  {
    for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
      const auto &factor = *it;
      if (factor->is<AbstractTensor>()) {
        auto &tensor = factor->as<AbstractTensor>();
        SEQUANT_ASSERT(ranges::none_of(tensor._slots(), [](const Index &idx) {
          return idx.tag().has_value();
        }));
      }
    }
  }
#endif
  bool mutated = false;
  bool pass_mutated = false;
  do {
    pass_mutated = false;

    for (auto it = ranges::begin(exrng); it != ranges::end(exrng);) {
      const auto &factor = *it;
      if (factor->is<AbstractTensor>()) {
        auto &tensor = factor->as<AbstractTensor>();

        pass_mutated &= tensor._transform_indices(const_replrules);

        if (tensor._label() == overlap_label() ||
            tensor._label() == kronecker_label()) {
          const auto bra = tensor._bra().at(0);
          const auto ket = tensor._ket().at(0);
          if (bra == ket) {
            pass_mutated = true;
            *it = ex<Constant>(1);
          }
        }  // Kronecker delta
      }
      ++it;
    }
    mutated |= pass_mutated;
  } while (pass_mutated);  // keep replacing til fixed point

  // assert that tensors_ indices are not tagged since going to tag indices
  {
    for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
      const auto &factor = *it;
      if (factor->is<AbstractTensor>()) {
        factor->as<AbstractTensor>()._reset_tags();
      }
    }
  }

  // update all_indices
  std::set<Index, Index::LabelCompare> all_indices_new;
  ranges::for_each(
      all_indices, [&const_replrules, &all_indices_new](const Index &idx) {
        auto dst_it = const_replrules.find(idx);
        [[maybe_unused]] auto insertion_result = all_indices_new.emplace(
            dst_it != const_replrules.end() ? dst_it->second : idx);
      });
  std::swap(all_indices_new, all_indices);

  return mutated;
}

template <Statistics S>
bool reduce_wick_impl(std::shared_ptr<Product> &expr,
                      const container::set<Index> &external_indices,
                      const container::set<Index> &noncovariant_indices,
                      const Context &ctx) {
  // if have noncovariant indices, will need to update them at the beginning of
  // every pass
  const auto have_noncovariant_indices = !noncovariant_indices.empty();

  if (Logger::instance().wick_reduce) {
    sequant::wprintf(
        "reduce_wick_impl(expr, external_indices):\n input expr = ",
        expr->to_latex(), "\n  external_indices = ");
    ranges::for_each(external_indices, [](auto &index) {
      sequant::wprintf(index.full_label(), " ");
    });
    sequant::wprintf("\n");
  }

  std::int64_t pass = -1;
  bool pass_mutated = false;
  do {
    pass_mutated = false;
    ++pass;

    // extract current indices
    auto idx_counter = get_used_indices_with_counts(expr);

    auto all_indices =
        idx_counter |
        ranges::views::transform([](const auto &v) { return v.first; }) |
        ranges::to<std::set<Index, Index::LabelCompare>>;

    // update list of noncovariant indices every iteration
    container::set<Index> all_noncovariant_indices;
    if (have_noncovariant_indices) {
      // see extract_indices
      all_noncovariant_indices =
          idx_counter |
          ranges::views::filter([&external_indices](const auto &v) {
            return (v.first.has_proto_indices() == true ||
                    v.second.proto != 0 || v.second.nonproto() != 2) &&
                   !external_indices.contains(v.first);
          }) |
          ranges::views::transform([](const auto &v) { return v.first; }) |
          ranges::to<container::set<Index>>;
      // augment list of noncovariant indices:
      // - any index with protoindices is noncovariant
      // - if kronecker delta has a noncovariant index, the other index is also
      // noncovariant
      expr->visit([&all_noncovariant_indices](const ExprPtr &ex) {
        if (ex.is<AbstractTensor>()) {
          auto &t = ex.template as<AbstractTensor>();
          for (auto &idx : slots(t)) {
            if (idx.has_proto_indices()) all_noncovariant_indices.emplace(idx);
          }
          const auto tlabel = t._label();
          const auto is_kronecker =
              tlabel == kronecker_label() ||
              (tlabel == overlap_label() &&
               get_default_context().metric() == IndexSpaceMetric::Unit &&
               t._bra()[0].proto_indices() == t._ket()[0].proto_indices());
          if (is_kronecker) {
            SEQUANT_ASSERT(t._bra_rank() == 1);
            Index b = t._bra()[0];
            SEQUANT_ASSERT(t._ket_rank() == 1);
            Index k = t._ket()[0];
            if (all_noncovariant_indices.contains(b)) {
              auto it = all_noncovariant_indices.find(k);
              if (it == all_noncovariant_indices.end()) {
                all_noncovariant_indices.emplace_hint(it, std::move(k));
              }
            } else if (all_noncovariant_indices.contains(k)) {
              auto it = all_noncovariant_indices.find(b);
              if (it == all_noncovariant_indices.end()) {
                all_noncovariant_indices.emplace_hint(it, std::move(b));
              }
            }
          }
        }
      });
      if (Logger::instance().wick_reduce) {
        sequant::wprintf("all_noncovariant_indices = ");
        for (auto &idx : all_noncovariant_indices) {
          sequant::wprintf(" ", idx.to_latex());
        }
        sequant::wprintf("\n");
      }
    }

    auto nonnull_result_opt = compute_index_replacement_rules<S>(
        expr, external_indices, all_noncovariant_indices, all_indices,
        ctx.index_space_registry());
    if (!nonnull_result_opt) return false;
    const auto &[replacement_rules, found_kroneckers] = *nonnull_result_opt;

    if (Logger::instance().wick_reduce) {
      sequant::wprintf("reduce_wick_impl(expr, external_indices) pass=", pass,
                       ":\n  expr = ", expr->to_latex(),
                       "\n  external_indices = ");
      ranges::for_each(external_indices, [](auto &index) {
        sequant::wprintf(index.full_label(), " ");
      });
      sequant::wprintf("\n  replrules = ");
      ranges::for_each(replacement_rules, [](auto &index) {
        sequant::wprintf(to_latex(index.first), "\\to", to_latex(index.second),
                         "\\,");
      });
    }

    // N.B. even if replacement list is empty, but have trivial kroneckers
    // invoke apply_index_replacement_rules
    if (found_kroneckers) {
      pass_mutated =
          apply_index_replacement_rules(expr, replacement_rules, all_indices);
    }

    if (Logger::instance().wick_reduce) {
      sequant::wprintf("\n  result = ", expr->to_latex(), "\n");
    }
  } while (pass_mutated);  // keep reducing until stop changing

  return true;
}

template <Statistics S>
struct NullNormalOperatorCanonicalizerDeregister {
  void operator()(void *) {
    const auto nop_labels = NormalOperator<S>::labels();
    TensorCanonicalizer::deregister_instance(nop_labels[0]);
    TensorCanonicalizer::deregister_instance(nop_labels[1]);
  }
};

}  // namespace detail

template <Statistics S>
void WickTheorem<S>::extract_indices(const Expr &expr,
                                     bool force_external) const {
  auto idx_counter = get_used_indices_with_counts(expr);

  all_indices_ =
      idx_counter |
      ranges::views::transform([](const auto &v) { return v.first; }) |
      ranges::to<container::set<Index>>;

  if (!user_defined_external_indices_) {
    const auto &copts = get_default_context().canonicalization_options();
    if (copts && copts->named_indices) {
      external_indices_ = copts->named_indices.value();
    } else {
      // external indices either appears once in nonproto slot or is pure
      // protoindex
      external_indices_ =
          idx_counter | ranges::views::filter([force_external](const auto &v) {
            return v.second.nonproto() <= 1 || force_external;
          }) |
          ranges::views::transform([](const auto &v) { return v.first; }) |
          ranges::to<container::set<Index>>;
    }
  }

  // covariant indices are indices that do not depend on other indices,
  // are not protoindices for other indices, and are dummy, i.e. summed over by
  // appearing twice in nonproto slots and not among external indices
  // noncovariant indices are the rest
  noncovariant_indices_ =
      idx_counter | ranges::views::filter([this](const auto &v) {
        return (v.first.has_proto_indices() == true || v.second.proto != 0 ||
                v.second.nonproto() != 2) &&
               !external_indices_->contains(v.first);
      }) |
      ranges::views::transform([](const auto &v) { return v.first; }) |
      ranges::to<container::set<Index>>;
}

template <Statistics S>
ExprPtr WickTheorem<S>::compute(const bool count_only,
                                const bool skip_input_canonicalization) {
  // need to avoid recanonicalization of operators produced by WickTheorem
  // by rapid canonicalization to avoid undoing all the good
  // the NormalOperator<S>::normalize did ... use RAII
  // 1. detail::NullNormalOperatorCanonicalizerDeregister<S> will restore state
  // of tensor canonicalizer
  // 2. this is the RAII object whose destruction will restore state of
  // the tensor canonicalizer
  std::unique_ptr<void, detail::NullNormalOperatorCanonicalizerDeregister<S>>
      raii_null_nop_canonicalizer;
  // 3. this makes the RAII object  ... NOT reentrant, only to be called in
  // top-level WickTheorem after initial canonicalization
  auto disable_nop_canonicalization = [&raii_null_nop_canonicalizer]() {
    if (!raii_null_nop_canonicalizer) {
      const auto nop_labels = NormalOperator<S>::labels();
      SEQUANT_ASSERT(nop_labels.size() == 2);
      TensorCanonicalizer::try_register_instance(
          std::make_shared<NullTensorCanonicalizer>(), nop_labels[0]);
      TensorCanonicalizer::try_register_instance(
          std::make_shared<NullTensorCanonicalizer>(), nop_labels[1]);
      raii_null_nop_canonicalizer = decltype(raii_null_nop_canonicalizer)(
          (void *)&raii_null_nop_canonicalizer, {});
    }
  };

  // have an Expr as input? Apply recursively ...
  if (expr_input_) {
    if (Logger::instance().wick_harness)
      std::wcout << "WickTheorem<S>::compute: input (before expand) = "
                 << to_latex_align(expr_input_) << std::endl;
    expand(expr_input_);
    if (Logger::instance().wick_harness)
      std::wcout << "WickTheorem<S>::compute: input (after expand) = "
                 << to_latex_align(expr_input_) << std::endl;
    // if sum, canonicalize and apply to each summand ...
    if (expr_input_->is<Sum>()) {
      if (!skip_input_canonicalization) {
        // initial full canonicalization
        canonicalize(expr_input_);
        SEQUANT_ASSERT(!expr_input_->as<Sum>().empty());
      }

      // NOW disable canonicalization of normal operators
      // N.B. even if skipped initial input canonicalization need to disable
      // subsequent nop canonicalization
      disable_nop_canonicalization();

      // parallelize over summands
      HashingAccumulator result_acc;
      std::mutex result_mtx;  // serializes updates of result
      auto summands = expr_input_->as<Sum>().summands();

      // find external_indices if don't have them
      if (!external_indices_) {
        const auto &copts = get_default_context().canonicalization_options();
        if (copts && copts->named_indices) {
          external_indices_ = copts->named_indices.value();
        } else {
          ranges::find_if(summands, [this](const auto &summand) {
            if (summand.template is<Sum>())  // summands must not be a Sum
              throw std::invalid_argument(
                  "WickTheorem<S>::compute(expr): expr is a Sum with one of "
                  "the "
                  "summands also a Sum, WickTheorem can only accept a fully "
                  "expanded Sum");
            else if (summand.template is<Product>()) {
              extract_indices(*(summand.template as_shared_ptr<Product>()));
              return true;
            } else
              return false;
          });
        }
      }

      if (Logger::instance().wick_harness)
        std::wcout << "WickTheorem<S>::compute: input (after canonicalize) has "
                   << summands.size()
                   << " terms = " << to_latex_align(expr_input_) << std::endl;

      auto wick_task = [&result_acc, &result_mtx, this,
                        &count_only](const ExprPtr &input) {
        WickTheorem wt(input->clone(), *this);
        auto task_result = wt.compute(
            count_only, /* definitely skip input canonicalization */ true);
        stats() += wt.stats();
        if (task_result) {
          std::scoped_lock<std::mutex> lock(result_mtx);
          result_acc.append(task_result);
        }
      };
      sequant::for_each(summands, wick_task);

      // if the sum is empty return zero
      // if the sum has 1 summand, return it directly
      return result_acc.make_expr();
    }
    // ... else if a product, find NormalOperatorSequence, if any, and compute
    // ...
    else if (expr_input_->is<Product>()) {
      if (!skip_input_canonicalization) {  // canonicalize, unless told to skip
        auto canon_byproduct = expr_input_->rapid_canonicalize();
        SEQUANT_ASSERT(
            canon_byproduct ==
            nullptr);  // canonicalization of Product always returns nullptr
      }
      // NOW disable canonicalization of normal operators
      // N.B. even if skipped initial input canonicalization need to disable
      // subsequent nop canonicalization
      disable_nop_canonicalization();

      if (!all_indices_) {
        extract_indices(*(expr_input_.as_shared_ptr<Product>()));
      }

      // split off NormalOperators into input_
      auto first_nop_it = ranges::find_if(
          *expr_input_,
          [](const ExprPtr &expr) { return expr->is<NormalOperator<S>>(); });
      // if have ops, split into nop sequence and cnumber "prefactor"
      if (first_nop_it != ranges::end(*expr_input_)) {
        // extract into prefactor and op sequence
        ExprPtr prefactor =
            ex<CProduct>(expr_input_->as<Product>().scalar(), ExprPtrList{});
        auto nopseq = std::make_shared<NormalOperatorSequence<S>>();
        for (const auto &factor : *expr_input_) {
          if (factor->template is<NormalOperator<S>>()) {
            nopseq->push_back(factor->template as<NormalOperator<S>>());
          } else {
            SEQUANT_ASSERT(factor->is_cnumber());
            *prefactor *= *factor;
          }
        }
        init_input(nopseq);

        // compute and record/analyze topological NormalOperator and Index
        // partitions
        if (use_topology_) {
          if (Logger::instance().wick_topology)
            std::wcout
                << "WickTheorem<S>::compute: input to topology computation = "
                << to_latex(expr_input_) << std::endl;

          // construct graph representation of the tensor product
          using TN = TensorNetwork;
          TN tn(expr_input_->as<Product>().factors());
          auto g = tn.create_graph({.distinct_named_indices = true});
          const auto &graph = g.bliss_graph;
          const auto &vlabels = g.vertex_labels;
          [[maybe_unused]] const auto &vcolors = g.vertex_colors;
          const auto &vtypes = g.vertex_types;
          const auto n = vtypes.size();
          SEQUANT_ASSERT(vcolors.size() == n);
          SEQUANT_ASSERT(vlabels.size() == n);
          const auto &tn_edges = tn.edges();
          const auto &tn_tensors = tn.tensors();
          auto idx_vertex_to_edge_ptr =
              [&](const auto idx_vertex) -> const TN::Edge * {
            SEQUANT_ASSERT(idx_vertex < n);
            const auto edge_idx = g.vertex_to_index_idx(idx_vertex);
            if (edge_idx < tn_edges.size())
              return &tn_edges[edge_idx];
            else  // indices without matching edges are pure protoindices
              return nullptr;
          };

          if (Logger::instance().wick_topology) {
            std::basic_ostringstream<wchar_t> oss;
            graph->write_dot(oss, {.labels = vlabels});
            std::wcout
                << "WickTheorem<S>::compute: colored graph produced from TN = "
                << std::endl
                << oss.str() << std::endl;
          }

          // identify vertex indices of NormalOperator objects and Indices
          // 1. list of vertex indices corresponding to NormalOperator objects
          //    on the TN graph and their ordinals in NormalOperatorSequence
          //    N.B. for NormalOperators the vertex indices coincide with
          //    the ordinals
          container::map<size_t, size_t> nop_vidx_ord;
          // 2. list of vertex indices corresponding to Index objects on the TN
          //    graph that appear in NormalOperatorsSequence and
          //    their ordinals therein
          //    N.B. for Index objects the vertex indices do NOT coincide with
          //         the ordinals
          container::map<size_t, size_t> index_vidx_ord;
          {
            const auto &nop_labels = NormalOperator<S>::labels();
            const auto nop_labels_begin = begin(nop_labels);
            const auto nop_labels_end = end(nop_labels);

            using opseq_view_type =
                flattened_rangenest<NormalOperatorSequence<S>>;
            auto opseq_view = opseq_view_type(input_.get());
            const auto opseq_view_begin = ranges::begin(opseq_view);
            const auto opseq_view_end = ranges::end(opseq_view);

            // NormalOperators are not reordered by canonicalization, hence the
            // ordinal can be computed by counting
            std::size_t nop_ord = 0;
            for (size_t v = 0; v != n; ++v) {
              if (vtypes[v] == VertexType::TensorCore &&
                  (std::find(nop_labels_begin, nop_labels_end, vlabels[v]) !=
                   nop_labels_end)) {
                [[maybe_unused]] auto insertion_result =
                    nop_vidx_ord.emplace(v, nop_ord++);
                SEQUANT_ASSERT(insertion_result.second);
              }
              if (vtypes[v] == VertexType::Index && !input_->empty()) {
                auto *edge_ptr = idx_vertex_to_edge_ptr(v);
                if (edge_ptr) {  // do not consider pure protoindices
                  auto &idx = edge_ptr->idx();
                  auto idx_it_in_opseq = ranges::find_if(
                      opseq_view,
                      [&idx](const auto &v) { return v.index() == idx; });
                  if (idx_it_in_opseq != opseq_view_end) {
                    const auto ord =
                        ranges::distance(opseq_view_begin, idx_it_in_opseq);
                    [[maybe_unused]] auto insertion_result =
                        index_vidx_ord.emplace(v, ord);
                    SEQUANT_ASSERT(insertion_result.second);
                  }
                }
              }
            }
          }

          // compute and save graph automorphism generators
          std::vector<std::vector<unsigned int>> aut_generators;
          {
            bliss::Stats stats;
            graph->set_splitting_heuristic(bliss::Graph::shs_fsm);

            auto save_aut = [&aut_generators](const unsigned int n,
                                              const unsigned int *aut) {
              aut_generators.emplace_back(aut, aut + n);
            };

            graph->find_automorphisms(
                stats, &bliss::aut_hook<decltype(save_aut)>, &save_aut);

            if (Logger::instance().wick_topology) {
              std::basic_ostringstream<wchar_t> oss2;
              bliss::print_auts(aut_generators, oss2, vlabels);
              std::wcout << "WickTheorem<S>::compute: colored graph "
                            "automorphism generators = \n"
                         << oss2.str() << std::endl;
            }
          }

          // Use automorphisms to determine groups of topologically equivalent
          // NormalOperator and Op objects.
          // @param vertices maps vertex indices of the objects to their
          //        ordinals in the sequence of such objects within
          //        the NormalOperatorSequence
          // @param nontrivial_partitions_only if true, only partitions with
          // more than one element, are reported, else even trivial
          // partitions with a single partition will be reported
          // @param vertex_pair_exclude a callable that accepts 2 vertex
          // indices and returns true if the automorphism of this pair
          // of indices is to be ignored; this is used to disregard
          // automorphisms of Index objects unless connected to same bra/ket
          // of an (anti)symmetric NormalOperator.
          // @return the \c {vertex_to_partition_idx,npartitions} pair in
          // which \c vertex_to_partition_idx maps vertex indices that are
          // part of nontrivial partitions to their (1-based) partition indices
          auto compute_partitions = [&aut_generators](
                                        const container::map<size_t, size_t>
                                            &vertices,
                                        bool nontrivial_partitions_only,
                                        auto &&vertex_pair_exclude) {
            container::map<size_t, size_t> vertex_to_partition_idx;
            int next_partition_idx = -1;

            // using each automorphism generator
            for (auto &&aut : aut_generators) {
              // skip automorphism generators that do not involve vertices
              // in `vertices` list
              bool aut_contains_other_vertices = true;
              for (auto &&[v, ord] : vertices) {
                (void)ord;
                const auto v_is_in_aut = v != aut[v];
                if (v_is_in_aut) {
                  aut_contains_other_vertices = false;
                  break;
                }
              }
              if (aut_contains_other_vertices) continue;

              // update partitions
              for (auto &&[v1, ord1] : vertices) {
                const auto v2 = aut[v1];
                if (v2 != v1 &&
                    !vertex_pair_exclude(
                        v1, v2)) {  // if the automorphism maps this vertex to
                                    // another ... they both must be in the same
                                    // partition
                  SEQUANT_ASSERT(vertices.find(v2) != vertices.end());
                  auto v1_partition_it = vertex_to_partition_idx.find(v1);
                  auto v2_partition_it = vertex_to_partition_idx.find(v2);
                  const bool v1_has_partition =
                      v1_partition_it != vertex_to_partition_idx.end();
                  const bool v2_has_partition =
                      v2_partition_it != vertex_to_partition_idx.end();
                  if (v1_has_partition &&
                      v2_has_partition) {  // both are in partitions? make sure
                                           // they are in the same partition.
                                           // N.B. this may leave gaps in
                                           // partition indices ... no biggie
                    const auto v1_part_idx = v1_partition_it->second;
                    const auto v2_part_idx = v2_partition_it->second;
                    if (v1_part_idx !=
                        v2_part_idx) {  // if they have different partition
                                        // indices, change the larger of the two
                                        // indices to match the lower
                      const auto target_part_idx =
                          std::min(v1_part_idx, v2_part_idx);
                      for (auto &v : vertex_to_partition_idx) {
                        if (v.second == v1_part_idx || v.second == v2_part_idx)
                          v.second = target_part_idx;
                      }
                    }
                  } else if (v1_has_partition) {  // only v1 is in a partition?
                                                  // place v2 in it
                    const auto v1_part_idx = v1_partition_it->second;
                    vertex_to_partition_idx.emplace(v2, v1_part_idx);
                  } else if (v2_has_partition) {  // only v2 is in a partition?
                                                  // place v1 in it
                    const auto v2_part_idx = v2_partition_it->second;
                    vertex_to_partition_idx.emplace(v1, v2_part_idx);
                  } else {  // neither is in a partition? place both in the next
                            // available partition
                    const size_t target_part_idx = ++next_partition_idx;
                    vertex_to_partition_idx.emplace(v1, target_part_idx);
                    vertex_to_partition_idx.emplace(v2, target_part_idx);
                  }
                }
              }
            }
            if (!nontrivial_partitions_only) {
              ranges::for_each(vertices, [&](const auto &vidx_ord) {
                auto &&[vidx, ord] = vidx_ord;
                if (vertex_to_partition_idx.find(vidx) ==
                    vertex_to_partition_idx.end()) {
                  vertex_to_partition_idx.emplace(vidx, ++next_partition_idx);
                }
              });
            }
            const auto npartitions = next_partition_idx + 1;
            return std::make_tuple(vertex_to_partition_idx, npartitions);
          };

          // compute NormalOperator->partition map, convert to partition lists
          // (if any), and register via set_nop_partitions to be used in full
          // contractions
          auto do_not_skip_elements = [](size_t, size_t) { return false; };
          auto [nop_vidx2pidx, nop_npartitions] = compute_partitions(
              nop_vidx_ord, /* nontrivial_partitions_only = */ true,
              do_not_skip_elements);

          // converts vertex ordinal to partition key map into a sequence of
          // partitions, each composed of the corresponding ordinals of the
          // vertices in the vertex_list sequence
          // @param vidx2pidx a map from vertex index (in TN) to its
          //        (1-based) partition index
          // @param npartitions the total number of partitions
          // @param vidx_ord ordered sequence of vertex indices, object
          // with vertex index `vidx` will be mapped to ordinal
          // `vidx_ord[vidx]`
          // @return sequence of partitions, sorted by the smallest ordinal
          auto extract_partitions = [](const auto &vidx2pidx,
                                       const auto npartitions,
                                       const auto &vidx_ord) {
            container::svector<container::svector<size_t>> partitions;

            SEQUANT_ASSERT(npartitions > -1);
            const size_t max_pidx = npartitions;
            partitions.reserve(max_pidx);

            // iterate over all partition indices ... note that there may be
            // gaps so count the actual partitions
            size_t partition_cnt = 0;
            for (size_t p = 0; p <= max_pidx; ++p) {
              bool p_found = false;
              for (const auto &[vidx, pidx] : vidx2pidx) {
                if (pidx == p) {
                  // !!remember to map the vertex index into the operator
                  // index!!
                  SEQUANT_ASSERT(vidx_ord.find(vidx) != vidx_ord.end());
                  const auto ordinal = vidx_ord.find(vidx)->second;
                  if (p_found == false) {  // first time this is found
                    partitions.emplace_back(container::svector<size_t>{
                        static_cast<size_t>(ordinal)});
                  } else
                    partitions[partition_cnt].emplace_back(ordinal);
                  p_found = true;
                }
              }
              if (p_found) ++partition_cnt;
            }

            // sort each partition
            for (auto &partition : partitions) {
              ranges::sort(partition);
            }

            // sort partitions in the order of increasing first element
            ranges::sort(partitions, [](const auto &p1, const auto &p2) {
              return p1.front() < p2.front();
            });

            return partitions;
          };

          if (!nop_vidx2pidx.empty()) {
            container::svector<container::svector<size_t>> nop_partitions;

            nop_partitions = extract_partitions(nop_vidx2pidx, nop_npartitions,
                                                nop_vidx_ord);

            if (Logger::instance().wick_topology) {
              std::wcout
                  << "WickTheorem<S>::compute: topological nop partitions:{\n";
              ranges::for_each(nop_partitions, [](auto &&part) {
                std::wcout << "{";
                ranges::for_each(part,
                                 [](auto &&p) { std::wcout << p << " "; });
                std::wcout << "}";
              });
              std::wcout << "}" << std::endl;
            }

            this->set_nop_partitions(nop_partitions);
          }

          // compute Index->partition map, and convert to partition lists (if
          // any), and check that use_topology_ is compatible with index
          // partitions
          // Index partitions are constructed to *only* include Index
          // objects attached to the bra/ket of any NormalOperator! hence
          // need to use filter in computing partitions
          auto exclude_index_vertex_pair = [&tn_tensors,
                                            &idx_vertex_to_edge_ptr](
                                               size_t v1, size_t v2) {
            const auto *edge1_ptr = idx_vertex_to_edge_ptr(v1);
            const auto *edge2_ptr = idx_vertex_to_edge_ptr(v2);
            if (!edge1_ptr || !edge2_ptr) return true;
            const auto &edge1 = *edge1_ptr;
            const auto &edge2 = *edge2_ptr;
            auto connected_to_bra_or_ket_of_same_symmetric_nop =
                [&tn_tensors](const auto &edge1, const auto &edge2) -> bool {
              const auto nt1 = edge1.vertex_count();
              SEQUANT_ASSERT(nt1 <= 2);
              const auto nt2 = edge2.vertex_count();
              SEQUANT_ASSERT(nt2 <= 2);
              for (auto i1 = 0; i1 != nt1; ++i1) {
                const auto tensor1_ord = edge1.vertex(i1).getTerminalIndex();
                for (auto i2 = 0; i2 != nt2; ++i2) {
                  const auto tensor2_ord = edge2.vertex(i2).getTerminalIndex();

                  // do not skip if connected to same ...
                  if (tensor1_ord == tensor2_ord) {
                    auto tensor_ord = tensor1_ord;
                    const std::shared_ptr<AbstractTensor> &tensor_ptr =
                        tn_tensors.at(tensor_ord);

                    // ... (anti)symmetric ...
                    if (tensor_ptr->_symmetry() != Symmetry::Nonsymm) {
                      const auto tensor1_slot_type =
                          edge1.vertex(i1).getOrigin();
                      const auto tensor2_slot_type =
                          edge2.vertex(i2).getOrigin();

                      // ... bra/ket of ...
                      if (tensor1_slot_type == tensor2_slot_type) {
                        // ... NormalOperator!
                        if (std::dynamic_pointer_cast<NormalOperator<S>>(
                                tensor_ptr)) {
                          return true;
                        }
                      }
                    }
                  }
                }
              }
              return false;
            };
            const bool exclude =
                !connected_to_bra_or_ket_of_same_symmetric_nop(edge1, edge2);
            return exclude;
          };

          // index_vidx2pidx maps vertex index (see
          // index_vidx_ord) to partition index
          container::map<size_t, size_t> index_vidx2pidx;
          int index_npartitions = -1;
          std::tie(index_vidx2pidx, index_npartitions) = compute_partitions(
              index_vidx_ord, /* nontrivial_partitions_only = */ false,
              /* this is to ensure that each index partition only involves
                 indices attached to bra or to ket of same
                 symmetric/antisymmetric nop.*/
              exclude_index_vertex_pair);

          if (!index_vidx2pidx.empty()) {
            container::svector<container::svector<size_t>> index_partitions;

            index_partitions = extract_partitions(
                index_vidx2pidx, index_npartitions, index_vidx_ord);

            if (Logger::instance().wick_topology) {
              std::wcout << "WickTheorem<S>::compute: topological index "
                            "partitions:{\n";
              ranges::for_each(
                  index_vidx2pidx, [&idx_vertex_to_edge_ptr](auto &&vidx_pidx) {
                    auto &&[vidx, pidx] = vidx_pidx;
                    auto *edge_ptr = idx_vertex_to_edge_ptr(vidx);
                    // skip pure proto indices
                    if (edge_ptr) {
                      auto &idx = edge_ptr->idx();
                      std::wcout << "Index " << idx.full_label()
                                 << " -> partition " << pidx << "\n";
                    }
                  });
              std::wcout << "}" << std::endl;
            }

            this->set_op_partitions(index_partitions);

            // TODO determine partitions of braket index pairs to be able to
            // exploit topology for spin-free WT note that right now indices
            // attached to bra/ket of spin-free normal operators are excluded
            // from index partitions above
          }
        }

        if (!input_->empty()) {
          if (Logger::instance().wick_contract) {
            std::wcout
                << "WickTheorem<S>::compute: input to compute_nopseq = {\n";
            for (auto &&nop : input_) std::wcout << to_latex(nop) << "\n";
            std::wcout << "}" << std::endl;
          }

          prefactor_ = prefactor;
          auto result = compute_nopseq(count_only);
          prefactor_.reset();

          if (result) {  // simplify if obtained nonzero ...
            // rapid_simplify(result);
            // canonicalize(result);
            rapid_simplify(
                result);  // rapid_simplify again since canonization may produce
                          // new opportunities (e.g. terms cancel, etc.)
          } else
            result = ex<Constant>(0);
          return result;
        }
      } else {  // product does not include ops
        return expr_input_;
      }
    }  // expr_input_->is<Product>()
    // ... else if NormalOperatorSequence already, compute ...
    else if (expr_input_->is<NormalOperatorSequence<S>>()) {
      SEQUANT_ABORT(
          "expr_input_ should no longer be nonnull if constructed with an "
          "expression that's a NormalOperatorSequence<S>");
    } else  // ... else do nothing
      return expr_input_;
  } else {
    // given a NormalOperatorSequence instead of an expression
    auto result = compute_nopseq(count_only);
    if (result) {  // simplify if obtained nonzero ...
      this->reduce(result);
      // N.B. DO NOT CANONICALIZE to preserve index pairings if doing partial
      // contraction
      rapid_simplify(result);
    } else
      result = ex<Constant>(0);
    return result;
  }
  SEQUANT_UNREACHABLE;
}

template <Statistics S>
void WickTheorem<S>::reduce(ExprPtr &expr) const {
  if (Logger::instance().wick_reduce) {
    std::wcout << "WickTheorem<S>::reduce: input = "
               << to_latex_align(expr, 20, 1) << std::endl;
  }

  const bool extracted_indices = !all_indices_;
  if (extracted_indices) {
    extract_indices(*expr);
  }

  // there are 2 possibilities: expr is a single Product, or it's a Sum of
  // Products
  if (expr.is<Product>()) {
    auto expr_cast = std::static_pointer_cast<Product>(expr);
    SEQUANT_ASSERT(external_indices_);
    if (detail::reduce_wick_impl<S>(expr_cast, *external_indices_,
                                    *external_indices_,
                                    get_default_context(S))) {
      expr = expr_cast;
    } else {
      expr = std::make_shared<Constant>(0);
    }
  } else if (expr.is<Sum>()) {
    for (auto &&subexpr : *expr) {
      SEQUANT_ASSERT(subexpr->is<Product>());
      auto subexpr_cast = std::static_pointer_cast<Product>(subexpr);
      SEQUANT_ASSERT(external_indices_);
      if (detail::reduce_wick_impl<S>(subexpr_cast, *external_indices_,
                                      *noncovariant_indices_,
                                      get_default_context(S)))
        subexpr = subexpr_cast;
      else
        subexpr = std::make_shared<Constant>(0);
    }
  }

  if (Logger::instance().wick_reduce) {
    sequant::wprintf(
        "WickTheorem<S>::reduce: result = ", to_latex_align(expr, 20, 1), "\n");
  }
  if (extracted_indices) reset_indices();
}
template <Statistics S>
WickTheorem<S>::~WickTheorem() {}

}  // namespace sequant

#endif  // SEQUANT_WICK_IMPL_HPP