Program Listing for File fusion.cpp

Return to documentation for file (SeQuant/core/optimize/fusion.cpp)

#include <SeQuant/core/optimize/fusion.hpp>

#include <SeQuant/core/complex.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/expr_operator.hpp>

#include <range/v3/algorithm.hpp>
#include <range/v3/iterator.hpp>
#include <range/v3/view.hpp>

namespace sequant::opt {

using ranges::views::drop;
using ranges::views::reverse;
using ranges::views::zip;

// convert a Product of single tensor and scalar == 1 into a tensor exprptr
auto lift_tensor = [](Product const& p) -> ExprPtr {
  return p.scalar() == 1 && p.size() == 1
             ? p.factor(0)
             : ex<Product>(Product{p.scalar(), p.factors().begin(),
                                   p.factors().end(), Product::Flatten::No});
};

Fusion::Fusion(Product const& lhs, Product const& rhs)
    : left_{fuse_left(lhs, rhs)}, right_{fuse_right(lhs, rhs)} {}

ExprPtr Fusion::left() const { return left_; }

ExprPtr Fusion::right() const { return right_; }

ExprPtr Fusion::fuse_left(Product const& lhs, Product const& rhs) {
  auto fac = container::svector<ExprPtr>{};

  for (auto&& [e1, e2] : zip(lhs.factors(), rhs.factors())) {
    if (e1 == e2)
      fac.push_back(e1);
    else
      break;
  }

  if (fac.empty()) return nullptr;

  auto lsmand = lhs.factors() | drop(fac.size());
  auto rsmand = rhs.factors() | drop(fac.size());

  auto fac_prod = Product{fac.begin(), fac.end()};
  auto lsmand_prod = Product{lsmand.begin(), lsmand.end()};
  auto rsmand_prod = Product{rsmand.begin(), rsmand.end()};

  assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() &&
         "Complex valued gcd not supported");
  auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real());

  fac_prod.scale(scalars_fused.at(0));
  lsmand_prod.scale(scalars_fused.at(1));
  rsmand_prod.scale(scalars_fused.at(2));

  // f (a + b)

  auto f = lift_tensor(fac_prod);
  auto a = lift_tensor(lsmand_prod);
  auto b = lift_tensor(rsmand_prod);

  return ex<Product>(ExprPtrList{f, ex<Sum>(ExprPtrList{a, b})});
}

ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) {
  auto fac = container::svector<ExprPtr>{};

  for (auto&& [e1, e2] :
       zip(lhs.factors() | reverse, rhs.factors() | reverse)) {
    if (e1 == e2)
      fac.push_back(e1);
    else
      break;
  }

  if (fac.empty()) return nullptr;

  ranges::reverse(fac);
  auto lsmand = lhs.factors() | reverse | drop(fac.size()) | reverse;
  auto rsmand = rhs.factors() | reverse | drop(fac.size()) | reverse;

  auto fac_prod = Product{fac.begin(), fac.end()};
  auto lsmand_prod = Product{lsmand.begin(), lsmand.end()};
  auto rsmand_prod = Product{rsmand.begin(), rsmand.end()};

  assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() &&
         "Complex valued gcd not supported");
  auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real());

  fac_prod.scale(scalars_fused.at(0));
  lsmand_prod.scale(scalars_fused.at(1));
  rsmand_prod.scale(scalars_fused.at(2));

  // (a + b) f

  auto a = lift_tensor(lsmand_prod);
  auto b = lift_tensor(rsmand_prod);
  auto f = lift_tensor(fac_prod);

  return ex<Product>(ExprPtrList{ex<Sum>(ExprPtrList{a, b}), f});
}

rational Fusion::gcd_rational(rational const& left, rational const& right) {
  auto&& r1 = left.real();
  auto&& r2 = right.real();
  auto&& n1 = numerator(r1);
  auto&& d1 = denominator(r1);
  auto&& n2 = numerator(r2);
  auto&& d2 = denominator(r2);

  auto num = gcd(n1 * d2, n2 * d1);
  return {num, d1 * d2};
}

std::array<rational, 3> Fusion::fuse_scalar(rational const& left,
                                            rational const& right) {
  auto fused = gcd_rational(left, right);
  rational left_fused = left / fused;
  rational right_fused = right / fused;
  if (left < 0 && right < 0) {
    fused *= -1;
    left_fused *= -1;
    right_fused *= -1;
  }
  return {fused, left_fused, right_fused};
}

}  // namespace sequant::opt