Program Listing for File expr.cpp

Return to documentation for file (SeQuant/core/expr.cpp)

//
// Created by Eduard Valeyev on 2019-02-06.
//

#include <SeQuant/core/abstract_tensor.hpp>
#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/tensor.hpp>
#include <SeQuant/core/tensor_network.hpp>

#include <range/v3/all.hpp>

#include <thread>
#include <vector>

namespace sequant {

ExprPtr ExprPtr::clone() const & {
  if (!*this) return {};
  return ExprPtr(as_shared_ptr()->clone());
}

ExprPtr ExprPtr::clone() && noexcept { return std::move(*this); }

ExprPtr::base_type &ExprPtr::as_shared_ptr() & {
  return static_cast<base_type &>(*this);
}
const ExprPtr::base_type &ExprPtr::as_shared_ptr() const & {
  return static_cast<const base_type &>(*this);
}
ExprPtr::base_type &&ExprPtr::as_shared_ptr() && {
  return static_cast<base_type &&>(*this);
}

Expr &ExprPtr::operator*() & {
  assert(this->operator bool());
  return *(this->get());
}

const Expr &ExprPtr::operator*() const & {
  assert(this->operator bool());
  return *(this->get());
}

Expr &&ExprPtr::operator*() && {
  assert(this->operator bool());
  return std::move(*(this->get()));
}

ExprPtr &ExprPtr::operator+=(const ExprPtr &other) {
  if (!*this) {
    *this = other.clone();
  } else if (as_shared_ptr()->is<Sum>()) {
    as_shared_ptr()->operator+=(*other);
  } else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
    *this = ex<Constant>(this->as<Constant>().value() +
                         other->as<Constant>().value());
  } else {
    *this = ex<Sum>(ExprPtrList{*this, other});
  }
  return *this;
}

ExprPtr &ExprPtr::operator-=(const ExprPtr &other) {
  if (!*this) {
    *this = ex<Constant>(-1) * other.clone();
  } else if (as_shared_ptr()->is<Sum>()) {
    as_shared_ptr()->operator-=(*other);
  } else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
    *this = ex<Constant>(this->as<Constant>().value() -
                         other->as<Constant>().value());
  } else {
    *this = ex<Sum>(ExprPtrList{*this, ex<Product>(-1, ExprPtrList{other})});
  }
  return *this;
}

ExprPtr &ExprPtr::operator*=(const ExprPtr &other) {
  if (!*this) {
    *this = other.clone();
  } else if (as_shared_ptr()->is<Product>()) {
    as_shared_ptr()->operator*=(*other);
  } else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
    *this = ex<Constant>(this->as<Constant>().value() *
                         other->as<Constant>().value());
  } else {
    *this = ex<Product>(ExprPtrList{*this, other});
  }
  return *this;
}

std::wstring ExprPtr::to_latex() const { return as_shared_ptr()->to_latex(); }

std::logic_error Expr::not_implemented(const char *fn) const {
  std::ostringstream oss;
  oss << "Expr::" << fn
      << " not implemented in this derived class (type_name=" << type_name()
      << ")";
  return std::logic_error(oss.str().c_str());
}

std::wstring Expr::to_latex() const { throw not_implemented("to_latex"); }

std::wstring Expr::to_wolfram() const { throw not_implemented("to_wolfram"); }

ExprPtr Expr::clone() const { throw not_implemented("clone"); }

void Expr::adjoint() { throw not_implemented("adjoint"); }

Expr &Expr::operator*=(const Expr &that) {
  throw not_implemented("operator*=");
}

Expr &Expr::operator^=(const Expr &that) {
  throw not_implemented("operator^=");
}

Expr &Expr::operator+=(const Expr &that) {
  throw not_implemented("operator+=");
}

Expr &Expr::operator-=(const Expr &that) {
  throw not_implemented("operator-=");
}

ExprPtr adjoint(const ExprPtr &expr) {
  auto result = expr->clone();
  result->adjoint();
  return result;
}

void Constant::adjoint() {
  value_ = conj(value_);
  reset_hash_value();
}

std::wstring_view Variable::label() const { return label_; }

void Variable::conjugate() { conjugated_ = !conjugated_; }

bool Variable::conjugated() const { return conjugated_; }

std::wstring Variable::to_latex() const {
  std::wstring result = L"{" + utf_to_latex(label_) + L"}";
  if (conjugated_) result = L"{" + result + L"^*" + L"}";
  return result;
}

ExprPtr Variable::clone() const { return ex<Variable>(*this); }

void Variable::adjoint() { conjugate(); }

bool Product::is_commutative() const {
  bool result = true;
  const auto nfactors = size();
  for (size_t f = 0; f != nfactors; ++f) {
    for (size_t s = 1; result && s != nfactors; ++s) {
      result &= factors_[f]->commutes_with(*factors_[s]);
    }
  }
  return result;
}

ExprPtr Product::canonicalize_impl(bool rapid) {
  // recursively canonicalize subfactors ...
  ranges::for_each(factors_, [this](auto &factor) {
    auto bp = factor->canonicalize();
    if (bp) {
      assert(bp->template is<Constant>());
      this->scalar_ *= std::static_pointer_cast<Constant>(bp)->value();
    }
  });

  if (Logger::instance().canonicalize) {
    std::wcout << "Product canonicalization(" << (rapid ? "fast" : "slow")
               << ") input: " << to_latex() << std::endl;
  }

  // pull out all Variables to the front
  auto variables = factors_ | ranges::views::filter([](const auto &factor) {
                     return factor.template is<Variable>();
                   }) |
                   ranges::to_vector;
  factors_ = factors_ | ranges::views::filter([](const auto &factor) {
               return !factor.template is<Variable>();
             }) |
             ranges::to<decltype(factors_)>;

  auto contains_nontensors = ranges::any_of(factors_, [](const auto &factor) {
    return std::dynamic_pointer_cast<AbstractTensor>(factor) == nullptr;
  });
  if (!contains_nontensors) {  // tensor network canonization is a special case
                               // that's done in
                               // TensorNetwork
    TensorNetwork tn(factors_);
    auto canon_factor =
        tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), rapid);
    const auto &tensors = tn.tensors();
    using std::size;
    assert(size(tensors) == size(factors_));
    using std::begin;
    using std::end;
    std::transform(begin(tensors), end(tensors), begin(factors_),
                   [](const auto &tptr) {
                     auto exprptr = std::dynamic_pointer_cast<Expr>(tptr);
                     assert(exprptr);
                     return exprptr;
                   });
    if (canon_factor) scalar_ *= canon_factor->as<Constant>().value();
    this->reset_hash_value();
  } else {  // if contains non-tensors, do commutation-checking resort

    // comparer that respects cardinal tensor labels
    auto &cardinal_tensor_labels =
        TensorCanonicalizer::cardinal_tensor_labels();
    auto local_compare = [&cardinal_tensor_labels](const ExprPtr &first,
                                                   const ExprPtr &second) {
      if (first->is<Labeled>() && second->is<Labeled>()) {
        const auto first_label = first->as<Tensor>().label();
        const auto second_label = second->as<Tensor>().label();
        if (first_label == second_label) return *first < *second;
        const auto first_is_cardinal_it = ranges::find_if(
            cardinal_tensor_labels,
            [&first_label](const std::wstring &l) { return l == first_label; });
        const auto first_is_cardinal =
            first_is_cardinal_it != ranges::end(cardinal_tensor_labels);
        const auto second_is_cardinal_it = ranges::find_if(
            cardinal_tensor_labels, [&second_label](const std::wstring &l) {
              return l == second_label;
            });
        const auto second_is_cardinal =
            second_is_cardinal_it != ranges::end(cardinal_tensor_labels);
        if (first_is_cardinal && second_is_cardinal)
          return first_is_cardinal_it < second_is_cardinal_it;
        else if (first_is_cardinal && !second_is_cardinal)
          return true;
        else if (!first_is_cardinal && second_is_cardinal)
          return false;
        else {
          assert(!first_is_cardinal && !second_is_cardinal);
          return *first < *second;
        }
      } else
        return *first < *second;
    };

    // ... then resort, respecting commutativity
    using std::begin;
    using std::end;
    if (static_commutativity()) {
      if (is_commutative()) {
        std::stable_sort(begin(factors_), end(factors_), local_compare);
      }
    } else {
      // must do bubble sort if not commuting to avoid swapping elements across
      // a noncommuting element
      bubble_sort(
          begin(factors_), end(factors_),
          [&local_compare](const ExprPtr &first, const ExprPtr &second) {
            bool result = (first->commutes_with(*second))
                              ? local_compare(first, second)
                              : false;
            return result;
          });
    }
  }

  // sort and reinsert Variables at the front
  ranges::sort(variables, [](const auto &first, const auto &second) {
    return first.template as<Variable>().label() <
           second.template as<Variable>().label();
  });
  factors_.insert(factors_.begin(), variables.begin(), variables.end());

  // TODO evaluate product of Tensors (turn this into Products of Products)

  if (Logger::instance().canonicalize)
    std::wcout << "Product canonicalization(" << (rapid ? "fast" : "slow")
               << ") result: " << to_latex() << std::endl;

  return {};  // side effects are absorbed into the scalar_
}

void Product::adjoint() {
  assert(static_commutativity() == false);  // assert no slicing
  auto adj_scalar = conj(scalar());
  using namespace ranges;
  auto adj_factors =
      factors() | views::reverse |
      views::transform([](auto &expr) { return ::sequant::adjoint(expr); });
  using std::swap;
  *this =
      Product(adj_scalar, ranges::begin(adj_factors), ranges::end(adj_factors));
}

ExprPtr Product::canonicalize() {
  return this->canonicalize_impl(/* rapid = */ false);
}

ExprPtr Product::rapid_canonicalize() {
  return this->canonicalize_impl(/* rapid = */ true);
}

void CProduct::adjoint() {
  auto adj_scalar = conj(scalar());
  using namespace ranges;
  // no need to reverse for commutative product
  auto adj_factors = factors() | views::transform([](auto &&expr) {
                       return ::sequant::adjoint(expr);
                     });
  *this = CProduct(adj_scalar, ranges::begin(adj_factors),
                   ranges::end(adj_factors));
}

void NCProduct::adjoint() {
  auto adj_scalar = conj(scalar());
  using namespace ranges;
  // no need to reverse for commutative product
  auto adj_factors =
      factors() | views::reverse |
      views::transform([](auto &&expr) { return ::sequant::adjoint(expr); });
  *this = NCProduct(adj_scalar, ranges::begin(adj_factors),
                    ranges::end(adj_factors));
}

void Sum::adjoint() {
  using namespace ranges;
  auto adj_summands = summands() | views::transform([](auto &&expr) {
                        return ::sequant::adjoint(expr);
                      });
  *this = Sum(ranges::begin(adj_summands), ranges::end(adj_summands));
}

ExprPtr Sum::canonicalize_impl(bool multipass) {
  if (Logger::instance().canonicalize)
    std::wcout << "Sum::canonicalize_impl: input = "
               << to_latex_align(shared_from_this()) << std::endl;

  const auto npasses = multipass ? 3 : 1;
  for (auto pass = 0; pass != npasses; ++pass) {
    // recursively canonicalize summands ...
    const auto nsubexpr = ranges::size(*this);
    for (std::size_t i = 0; i != nsubexpr; ++i) {
      auto bp = (pass % 2 == 0) ? summands_[i]->rapid_canonicalize()
                                : summands_[i]->canonicalize();
      if (bp) {
        assert(bp->template is<Constant>());
        summands_[i] =
            ex<Product>(std::static_pointer_cast<Constant>(bp)->value(),
                        ExprPtrList{summands_[i]});
      }
    };

    if (Logger::instance().canonicalize)
      std::wcout << "Sum::canonicalize_impl (pass=" << pass
                 << "): after canonicalizing summands = "
                 << to_latex_align(shared_from_this()) << std::endl;

    // ... then resort according to size, then hash values
    using std::begin;
    using std::end;
    std::stable_sort(begin(summands_), end(summands_),
                     [](const auto &first, const auto &second) {
                       const auto first_size = sequant::size(first);
                       const auto second_size = sequant::size(second);

                       return (first_size == second_size)
                                  ? *first < *second
                                  : first_size < second_size;
                     });

    if (Logger::instance().canonicalize)
      std::wcout << "Sum::canonicalize_impl (pass=" << pass
                 << "): after hash-sorting summands = "
                 << to_latex_align(shared_from_this()) << std::endl;

    // ... then reduce terms whose hash values are identical
    auto first_it = begin(summands_);
    auto hash_comparer = [](const auto &first, const auto &second) {
      return first->hash_value() == second->hash_value();
    };
    while ((first_it = std::adjacent_find(first_it, end(summands_),
                                          hash_comparer)) != end(summands_)) {
      assert((*first_it)->hash_value() == (*(first_it + 1))->hash_value());
      // find first element whose hash is not equal to (*first_it)->hash_value()
      auto plast_it = std::find_if_not(
          first_it + 1, end(summands_), [first_it](const auto &elem) {
            return (*first_it)->hash_value() == elem->hash_value();
          });
      const auto nidentical = plast_it - first_it;
      assert(nidentical > 1);
      auto reduce_range = [first_it, this, nidentical](auto &begin, auto &end) {
        if ((*first_it)->template is<Tensor>()) {
          Product tensor_as_Product{};
          tensor_as_Product.append(nidentical, (*first_it)->as<Tensor>());
          (*first_it) = std::make_shared<Product>(tensor_as_Product);
          this->summands_.erase(first_it + 1, end);
        } else if ((*first_it)->template is<Product>()) {
          auto &prod = (*first_it)->template as<Product>();
          for (auto it = begin + 1; it != end; ++it) {
            if ((*it)->template is<Tensor>()) {
              Product tensor_as_Product{};
              tensor_as_Product.append(1, (*it)->template as<Tensor>());
              (*it) = std::make_shared<Product>(tensor_as_Product);
            }
            if ((*it)->template is<Product>()) {
              prod.add_identical((*it)->template as<Product>());
            }
          }
          auto summands_to_erase = std::pair{first_it + 1, end};
          if (prod.is_zero()) summands_to_erase.first = first_it;
          this->summands_.erase(summands_to_erase.first,
                                summands_to_erase.second);
        }
      };
      reduce_range(first_it, plast_it);
    }

    if (Logger::instance().canonicalize)
      std::wcout << "Sum::canonicalize_impl (pass=" << pass
                 << "): after reducing summands = "
                 << to_latex_align(shared_from_this()) << std::endl;
  }

  return {};  // side effects are absorbed into summands
}

}  // namespace sequant