#include <cassert>
#include <cmath>
#include <cstddef>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <type_traits>
#include <utility>

#include <SeQuant/core/abstract_tensor.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/tensor_network.hpp>

#if __cplusplus >= 202002L
#include <bit>

namespace {

template <typename T>
bool has_single_bit(T x) noexcept {
#if __cplusplus < 202002L
  return x != 0 && (x & (x - 1)) == 0;
  return std::has_single_bit(x);

}  // namespace

namespace sequant {

class Tensor;

// EvalNode optimize(ExprPtr const& expr);

namespace opt {

namespace {

template <
    typename I, typename F,
    typename = std::enable_if_t<std::is_integral_v<I> && std::is_unsigned_v<I>>,
    typename = std::enable_if_t<std::is_invocable_v<F, I, I>>>
void biparts(I n, F const& func) {
  if (n == 0) return;
  I const h = static_cast<I>(std::floor(n / 2.0));
  for (I n_ = 1; n_ <= h; ++n_) {
    auto const l = n & n_;
    auto const r = (n - n_) & n;
    if ((l | r) == n) func(l, r);

template <typename IdxToSz,
          std::enable_if_t<std::is_invocable_r_v<size_t, IdxToSz, Index>,
                           bool> = true>
double ops_count(IdxToSz const& idxsz, container::svector<Index> const& commons,
                 container::svector<Index> const& diffs) {
  double ops = 1.0;
  for (auto&& idx : ranges::views::concat(commons, diffs))
    ops *= std::invoke(idxsz, idx);
  // ops == 1.0 implies both commons and diffs empty
  return ops == 1.0 ? 0 : ops;

using eval_seq_t = container::svector<int>;

struct OptRes {
  container::svector<sequant::Index> indices;

  double flops;

  eval_seq_t sequence;

template <typename I1, typename I2>
container::svector<Index> common_indices(I1 const& idxs1, I2 const& idxs2) {
  using std::back_inserter;
  using std::begin;
  using std::end;
  using std::set_intersection;

  assert(std::is_sorted(begin(idxs1), end(idxs1), Index::LabelCompare{}));
  assert(std::is_sorted(begin(idxs2), end(idxs2), Index::LabelCompare{}));

  container::svector<Index> result;

  set_intersection(begin(idxs1), end(idxs1), begin(idxs2), end(idxs2),
                   back_inserter(result), Index::LabelCompare{});
  return result;

template <typename I1, typename I2>
container::svector<Index> diff_indices(I1 const& idxs1, I2 const& idxs2) {
  using std::back_inserter;
  using std::begin;
  using std::end;
  using std::set_symmetric_difference;

  assert(std::is_sorted(begin(idxs1), end(idxs1), Index::LabelCompare{}));
  assert(std::is_sorted(begin(idxs2), end(idxs2), Index::LabelCompare{}));

  container::svector<Index> result;

  set_symmetric_difference(begin(idxs1), end(idxs1), begin(idxs2), end(idxs2),
                           back_inserter(result), Index::LabelCompare{});
  return result;

template <typename IdxToSz,
          std::enable_if_t<std::is_invocable_r_v<size_t, IdxToSz, Index>,
                           bool> = true>
eval_seq_t single_term_opt(TensorNetwork const& network, IdxToSz const& idxsz) {
  // number of terms
  auto const nt = network.tensors().size();
  if (nt == 1) return eval_seq_t{0};
  if (nt == 2) return eval_seq_t{0, 1, -1};
  auto nth_tensor_indices = container::svector<container::svector<Index>>{};

  for (std::size_t i = 0; i < nt; ++i) {
    auto const& tnsr = *network.tensors().at(i);
    auto bk = container::svector<Index>{};
    bk.reserve(bra_rank(tnsr) + ket_rank(tnsr));
    for (auto&& idx : braket(tnsr)) bk.push_back(idx);

    ranges::sort(bk, Index::LabelCompare{});

  container::svector<OptRes> results((1 << nt), OptRes{{}, 0, {}});

  // power_pos is used, and incremented, only when the
  // result[1<<0]
  // result[1<<1]
  // result[1<<2]
  // and so on are set
  size_t power_pos = 0;
  for (size_t n = 1; n < (1ul << nt); ++n) {
    double curr_cost = std::numeric_limits<double>::max();
    std::pair<size_t, size_t> curr_parts{0, 0};
    container::svector<Index> curr_indices{};

    // function to find the optimal partition
    auto scan_parts = [&curr_cost,                              //
                       &curr_parts,                             //
                       &curr_indices,                           //
                           & results = std::as_const(results),  //
                       &idxsz](                                 //
                          size_t lpart, size_t rpart) {
      auto commons =
          common_indices(results[lpart].indices, results[rpart].indices);
      auto diffs = diff_indices(results[lpart].indices, results[rpart].indices);
      auto new_cost = ops_count(idxsz,           //
                                commons, diffs)  //
                      + results[lpart].flops     //
                      + results[rpart].flops;
      if (new_cost <= curr_cost) {
        curr_cost = new_cost;
        curr_parts = decltype(curr_parts){lpart, rpart};
        curr_indices = std::move(diffs);

    biparts(n, scan_parts);

    auto& curr_result = results[n];
    if (has_single_bit(n)) {
      // evaluation of a single atomic tensor
      curr_result.flops = 0;
      curr_result.indices = std::move(nth_tensor_indices[power_pos]);
      curr_result.sequence = eval_seq_t{static_cast<int>(power_pos++)};
    } else {
      curr_result.flops = curr_cost;
      curr_result.indices = std::move(curr_indices);
      auto const& first = results[curr_parts.first].sequence;
      auto const& second = results[curr_parts.second].sequence;

      curr_result.sequence =
          (first[0] < second[0] ? ranges::views::concat(first, second)
                                : ranges::views::concat(second, first)) |

  return results[(1 << nt) - 1].sequence;

}  // namespace

ExprPtr tail_factor(ExprPtr const& expr) noexcept;

void pull_scalar(sequant::ExprPtr expr) noexcept;

template <typename IdxToSz,
          std::enable_if_t<std::is_invocable_v<IdxToSz, Index>, bool> = true>
ExprPtr single_term_opt(Product const& prod, IdxToSz const& idxsz) {
  using ranges::views::filter;
  using ranges::views::reverse;

  if (prod.factors().size() < 3)
    return ex<Product>(Product{prod.scalar(), prod.factors().begin(),
                               prod.factors().end(), Product::Flatten::No});
  auto const tensors =
      prod | filter(&ExprPtr::template is<Tensor>) | ranges::to_vector;
  auto seq = single_term_opt(TensorNetwork{tensors}, idxsz);
  auto result = container::svector<ExprPtr>{};
  for (auto i : seq)
    if (i == -1) {
      auto rexpr = *result.rbegin();
      auto lexpr = *result.rbegin();
      auto p = Product{1, ExprPtrList{lexpr, rexpr}, Product::Flatten::No};
          1, p.factors().begin(), p.factors().end(), Product::Flatten::No}));
    } else {

  auto& p_ = (*result.rbegin()).as<Product>();
  for (auto&& v : prod | reverse | filter(&Expr::template is<Variable>))
    p_.prepend(1, v, Product::Flatten::No);

  return *result.rbegin();

container::vector<container::vector<size_t>> clusters(Sum const& expr);

Sum reorder(Sum const& sum);

template <typename IdxToSize,
          typename =
              std::enable_if_t<std::is_invocable_r_v<size_t, IdxToSize, Index>>>
ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size) {
  using ranges::views::transform;
  if (expr->is<Tensor>())
    return expr->clone();
  else if (expr->is<Product>())
    return opt::single_term_opt(expr->as<Product>(), idx2size);
  else if (expr->is<Sum>()) {
    auto smands = *expr | transform([&idx2size](auto&& s) {
      return optimize(s, idx2size);
    }) | ranges::to_vector;
    auto sum = Sum{smands.begin(), smands.end()};
    return ex<Sum>(opt::reorder(sum));
  } else
    throw std::runtime_error{"Optimization attempted on unsupported Expr type"};

}  // namespace opt

ExprPtr optimize(ExprPtr const& expr);

}  // namespace sequant