Program Listing for File expr.cpp¶
↰ Return to documentation for file (SeQuant/core/expressions/expr.cpp)
//
// Created by Eduard Valeyev on 2019-02-06.
//
#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/expressions/abstract_tensor.hpp>
#include <SeQuant/core/expressions/constant.hpp>
#include <SeQuant/core/expressions/expr.hpp>
#include <SeQuant/core/expressions/tensor.hpp>
#include <SeQuant/core/logger.hpp>
#include <SeQuant/core/runtime.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/tensor_network/typedefs.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <range/v3/all.hpp>
#include <thread>
#include <vector>
namespace sequant {
ExprPtr ExprPtr::clone() const & {
if (!*this) return {};
return ExprPtr(as_shared_ptr()->clone());
}
ExprPtr ExprPtr::clone() && noexcept { return std::move(*this); }
ExprPtr::base_type &ExprPtr::as_shared_ptr() & {
return static_cast<base_type &>(*this);
}
const ExprPtr::base_type &ExprPtr::as_shared_ptr() const & {
return static_cast<const base_type &>(*this);
}
ExprPtr::base_type &&ExprPtr::as_shared_ptr() && {
return static_cast<base_type &&>(*this);
}
Expr &ExprPtr::operator*() & {
SEQUANT_ASSERT(this->operator bool());
return *(this->get());
}
const Expr &ExprPtr::operator*() const & {
SEQUANT_ASSERT(this->operator bool());
return *(this->get());
}
Expr &&ExprPtr::operator*() && {
SEQUANT_ASSERT(this->operator bool());
return std::move(*(this->get()));
}
ExprPtr &ExprPtr::operator+=(const ExprPtr &other) {
if (!other) return *this;
if (!*this) {
*this = other.clone();
} else if (as_shared_ptr()->is<Sum>()) {
as_shared_ptr()->operator+=(*other);
} else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
*this = ex<Constant>(this->as<Constant>().value() +
other->as<Constant>().value());
} else {
*this = ex<Sum>(ExprPtrList{*this, other});
}
return *this;
}
ExprPtr &ExprPtr::operator-=(const ExprPtr &other) {
if (!other) return *this;
if (!*this) {
*this = ex<Constant>(-1) * other.clone();
} else if (as_shared_ptr()->is<Sum>()) {
as_shared_ptr()->operator-=(*other);
} else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
*this = ex<Constant>(this->as<Constant>().value() -
other->as<Constant>().value());
} else {
*this = ex<Sum>(ExprPtrList{*this, ex<Product>(-1, ExprPtrList{other})});
}
return *this;
}
ExprPtr &ExprPtr::operator*=(const ExprPtr &other) {
if (!other) return *this;
if (!*this) {
*this = other.clone();
} else if (as_shared_ptr()->is<Product>()) {
as_shared_ptr()->operator*=(*other);
} else if (as_shared_ptr()->is<Constant>() && other->is<Constant>()) {
*this = ex<Constant>(this->as<Constant>().value() *
other->as<Constant>().value());
} else {
*this = ex<Product>(ExprPtrList{*this, other});
}
return *this;
}
std::size_t ExprPtr::size() const { return this->get()->size(); }
std::wstring ExprPtr::to_latex() const { return as_shared_ptr()->to_latex(); }
std::logic_error Expr::not_implemented(const char *fn) const {
std::ostringstream oss;
oss << "Expr::" << fn
<< " not implemented in this derived class (type_name=" << type_name()
<< ")";
return std::logic_error(oss.str().c_str());
}
std::wstring Expr::to_latex() const { throw not_implemented("to_latex"); }
std::wstring Expr::to_wolfram() const { throw not_implemented("to_wolfram"); }
ExprPtr Expr::clone() const { throw not_implemented("clone"); }
void Expr::adjoint() { throw not_implemented("adjoint"); }
Expr &Expr::operator*=(const Expr &) { throw not_implemented("operator*="); }
Expr &Expr::operator^=(const Expr &) { throw not_implemented("operator^="); }
Expr &Expr::operator+=(const Expr &) { throw not_implemented("operator+="); }
Expr &Expr::operator-=(const Expr &) { throw not_implemented("operator-="); }
ExprPtr adjoint(const ExprPtr &expr) {
auto result = expr->clone();
result->adjoint();
return result;
}
void Constant::adjoint() {
value_ = conj(value_);
reset_hash_value();
}
std::wstring_view Variable::label() const { return label_; }
void Variable::set_label(std::wstring label) {
label_ = std::move(label);
reset_hash_value();
}
void Variable::conjugate() { conjugated_ = !conjugated_; }
bool Variable::conjugated() const { return conjugated_; }
std::wstring Variable::to_latex() const {
std::wstring result = L"{" + utf_to_latex(label_) + L"}";
if (conjugated_) result = L"{" + result + L"^*" + L"}";
return result;
}
ExprPtr Variable::clone() const { return ex<Variable>(*this); }
void Variable::adjoint() { conjugate(); }
bool Product::is_commutative() const {
bool result = true;
const auto nfactors = size();
for (size_t f = 0; f != nfactors; ++f) {
for (size_t s = 1; result && s != nfactors; ++s) {
result &= factors_[f]->commutes_with(*factors_[s]);
}
}
return result;
}
ExprPtr Product::canonicalize_impl(CanonicalizeOptions opts) {
// recursively canonicalize non-tensor subfactors (tensors will be
// canonicalized as part of the TN built of all tensor factors of this) ...
ranges::for_each(factors_, [this, opts](auto &factor) {
if (factor.template is<AbstractTensor>()) {
return;
}
auto bp = factor->canonicalize(opts);
if (bp) {
SEQUANT_ASSERT(bp->template is<Constant>());
this->scalar_ *= std::static_pointer_cast<Constant>(bp)->value();
}
});
if (Logger::instance().canonicalize) {
std::wcout << "Product canonicalization(" << to_wstring(opts.method)
<< ") input: " << to_latex() << std::endl;
}
// pull out all Variables to the front
auto variables = factors_ | ranges::views::filter([](const auto &factor) {
return factor.template is<Variable>();
}) |
ranges::to_vector;
// sort variables
ranges::sort(variables, [](const auto &first, const auto &second) {
return first.template as<Variable>().label() <
second.template as<Variable>().label();
});
factors_ = factors_ | ranges::views::filter([](const auto &factor) {
return !factor.template is<Variable>();
}) |
ranges::to<decltype(factors_)>;
// if there are no factors, insert variables back and return
if (factors_.empty()) {
factors_.insert(factors_.begin(), variables.begin(), variables.end());
return {};
}
auto contains_nontensors = ranges::any_of(factors_, [](const auto &factor) {
return std::dynamic_pointer_cast<AbstractTensor>(factor) == nullptr;
});
if (!contains_nontensors) { // tensor network canonization is a special case
// that's done in
// TensorNetwork
auto make_canonical_tn = [this, &opts](auto *tn_null_ptr) {
using TN = std::decay_t<std::remove_pointer_t<decltype(tn_null_ptr)>>;
ExprPtr canon_factor;
TN tn(this->factors_);
if constexpr (TN::version() == 3) {
canon_factor = tn.canonicalize(
TensorCanonicalizer::cardinal_tensor_labels(), opts);
} else {
using NamedIndexSet = tensor_network::NamedIndexSet;
std::shared_ptr<NamedIndexSet> named_indices =
!opts.named_indices
? nullptr
: std::make_shared<NamedIndexSet>(opts.named_indices->begin(),
opts.named_indices->end());
canon_factor = tn.canonicalize(
TensorCanonicalizer::cardinal_tensor_labels(),
opts.method == CanonicalizationMethod::Rapid, named_indices.get());
}
return std::pair{std::move(tn), canon_factor};
};
using TN = TensorNetwork;
auto [tn, canon_factor] = make_canonical_tn(static_cast<TN *>(nullptr));
const auto &tensors = tn.tensors();
using std::size;
SEQUANT_ASSERT(size(tensors) == size(factors_));
using std::begin;
using std::end;
std::transform(begin(tensors), end(tensors), begin(factors_),
[](const auto &tptr) {
auto exprptr = std::dynamic_pointer_cast<Expr>(tptr);
SEQUANT_ASSERT(exprptr);
return exprptr;
});
if (canon_factor) scalar_ *= canon_factor->template as<Constant>().value();
this->reset_hash_value();
} else { // if contains non-tensors, do commutation-checking resort
// comparer that respects cardinal tensor labels
auto &cardinal_tensor_labels =
TensorCanonicalizer::cardinal_tensor_labels();
auto local_compare = [&cardinal_tensor_labels](const ExprPtr &first,
const ExprPtr &second) {
if (first->is<Labeled>() && second->is<Labeled>()) {
const auto first_label = first->as<Labeled>().label();
const auto second_label = second->as<Labeled>().label();
if (first_label == second_label) return *first < *second;
const auto first_is_cardinal_it = ranges::find_if(
cardinal_tensor_labels,
[&first_label](const std::wstring &l) { return l == first_label; });
const auto first_is_cardinal =
first_is_cardinal_it != ranges::end(cardinal_tensor_labels);
const auto second_is_cardinal_it = ranges::find_if(
cardinal_tensor_labels, [&second_label](const std::wstring &l) {
return l == second_label;
});
const auto second_is_cardinal =
second_is_cardinal_it != ranges::end(cardinal_tensor_labels);
if (first_is_cardinal && second_is_cardinal)
return first_is_cardinal_it < second_is_cardinal_it;
else if (first_is_cardinal && !second_is_cardinal)
return true;
else if (!first_is_cardinal && second_is_cardinal)
return false;
else {
SEQUANT_ASSERT(!first_is_cardinal && !second_is_cardinal);
return *first < *second;
}
} else
return *first < *second;
};
// ... then resort, respecting commutativity
using std::begin;
using std::end;
if (static_commutativity()) {
if (is_commutative()) {
std::stable_sort(begin(factors_), end(factors_), local_compare);
}
} else {
// must do bubble sort if not commuting to avoid swapping elements across
// a noncommuting element
bubble_sort(
begin(factors_), end(factors_),
[&local_compare](const ExprPtr &first, const ExprPtr &second) {
bool result = (first->commutes_with(*second))
? local_compare(first, second)
: false;
return result;
});
}
}
// reinsert Variables at the front
factors_.insert(factors_.begin(), variables.begin(), variables.end());
// TODO evaluate product of Tensors (turn this into Products of Products)
if (Logger::instance().canonicalize)
std::wcout << "Product canonicalization(" << to_wstring(opts.method)
<< ") result: " << to_latex() << std::endl;
return {}; // side effects are absorbed into the scalar_
}
void Product::adjoint() {
SEQUANT_ASSERT(static_commutativity() == false); // assert no slicing
auto adj_scalar = conj(scalar());
using namespace ranges;
auto adj_factors =
factors() | views::reverse |
views::transform([](auto &expr) { return ::sequant::adjoint(expr); });
using std::swap;
*this =
Product(adj_scalar, ranges::begin(adj_factors), ranges::end(adj_factors));
}
ExprPtr Product::canonicalize(CanonicalizeOptions opt) {
return this->canonicalize_impl(opt);
}
ExprPtr Product::rapid_canonicalize(CanonicalizeOptions opt) {
SEQUANT_ASSERT(opt.method == CanonicalizationMethod::Rapid);
return this->canonicalize_impl(opt);
}
void CProduct::adjoint() {
auto adj_scalar = conj(scalar());
using namespace ranges;
// no need to reverse for commutative product
auto adj_factors = factors() | views::transform([](auto &&expr) {
return ::sequant::adjoint(expr);
});
*this = CProduct(adj_scalar, ranges::begin(adj_factors),
ranges::end(adj_factors));
}
void NCProduct::adjoint() {
auto adj_scalar = conj(scalar());
using namespace ranges;
// no need to reverse for commutative product
auto adj_factors =
factors() | views::reverse |
views::transform([](auto &&expr) { return ::sequant::adjoint(expr); });
*this = NCProduct(adj_scalar, ranges::begin(adj_factors),
ranges::end(adj_factors));
}
void Sum::adjoint() {
using namespace ranges;
auto adj_summands = summands() | views::transform([](auto &&expr) {
return ::sequant::adjoint(expr);
});
*this = Sum(ranges::begin(adj_summands), ranges::end(adj_summands));
}
ExprPtr Sum::canonicalize_impl(bool multipass, CanonicalizeOptions opts) {
if (Logger::instance().canonicalize)
std::wcout << "Sum::canonicalize_impl: input = "
<< to_latex_align(shared_from_this()) << std::endl;
const auto npasses = multipass ? 2 : 1;
for (auto pass = 0; pass != npasses; ++pass) {
const auto rapid = (pass % 2 == 0);
// canonicalizing TNs in a sum requires treating named indices as
// meaningful/distinct
auto opts_copy = opts;
opts_copy.ignore_named_index_labels =
CanonicalizeOptions::IgnoreNamedIndexLabel::No;
if (rapid) {
opts_copy.method = CanonicalizationMethod::Lexicographic;
} else
opts_copy.method = opts.method | CanonicalizationMethod::Topological;
// recursively canonicalize summands ...
// using for_each and direct access to summands
sequant::for_each(summands_, [pass, &opts_copy, &rapid](ExprPtr &summand) {
ExprPtr bp;
if (rapid) {
bp = summand->rapid_canonicalize(opts_copy);
} else {
bp = summand->canonicalize(opts_copy);
}
if (bp) {
SEQUANT_ASSERT(bp->template is<Constant>());
summand = ex<Product>(std::static_pointer_cast<Constant>(bp)->value(),
ExprPtrList{summand});
}
});
if (Logger::instance().canonicalize)
std::wcout << "Sum::canonicalize_impl (pass=" << pass
<< "): after canonicalizing summands = "
<< to_latex_align(shared_from_this()) << std::endl;
HashingAccumulator acc;
for (auto &summand : summands_) {
acc.append(summand);
}
// last pass? sort by hash then by Expr::operator<
// N.B. no point in differentiating between canonicalization methods here
// since need to sort in both cases
auto new_sum =
(pass == npasses - 1) ? acc.make_canonicalized_sum() : acc.make_sum();
this->swap(*new_sum);
if (Logger::instance().canonicalize)
std::wcout << "Sum::canonicalize_impl (pass=" << pass
<< "): after reducing summands = "
<< to_latex_align(shared_from_this()) << std::endl;
}
return {}; // side effects are absorbed into summands
}
HashingAccumulator &HashingAccumulator::append(ExprPtr summand, bool flatten) {
// flatten, if needed
if (flatten && summand.is<Sum>()) {
for (auto &subsummand : summand.as<Sum>().summands()) {
this->append(subsummand, flatten);
}
return *this;
}
// process summand as a whole
auto it = summands_.find(summand);
if (it == summands_.end()) {
summands_.emplace(summand);
} else { // found existing term with the same hash
auto existing_summand = *it;
if (summand.template is<Product>()) {
if (existing_summand.is<Product>()) {
// both are products - add them
existing_summand.as<Product>().add_identical(
summand.template as<Product>());
} else {
// convert existing term to product and add
auto product_copy = std::make_shared<Product>(summand->clone());
product_copy->add_identical(existing_summand);
summands_.erase(it);
summands_.emplace(std::move(product_copy));
}
} else {
if (existing_summand.is<Product>()) {
existing_summand.as<Product>().add_identical(summand);
} else {
// neither is a product - create new product
auto product_form = std::make_shared<Product>();
product_form->append(2, summand.template as<Expr>());
summands_.erase(it);
summands_.emplace(std::move(product_form));
}
}
}
return *this;
}
SumPtr HashingAccumulator::make_sum_impl(bool canonicalize) {
Sum::summands_type summands;
summands.reserve(summands_.size());
for (auto summand : summands_) {
if (!summand->is_zero()) {
summands.push_back(summand);
}
}
if (canonicalize) {
ranges::sort(summands, [](const auto &e1, const auto &e2) {
if (e1->hash_value() == e2->hash_value()) {
return e1 < e2;
} else {
return e1->hash_value() < e2->hash_value();
}
});
}
return std::make_shared<Sum>(std::move(summands), Sum::move_only_tag{});
}
SumPtr HashingAccumulator::make_sum() { return make_sum_impl(false); }
SumPtr HashingAccumulator::make_canonicalized_sum() {
return make_sum_impl(true);
}
ExprPtr HashingAccumulator::make_expr(bool canonicalize) {
if (summands_.size() == 0) {
return ex<Constant>(0);
} else if (summands_.size() == 1)
return *(summands_.begin());
else
return make_sum_impl(canonicalize);
}
bool proportional_to::operator()(const ExprPtr &expr1,
const ExprPtr &expr2) const {
if (expr1->type_id() !=
expr2->type_id()) { // if expr1 is a Product with single factor == expr2,
// or vice versa
if (expr1.is<Product>()) {
return expr1.as<Product>().factors().size() == 1 &&
expr1.as<Product>().factors().front() == expr2;
} else if (expr2.is<Product>()) {
return expr2.as<Product>().factors().size() == 1 &&
expr2.as<Product>().factors().front() == expr1;
} else
return false;
}
// expr1 and expr2 are same type
if (expr1.is<Constant>()) {
return true;
}
if (expr1.is<Product>()) {
return expr1->hash_value() == expr2->hash_value() &&
expr1.as<Product>().factors() == expr2.as<Product>().factors();
}
return expr1 == expr2;
}
} // namespace sequant