Program Listing for File wick.impl.hpp

Return to documentation for file (SeQuant/core/wick.impl.hpp)

//
// Created by Eduard Valeyev on 3/31/18.
//

#ifndef SEQUANT_WICK_IMPL_HPP
#define SEQUANT_WICK_IMPL_HPP

#include <SeQuant/core/bliss.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/tensor_network.hpp>

#ifdef SEQUANT_HAS_EXECUTION_HEADER
#include <execution>
#endif

namespace sequant {

namespace detail {

struct zero_result : public std::exception {};


template <Statistics S>
container::map<Index, Index> compute_index_replacement_rules(
    std::shared_ptr<Product> &product,
    const container::set<Index> &external_indices,
    const std::set<Index, Index::LabelCompare> &all_indices) {
  expr_range exrng(product);

  auto index_validator = [&all_indices](const Index &idx) {
    return all_indices.find(idx) == all_indices.end();
  };
  IndexFactory idxfac(index_validator);
  container::map<Index /* src */, Index /* dst */> result;  // src->dst

  // computes an index in intersection of space1 and space2
  auto make_intersection_index = [&idxfac](const IndexSpace &space1,
                                           const IndexSpace &space2) {
    auto isr = sequant::get_default_context(S).index_space_registry();
    const auto intersection_space = isr->intersection(space1, space2);
    if (!intersection_space) throw zero_result{};
    return idxfac.make(intersection_space);
  };

  // transfers proto indices from idx (if any) to img
  auto proto = [](const Index &img, const Index &idx) {
    if (idx.has_proto_indices()) {
      if (img.has_proto_indices()) {
        assert(img.proto_indices() == idx.proto_indices());
        return img;
      } else
        return Index(img, idx.proto_indices());
    } else {
      assert(!img.has_proto_indices());
      return img;
    }
  };

  // adds src->dst or src->intersection(dst,current_dst)
  auto add_rule = [&result, &proto, &make_intersection_index](
                      const Index &src, const Index &dst) {
    auto src_it = result.find(src);
    if (src_it == result.end()) {  // if brand new, add the rule
      auto insertion_result = result.emplace(src, proto(dst, src));
      assert(insertion_result.second);
    } else {  // else modify the destination of the existing rule to the
      // intersection
      const auto &old_dst = src_it->second;
      assert(old_dst.proto_indices() == src.proto_indices());
      if (dst.space() != old_dst.space()) {
        src_it->second =
            proto(make_intersection_index(old_dst.space(), dst.space()), src);
      }
    }
  };

  // adds src1->dst and src2->dst; if src1->dst1 and/or src2->dst2 already
  // exist the existing rules are updated to map to the intersection of dst1,
  // dst2 and dst
  auto add_rules = [&result, &idxfac, &proto, &make_intersection_index](
                       const Index &src1, const Index &src2, const Index &dst) {
    auto isr = get_default_context(S).index_space_registry();
    // are there replacement rules already for src{1,2}?
    auto src1_it = result.find(src1);
    auto src2_it = result.find(src2);
    const auto has_src1_rule = src1_it != result.end();
    const auto has_src2_rule = src2_it != result.end();

    // which proto-indices should dst1 and dst2 inherit? a source index without
    // proto indices will inherit its source counterpart's indices, unless it
    // already has its own protoindices: <a_ij|p> = <a_ij|a_ij> (hence replace p
    // with a_ij), but <a_ij|p_kl> = <a_ij|a_kl> != <a_ij|a_ij> (hence replace
    // p_kl with a_kl)
    const auto &dst1_proto =
        !src1.has_proto_indices() && src2.has_proto_indices() ? src2 : src1;
    const auto &dst2_proto =
        !src2.has_proto_indices() && src1.has_proto_indices() ? src1 : src2;

    if (!has_src1_rule && !has_src2_rule) {  // if brand new, add the rules
      auto insertion_result1 = result.emplace(src1, proto(dst, dst1_proto));
      assert(insertion_result1.second);
      auto insertion_result2 = result.emplace(src2, proto(dst, dst2_proto));
      assert(insertion_result2.second);
    } else if (has_src1_rule &&
               !has_src2_rule) {  // update the existing rule for src1
      const auto &old_dst1 = src1_it->second;
      assert(old_dst1.proto_indices() == dst1_proto.proto_indices());
      if (dst.space() != old_dst1.space()) {
        src1_it->second = proto(
            make_intersection_index(old_dst1.space(), dst.space()), dst1_proto);
      }
      result.emplace(src2, src1_it->second);
    } else if (!has_src1_rule &&
               has_src2_rule) {  // update the existing rule for src2
      const auto &old_dst2 = src2_it->second;
      assert(old_dst2.proto_indices() == dst2_proto.proto_indices());
      if (dst.space() != old_dst2.space()) {
        src2_it->second = proto(
            make_intersection_index(old_dst2.space(), dst.space()), dst2_proto);
      }
      result.emplace(src1, src2_it->second);
    } else {  // update both of the existing rules
      const auto &old_dst1 = src1_it->second;
      const auto &old_dst2 = src2_it->second;
      const auto new_dst_space =
          (dst.space() != old_dst1.space() || dst.space() != old_dst2.space())
              ? isr->intersection(
                    isr->intersection(old_dst1.space(), old_dst2.space()),
                    dst.space())
              : dst.space();
      if (!new_dst_space) throw zero_result{};
      Index new_dst;
      if (new_dst_space == old_dst1.space()) {
        new_dst = old_dst1;
        if (new_dst_space == old_dst2.space() && old_dst2 < new_dst) {
          new_dst = old_dst2;
        }
        if (new_dst_space == dst.space() && dst < new_dst) {
          new_dst = dst;
        }
      } else if (new_dst_space == old_dst2.space()) {
        new_dst = old_dst2;
        if (new_dst_space == dst.space() && dst < new_dst) {
          new_dst = dst;
        }
      } else if (new_dst_space == dst.space()) {
        new_dst = dst;
      } else
        new_dst = idxfac.make(new_dst_space);
      result.emplace(src1, proto(new_dst, dst1_proto));
      result.emplace(src2, proto(new_dst, dst2_proto));
    }
  };

  auto isr = get_default_context(S).index_space_registry();
  for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
    const auto &factor = *it;
    if (factor->type_id() == Expr::get_type_id<Tensor>()) {
      const auto &tensor = static_cast<const Tensor &>(*factor);
      if (tensor.label() == overlap_label()) {
        assert(tensor.bra().size() == 1);
        assert(tensor.ket().size() == 1);
        const auto &bra = tensor.bra().at(0);
        const auto &ket = tensor.ket().at(0);
        assert(bra != ket);

        const auto bra_is_ext = ranges::find(external_indices, bra) !=
                                ranges::end(external_indices);
        const auto ket_is_ext = ranges::find(external_indices, ket) !=
                                ranges::end(external_indices);

        const auto intersection_space =
            isr->intersection(bra.space(), ket.space());

        // if overlap's indices are from non-overlapping spaces, return zero
        if (!intersection_space) {
          throw zero_result{};
        }

        if (!bra_is_ext && !ket_is_ext) {  // int + int
          const auto new_dummy = idxfac.make(intersection_space);
          add_rules(bra, ket, new_dummy);
        } else if (bra_is_ext && !ket_is_ext) {  // ext + int
          if (includes(ket.space(), bra.space())) {
            add_rule(ket, bra);
          } else {
            add_rule(ket, idxfac.make(intersection_space));
          }
        } else if (!bra_is_ext && ket_is_ext) {  // int + ext
          if (includes(bra.space(), ket.space())) {
            add_rule(bra, ket);
          } else {
            add_rule(bra, idxfac.make(intersection_space));
          }
        }
      }
    }
  }

  return result;
}

inline bool apply_index_replacement_rules(
    std::shared_ptr<Product> &product,
    const container::map<Index, Index> &const_replrules,
    const container::set<Index> &external_indices,
    std::set<Index, Index::LabelCompare> &all_indices,
    const std::shared_ptr<const IndexSpaceRegistry> &isr) {
  // to be able to use map[]
  auto &replrules = const_cast<container::map<Index, Index> &>(const_replrules);

  expr_range exrng(product);

#ifndef NDEBUG
  // assert that tensors_ indices are not tagged since going to tag indices
  {
    for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
      const auto &factor = *it;
      if (factor->is<Tensor>()) {
        auto &tensor = factor->as<Tensor>();
        assert(ranges::none_of(tensor.const_braket(), [](const Index &idx) {
          return idx.tag().has_value();
        }));
      }
    }
  }
#endif
  bool mutated = false;
  bool pass_mutated = false;
  do {
    pass_mutated = false;

    for (auto it = ranges::begin(exrng); it != ranges::end(exrng);) {
      const auto &factor = *it;
      if (factor->is<Tensor>()) {
        bool erase_it = false;
        auto &tensor = factor->as<Tensor>();

        pass_mutated &= tensor.transform_indices(const_replrules);

        if (tensor.label() == overlap_label()) {
          const auto &bra = tensor.bra().at(0);
          const auto &ket = tensor.ket().at(0);

          if (bra.proto_indices() == ket.proto_indices()) {
            const auto bra_is_ext = ranges::find(external_indices, bra) !=
                                    ranges::end(external_indices);
            const auto ket_is_ext = ranges::find(external_indices, ket) !=
                                    ranges::end(external_indices);

#ifndef NDEBUG
            const auto intersection_space =
                isr->intersection(bra.space(), ket.space());
#endif

            if (!bra_is_ext && !ket_is_ext) {  // int + int
#ifndef NDEBUG
              if (replrules.find(bra) != replrules.end() &&
                  replrules.find(ket) != replrules.end())
                assert(replrules[bra].space() == replrules[ket].space());
#endif
              erase_it = true;
            } else if (bra_is_ext && !ket_is_ext) {  // ext + int
              if (isr->intersection(ket.space(), bra.space()) !=
                  IndexSpace::null) {
#ifndef NDEBUG
                if (replrules.find(ket) != replrules.end())
                  assert(replrules[ket].space() == bra.space());
#endif
                erase_it = true;
              } else {
#ifndef NDEBUG
                if (replrules.find(ket) != replrules.end())
                  assert(replrules[ket].space() == intersection_space);
#endif
              }
            } else if (!bra_is_ext && ket_is_ext) {  // int + ext
              if (isr->intersection(bra.space(), ket.space()) !=
                  IndexSpace::null) {
#ifndef NDEBUG
                if (replrules.find(bra) != replrules.end())
                  assert(replrules[bra].space() == ket.space());
#endif
                erase_it = true;
              } else {
#ifndef NDEBUG
                if (replrules.find(bra) != replrules.end())
                  assert(replrules[bra].space() == intersection_space);
#endif
              }
            } else {  // ext + ext
              if (bra == ket) erase_it = true;
            }

            if (erase_it) {
              pass_mutated = true;
              *it = ex<Constant>(1);
            }
          }  // matching proto indices
        }    // Kronecker delta
      }
      ++it;
    }
    mutated |= pass_mutated;
  } while (pass_mutated);  // keep replacing til fixed point

  // assert that tensors_ indices are not tagged since going to tag indices
  {
    for (auto it = ranges::begin(exrng); it != ranges::end(exrng); ++it) {
      const auto &factor = *it;
      if (factor->is<Tensor>()) {
        factor->as<Tensor>().reset_tags();
      }
    }
  }

  // update all_indices
  std::set<Index, Index::LabelCompare> all_indices_new;
  ranges::for_each(
      all_indices, [&const_replrules, &all_indices_new](const Index &idx) {
        auto dst_it = const_replrules.find(idx);
        [[maybe_unused]] auto insertion_result = all_indices_new.emplace(
            dst_it != const_replrules.end() ? dst_it->second : idx);
      });
  std::swap(all_indices_new, all_indices);

  return mutated;
}

template <Statistics S>
void reduce_wick_impl(std::shared_ptr<Product> &expr,
                      const container::set<Index> &external_indices) {
  if (get_default_context(S).metric() == IndexSpaceMetric::Unit) {
    bool pass_mutated = false;
    do {
      pass_mutated = false;

      // extract current indices
      std::set<Index, Index::LabelCompare> all_indices;
      ranges::for_each(*expr, [&all_indices](const auto &factor) {
        if (factor->template is<Tensor>()) {
          ranges::for_each(factor->template as<const Tensor>().braket(),
                           [&all_indices](const Index &idx) {
                             [[maybe_unused]] auto result =
                                 all_indices.insert(idx);
                           });
        }
      });

      const auto replacement_rules = compute_index_replacement_rules<S>(
          expr, external_indices, all_indices);

      if (Logger::instance().wick_reduce) {
        std::wcout << "reduce_wick_impl(expr, external_indices):\n  expr = "
                   << expr->to_latex() << "\n  external_indices = ";
        ranges::for_each(external_indices, [](auto &index) {
          std::wcout << index.label() << " ";
        });
        std::wcout << "\n  replrules = ";
        ranges::for_each(replacement_rules, [](auto &index) {
          std::wcout << to_latex(index.first) << "\\to"
                     << to_latex(index.second) << "\\,";
        });
        std::wcout.flush();
      }

      if (!replacement_rules.empty()) {
        auto isr = get_default_context(S).index_space_registry();
        pass_mutated = apply_index_replacement_rules(
            expr, replacement_rules, external_indices, all_indices, isr);
      }

      if (Logger::instance().wick_reduce) {
        std::wcout << "\n  result = " << expr->to_latex() << std::endl;
      }

    } while (pass_mutated);  // keep reducing until stop changing
  } else
    abort();  // programming error?
}

template <Statistics S>
struct NullNormalOperatorCanonicalizerDeregister {
  void operator()(void *) {
    const auto nop_labels = NormalOperator<S>::labels();
    TensorCanonicalizer::deregister_instance(nop_labels[0]);
    TensorCanonicalizer::deregister_instance(nop_labels[1]);
  }
};

}  // namespace detail

inline container::set<Index> extract_external_indices(const Expr &expr) {
  if (ranges::any_of(expr, [](auto &e) { return e.template is<Sum>(); }))
    throw std::invalid_argument(
        "extract_external_indices(expr): expr must be expanded (i.e. no "
        "subexpression can be a Sum)");

  container::map<Index, int64_t> idx_counter;
  auto visitor = [&idx_counter](const auto &expr) {
    auto expr_as_abstract_tensor =
        std::dynamic_pointer_cast<AbstractTensor>(expr);
    if (expr_as_abstract_tensor) {
      ranges::for_each(expr_as_abstract_tensor->_braket(),
                       [&idx_counter](const auto &v) {
                         auto it = idx_counter.find(v);
                         if (it == idx_counter.end()) {
                           idx_counter.emplace(v, 1);
                         } else {
                           it->second++;
                         }
                       });
    }
  };
  expr.visit(visitor);

  return idx_counter |
         ranges::views::filter([](const auto &v) { return v.second == 1; }) |
         ranges::views::transform([](const auto &v) { return v.first; }) |
         ranges::to<container::set<Index>>;
}

template <Statistics S>
ExprPtr WickTheorem<S>::compute(const bool count_only,
                                const bool skip_input_canonicalization) {
  // need to avoid recanonicalization of operators produced by WickTheorem
  // by rapid canonicalization to avoid undoing all the good
  // the NormalOperator<S>::normalize did ... use RAII
  // 1. detail::NullNormalOperatorCanonicalizerDeregister<S> will restore state
  // of tensor canonicalizer
  // 2. this is the RAII object whose destruction will restore state of
  // the tensor canonicalizer
  std::unique_ptr<void, detail::NullNormalOperatorCanonicalizerDeregister<S>>
      raii_null_nop_canonicalizer;
  // 3. this makes the RAII object  ... NOT reentrant, only to be called in
  // top-level WickTheorem after initial canonicalization
  auto disable_nop_canonicalization = [&raii_null_nop_canonicalizer]() {
    if (!raii_null_nop_canonicalizer) {
      const auto nop_labels = NormalOperator<S>::labels();
      assert(nop_labels.size() == 2);
      TensorCanonicalizer::try_register_instance(
          std::make_shared<NullTensorCanonicalizer>(), nop_labels[0]);
      TensorCanonicalizer::try_register_instance(
          std::make_shared<NullTensorCanonicalizer>(), nop_labels[1]);
      raii_null_nop_canonicalizer = decltype(raii_null_nop_canonicalizer)(
          (void *)&raii_null_nop_canonicalizer, {});
    }
  };

  // have an Expr as input? Apply recursively ...
  if (expr_input_) {
    if (Logger::instance().wick_harness)
      std::wcout << "WickTheorem<S>::compute: input (before expand) = "
                 << to_latex_align(expr_input_) << std::endl;
    expand(expr_input_);
    if (Logger::instance().wick_harness)
      std::wcout << "WickTheorem<S>::compute: input (after expand) = "
                 << to_latex_align(expr_input_) << std::endl;
    // if sum, canonicalize and apply to each summand ...
    if (expr_input_->is<Sum>()) {
      if (!skip_input_canonicalization) {
        // initial full canonicalization
        canonicalize(expr_input_);
        assert(!expr_input_->as<Sum>().empty());
      }

      // NOW disable canonicalization of normal operators
      // N.B. even if skipped initial input canonicalization need to disable
      // subsequent nop canonicalization
      disable_nop_canonicalization();

      // parallelize over summands
      auto result = std::make_shared<Sum>();
      std::mutex result_mtx;  // serializes updates of result
      auto summands = expr_input_->as<Sum>().summands();

      // find external_indices if don't have them
      if (!external_indices_) {
        ranges::find_if(summands, [this](const auto &summand) {
          if (summand.template is<Sum>())  // summands must not be a Sum
            throw std::invalid_argument(
                "WickTheorem<S>::compute(expr): expr is a Sum with one of the "
                "summands also a Sum, WickTheorem can only accept a fully "
                "expanded Sum");
          else if (summand.template is<Product>()) {
            external_indices_ = extract_external_indices(
                *(summand.template as_shared_ptr<Product>()));
            return true;
          } else
            return false;
        });
      }

      if (Logger::instance().wick_harness)
        std::wcout << "WickTheorem<S>::compute: input (after canonicalize) has "
                   << summands.size() << " terms = " << to_latex_align(result)
                   << std::endl;

      auto wick_task = [&result, &result_mtx, this,
                        &count_only](const ExprPtr &input) {
        WickTheorem wt(input->clone(), *this);
        auto task_result = wt.compute(
            count_only, /* definitely skip input canonicalization */ true);
        stats() += wt.stats();
        if (task_result) {
          std::scoped_lock<std::mutex> lock(result_mtx);
          result->append(task_result);
        }
      };
      sequant::for_each(summands, wick_task);

      // if the sum is empty return zero
      // if the sum has 1 summand, return it directly
      ExprPtr result_expr = result;
      if (result->summands().size() == 0) {
        result_expr = ex<Constant>(0);
      }
      if (result->summands().size() == 1)
        result_expr = std::move(result->summands()[0]);

      return result_expr;
    }
    // ... else if a product, find NormalOperatorSequence, if any, and compute
    // ...
    else if (expr_input_->is<Product>()) {
      if (!skip_input_canonicalization) {  // canonicalize, unless told to skip
        auto canon_byproduct = expr_input_->rapid_canonicalize();
        assert(canon_byproduct ==
               nullptr);  // canonicalization of Product always returns nullptr
      }
      // NOW disable canonicalization of normal operators
      // N.B. even if skipped initial input canonicalization need to disable
      // subsequent nop canonicalization
      disable_nop_canonicalization();

      // find external_indices if don't have them
      if (!external_indices_) {
        external_indices_ =
            extract_external_indices(*(expr_input_.as_shared_ptr<Product>()));
      } else {
        assert(
            extract_external_indices(*(expr_input_.as_shared_ptr<Product>())) ==
            *external_indices_);
      }

      // split off NormalOperators into input_
      auto first_nop_it = ranges::find_if(
          *expr_input_,
          [](const ExprPtr &expr) { return expr->is<NormalOperator<S>>(); });
      // if have ops, split into nop sequence and cnumber "prefactor"
      if (first_nop_it != ranges::end(*expr_input_)) {
        // extract into prefactor and op sequence
        ExprPtr prefactor =
            ex<CProduct>(expr_input_->as<Product>().scalar(), ExprPtrList{});
        auto nopseq = std::make_shared<NormalOperatorSequence<S>>();
        for (const auto &factor : *expr_input_) {
          if (factor->template is<NormalOperator<S>>()) {
            nopseq->push_back(factor->template as<NormalOperator<S>>());
          } else {
            assert(factor->is_cnumber());
            *prefactor *= *factor;
          }
        }
        init_input(nopseq);

        // compute and record/analyze topological NormalOperator and Index
        // partitions
        if (use_topology_) {
          if (Logger::instance().wick_topology)
            std::wcout
                << "WickTheorem<S>::compute: input to topology computation = "
                << to_latex(expr_input_) << std::endl;

          // construct graph representation of the tensor product
          TensorNetwork tn(expr_input_->as<Product>().factors());
          auto [graph, vlabels, vcolors, vtypes] = tn.make_bliss_graph();
          const auto n = vlabels.size();
          assert(vtypes.size() == n);
          const auto &tn_edges = tn.edges();
          const auto &tn_tensors = tn.tensors();

          if (Logger::instance().wick_topology) {
            std::basic_ostringstream<wchar_t> oss;
            graph->write_dot(oss, vlabels);
            std::wcout
                << "WickTheorem<S>::compute: colored graph produced from TN = "
                << std::endl
                << oss.str() << std::endl;
          }

          // identify vertex indices of NormalOperator objects and Indices
          // 1. list of vertex indices corresponding to NormalOperator objects
          //    on the TN graph and their ordinals in NormalOperatorSequence
          //    N.B. for NormalOperators the vertex indices coincide with
          //    the ordinals
          container::map<size_t, size_t> nop_vidx_ord;
          // 2. list of vertex indices corresponding to Index objects on the TN
          //    graph that appear in NormalOperatorsSequence and
          //    their ordinals therein
          //    N.B. for Index objects the vertex indices do NOT coincide with
          //         the ordinals
          container::map<size_t, size_t> index_vidx_ord;
          {
            const auto &nop_labels = NormalOperator<S>::labels();
            const auto nop_labels_begin = begin(nop_labels);
            const auto nop_labels_end = end(nop_labels);

            using opseq_view_type =
                flattened_rangenest<NormalOperatorSequence<S>>;
            auto opseq_view = opseq_view_type(input_.get());
            const auto opseq_view_begin = ranges::begin(opseq_view);
            const auto opseq_view_end = ranges::end(opseq_view);

            // NormalOperators are not reordered by canonicalization, hence the
            // ordinal can be computed by counting
            std::size_t nop_ord = 0;
            for (size_t v = 0; v != n; ++v) {
              if (vtypes[v] == TensorNetwork::VertexType::TensorCore &&
                  (std::find(nop_labels_begin, nop_labels_end, vlabels[v]) !=
                   nop_labels_end)) {
                auto insertion_result = nop_vidx_ord.emplace(v, nop_ord++);
                assert(insertion_result.second);
              }
              if (vtypes[v] == TensorNetwork::VertexType::Index &&
                  !input_->empty()) {
                auto &idx = (tn_edges.begin() + v)->idx();
                auto idx_it_in_opseq = ranges::find_if(
                    opseq_view,
                    [&idx](const auto &v) { return v.index() == idx; });
                if (idx_it_in_opseq != opseq_view_end) {
                  const auto ord =
                      ranges::distance(opseq_view_begin, idx_it_in_opseq);
                  auto insertion_result = index_vidx_ord.emplace(v, ord);
                  assert(insertion_result.second);
                }
              }
            }
          }

          // compute and save graph automorphism generators
          std::vector<std::vector<unsigned int>> aut_generators;
          {
            bliss::Stats stats;
            graph->set_splitting_heuristic(bliss::Graph::shs_fsm);

            auto save_aut = [&aut_generators](const unsigned int n,
                                              const unsigned int *aut) {
              aut_generators.emplace_back(aut, aut + n);
            };

            graph->find_automorphisms(
                stats, &bliss::aut_hook<decltype(save_aut)>, &save_aut);

            if (Logger::instance().wick_topology) {
              std::basic_ostringstream<wchar_t> oss2;
              bliss::print_auts(aut_generators, oss2, vlabels);
              std::wcout << "WickTheorem<S>::compute: colored graph "
                            "automorphism generators = \n"
                         << oss2.str() << std::endl;
            }
          }

          // Use automorphisms to determine groups of topologically equivalent
          // NormalOperator and Op objects.
          // @param vertices maps vertex indices of the objects to their
          //        ordinals in the sequence of such objects within
          //        the NormalOperatorSequence
          // @param nontrivial_partitions_only if true, only partitions with
          // more than one element, are reported, else even trivial
          // partitions with a single partition will be reported
          // @param vertex_pair_exclude a callable that accepts 2 vertex
          // indices and returns true if the automorphism of this pair
          // of indices is to be ignored
          // @return the \c {vertex_to_partition_idx,npartitions} pair in
          // which \c vertex_to_partition_idx maps vertex indices that are
          // part of nontrivial partitions to their (1-based) partition indices
          auto compute_partitions = [&aut_generators](
                                        const container::map<size_t, size_t>
                                            &vertices,
                                        bool nontrivial_partitions_only,
                                        auto &&vertex_pair_exclude) {
            container::map<size_t, size_t> vertex_to_partition_idx;
            int next_partition_idx = -1;

            // using each automorphism generator
            for (auto &&aut : aut_generators) {
              // skip automorphism generators that involve vertices that are
              // not part of vertices
              // this prevents topology exploitation for spin-free Wick
              // TODO learn how to compute partitions correctly for
              //      spin-free cases
              const auto nv = aut.size();
              bool aut_contains_other_vertices = false;
              for (std::size_t v = 0; v != nv; ++v) {
                const auto v_is_in_aut = v != aut[v];
                if (v_is_in_aut && !vertices.contains(v)) {
                  aut_contains_other_vertices = true;
                  break;
                }
              }
              if (aut_contains_other_vertices) continue;

              // update partitions
              for (auto &&[v1, ord1] : vertices) {
                const auto v2 = aut[v1];
                if (v2 != v1 &&
                    !vertex_pair_exclude(
                        v1, v2)) {  // if the automorphism maps this vertex to
                                    // another ... they both must be in the same
                                    // partition
                  assert(vertices.find(v2) != vertices.end());
                  auto v1_partition_it = vertex_to_partition_idx.find(v1);
                  auto v2_partition_it = vertex_to_partition_idx.find(v2);
                  const bool v1_has_partition =
                      v1_partition_it != vertex_to_partition_idx.end();
                  const bool v2_has_partition =
                      v2_partition_it != vertex_to_partition_idx.end();
                  if (v1_has_partition &&
                      v2_has_partition) {  // both are in partitions? make sure
                                           // they are in the same partition.
                                           // N.B. this may leave gaps in
                                           // partition indices ... no biggie
                    const auto v1_part_idx = v1_partition_it->second;
                    const auto v2_part_idx = v2_partition_it->second;
                    if (v1_part_idx !=
                        v2_part_idx) {  // if they have different partition
                                        // indices, change the larger of the two
                                        // indices to match the lower
                      const auto target_part_idx =
                          std::min(v1_part_idx, v2_part_idx);
                      for (auto &v : vertex_to_partition_idx) {
                        if (v.second == v1_part_idx || v.second == v2_part_idx)
                          v.second = target_part_idx;
                      }
                    }
                  } else if (v1_has_partition) {  // only v1 is in a partition?
                                                  // place v2 in it
                    const auto v1_part_idx = v1_partition_it->second;
                    vertex_to_partition_idx.emplace(v2, v1_part_idx);
                  } else if (v2_has_partition) {  // only v2 is in a partition?
                                                  // place v1 in it
                    const auto v2_part_idx = v2_partition_it->second;
                    vertex_to_partition_idx.emplace(v1, v2_part_idx);
                  } else {  // neither is in a partition? place both in the next
                            // available partition
                    const size_t target_part_idx = ++next_partition_idx;
                    vertex_to_partition_idx.emplace(v1, target_part_idx);
                    vertex_to_partition_idx.emplace(v2, target_part_idx);
                  }
                }
              }
            }
            if (!nontrivial_partitions_only) {
              ranges::for_each(vertices, [&](const auto &vidx_ord) {
                auto &&[vidx, ord] = vidx_ord;
                if (vertex_to_partition_idx.find(vidx) ==
                    vertex_to_partition_idx.end()) {
                  vertex_to_partition_idx.emplace(vidx, ++next_partition_idx);
                }
              });
            }
            const auto npartitions = next_partition_idx;
            return std::make_tuple(vertex_to_partition_idx, npartitions);
          };

          // compute NormalOperator->partition map, convert to partition lists
          // (if any), and register via set_nop_partitions to be used in full
          // contractions
          auto do_not_skip_elements = [](size_t v1, size_t v2) {
            return false;
          };
          auto [nop_vidx2pidx, nop_npartitions] = compute_partitions(
              nop_vidx_ord, /* nontrivial_partitions_only = */ true,
              do_not_skip_elements);

          // converts vertex ordinal to partition key map into a sequence of
          // partitions, each composed of the corresponding ordinals of the
          // vertices in the vertex_list sequence
          // @param vidx2pidx a map from vertex index (in TN) to its
          //        (1-based) partition index
          // @param npartitions the total number of partitions
          // @param vidx_ord ordered sequence of vertex indices, object
          // with vertex index `vidx` will be mapped to ordinal
          // `vidx_ord[vidx]`
          // @return sequence of partitions, sorted by the smallest ordinal
          auto extract_partitions = [](const auto &vidx2pidx,
                                       const auto npartitions,
                                       const auto &vidx_ord) {
            container::svector<container::svector<size_t>> partitions;

            assert(npartitions > -1);
            const size_t max_pidx = npartitions;
            partitions.reserve(max_pidx);

            // iterate over all partition indices ... note that there may be
            // gaps so count the actual partitions
            size_t partition_cnt = 0;
            for (size_t p = 0; p <= max_pidx; ++p) {
              bool p_found = false;
              for (const auto &[vidx, pidx] : vidx2pidx) {
                if (pidx == p) {
                  // !!remember to map the vertex index into the operator
                  // index!!
                  assert(vidx_ord.find(vidx) != vidx_ord.end());
                  const auto ordinal = vidx_ord.find(vidx)->second;
                  if (p_found == false) {  // first time this is found
                    partitions.emplace_back(container::svector<size_t>{
                        static_cast<size_t>(ordinal)});
                  } else
                    partitions[partition_cnt].emplace_back(ordinal);
                  p_found = true;
                }
              }
              if (p_found) ++partition_cnt;
            }

            // sort each partition
            for (auto &partition : partitions) {
              ranges::sort(partition);
            }

            // sort partitions in the order of increasing first element
            ranges::sort(partitions, [](const auto &p1, const auto &p2) {
              return p1.front() < p2.front();
            });

            return partitions;
          };

          if (!nop_vidx2pidx.empty()) {
            container::svector<container::svector<size_t>> nop_partitions;

            nop_partitions = extract_partitions(nop_vidx2pidx, nop_npartitions,
                                                nop_vidx_ord);

            if (Logger::instance().wick_topology) {
              std::wcout
                  << "WickTheorem<S>::compute: topological nop partitions:{\n";
              ranges::for_each(nop_partitions, [](auto &&part) {
                std::wcout << "{";
                ranges::for_each(part,
                                 [](auto &&p) { std::wcout << p << " "; });
                std::wcout << "}";
              });
              std::wcout << "}" << std::endl;
            }

            this->set_nop_partitions(nop_partitions);
          }

          // compute Index->partition map, and convert to partition lists (if
          // any), and check that use_topology_ is compatible with index
          // partitions
          // Index partitions are constructed to *only* include Index
          // objects attached to the bra/ket of any NormalOperator! hence
          // need to use filter in computing partitions
          auto exclude_index_vertex_pair = [&tn_tensors, &tn_edges](size_t v1,
                                                                    size_t v2) {
            // v1 and v2 are vertex indices and also index the edges in the
            // TensorNetwork
            assert(v1 < tn_edges.size());
            assert(v2 < tn_edges.size());
            const auto &edge1 = *(tn_edges.begin() + v1);
            const auto &edge2 = *(tn_edges.begin() + v2);
            auto connected_to_same_nop = [&tn_tensors](int term1, int term2) {
              if (term1 == term2 && term1 != 0) {
                auto tensor_idx = std::abs(term1) - 1;
                const std::shared_ptr<AbstractTensor> &tensor_ptr =
                    tn_tensors.at(tensor_idx);
                if (std::dynamic_pointer_cast<NormalOperator<S>>(tensor_ptr))
                  return true;
              }
              return false;
            };
            const bool exclude =
                !(connected_to_same_nop(edge1.first(), edge2.first()) ||
                  connected_to_same_nop(edge1.first(), edge2.second()) ||
                  connected_to_same_nop(edge1.second(), edge2.first()) ||
                  connected_to_same_nop(edge1.second(), edge2.second()));
            return exclude;
          };

          // index_vidx2pidx maps vertex index (see
          // index_vidx_ord) to partition index
          container::map<size_t, size_t> index_vidx2pidx;
          int index_npartitions = -1;
          std::tie(index_vidx2pidx, index_npartitions) = compute_partitions(
              index_vidx_ord, /* nontrivial_partitions_only = */ false,
              exclude_index_vertex_pair);

          if (!index_vidx2pidx.empty()) {
            container::svector<container::svector<size_t>> index_partitions;

            index_partitions = extract_partitions(
                index_vidx2pidx, index_npartitions, index_vidx_ord);

            if (Logger::instance().wick_topology) {
              std::wcout << "WickTheorem<S>::compute: topological index "
                            "partitions:{\n";
              ranges::for_each(index_vidx2pidx, [&tn_edges](auto &&vidx_pidx) {
                auto &&[vidx, pidx] = vidx_pidx;
                assert(vidx < tn_edges.size());
                auto &idx = (tn_edges.begin() + vidx)->idx();
                std::wcout << "Index " << idx.full_label() << " -> partition "
                           << pidx << "\n";
              });
              std::wcout << "}" << std::endl;
            }

            this->set_op_partitions(index_partitions);
          }
        }

        if (!input_->empty()) {
          if (Logger::instance().wick_contract) {
            std::wcout
                << "WickTheorem<S>::compute: input to compute_nopseq = {\n";
            for (auto &&nop : input_) std::wcout << to_latex(nop) << "\n";
            std::wcout << "}" << std::endl;
          }
          auto result = compute_nopseq(count_only);
          if (result) {  // simplify if obtained nonzero ...
            result = prefactor * result;
            expand(result);
            this->reduce(result);
            rapid_simplify(result);
            canonicalize(result);
            rapid_simplify(
                result);  // rapid_simplify again since canonization may produce
                          // new opportunities (e.g. terms cancel, etc.)
          } else
            result = ex<Constant>(0);
          return result;
        }
      } else {  // product does not include ops
        return expr_input_;
      }
    }  // expr_input_->is<Product>()
    // ... else if NormalOperatorSequence already, compute ...
    else if (expr_input_->is<NormalOperatorSequence<S>>()) {
      abort();  // expr_input_ should no longer be nonnull if constructed with
                // an expression that's a NormalOperatorSequence<S>
      init_input(
          expr_input_.template as_shared_ptr<NormalOperatorSequence<S>>());
      // NB no simplification possible for a bare product w/ full contractions
      // ... partial contractions will need simplification
      return compute_nopseq(count_only);
    } else  // ... else do nothing
      return expr_input_;
  } else  // given a NormalOperatorSequence instead of an expression
    return compute_nopseq(count_only);
  abort();
}

template <Statistics S>
void WickTheorem<S>::reduce(ExprPtr &expr) const {
  if (Logger::instance().wick_reduce) {
    std::wcout << "WickTheorem<S>::reduce: input = "
               << to_latex_align(expr, 20, 1) << std::endl;
  }
  // there are 2 possibilities: expr is a single Product, or it's a Sum of
  // Products
  if (expr->type_id() == Expr::get_type_id<Product>()) {
    auto expr_cast = std::static_pointer_cast<Product>(expr);
    try {
      assert(external_indices_);
      detail::reduce_wick_impl<S>(expr_cast, *external_indices_);
      expr = expr_cast;
    } catch (detail::zero_result &) {
      expr = std::make_shared<Constant>(0);
    }
  } else {
    assert(expr->type_id() == Expr::get_type_id<Sum>());
    for (auto &&subexpr : *expr) {
      assert(subexpr->is<Product>());
      auto subexpr_cast = std::static_pointer_cast<Product>(subexpr);
      try {
        assert(external_indices_);
        detail::reduce_wick_impl<S>(subexpr_cast, *external_indices_);
        subexpr = subexpr_cast;
      } catch (detail::zero_result &) {
        subexpr = std::make_shared<Constant>(0);
      }
    }
  }

  if (Logger::instance().wick_reduce) {
    std::wcout << "WickTheorem<S>::reduce: result = "
               << to_latex_align(expr, 20, 1) << std::endl;
  }
}
template <Statistics S>
WickTheorem<S>::~WickTheorem() {}

}  // namespace sequant

#endif  // SEQUANT_WICK_IMPL_HPP