Program Listing for File expr.cpp¶
↰ Return to documentation for file (SeQuant/core/utility/expr.cpp)
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/utility/expr.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/core/utility/string.hpp>
#include <range/v3/view/concat.hpp>
#include <algorithm>
#include <bitset>
#include <climits>
#include <optional>
#include <sstream>
#include <string>
namespace sequant {
// Top-level diff means a diff of the object instance itself without
// regard for any contained subexpressions
template <typename T>
std::string to_string(const Complex<T> &c) {
std::stringstream stream;
if (c.imag() == 0) {
stream << c.real();
} else if (c.real() == 0) {
stream << c.imag() << "*i";
} else if (c.imag() < 0) {
stream << "(" << c.real() << " - " << (-c.imag()) << "*i)";
} else {
stream << "(" << c.real() << " + " << c.imag() << "*i)";
}
return stream.str();
}
std::string toplevel_diff(const Constant &lhs, const Constant &rhs) {
if (lhs == rhs) {
return {};
}
return to_string(lhs.value()) + " vs. " + to_string(rhs.value());
}
std::string toplevel_diff(const Variable &lhs, const Variable &rhs) {
if (lhs == rhs) {
return {};
}
if (lhs.label() != rhs.label()) {
return to_string(lhs.label()) + " vs. " + to_string(rhs.label());
}
return (lhs.conjugated() ? "conjugated"
: "non-conjugated" + std::string(" vs. ")) +
(rhs.conjugated() ? "conjugated" : "non-conjugated");
}
std::string toplevel_diff(const Index &lhs, const Index &rhs);
template <typename LRange, typename RRange>
std::string diff_indices(const LRange &lhs, const RRange &rhs) {
auto lhs_size = std::distance(std::begin(lhs), std::end(lhs));
auto rhs_size = std::distance(std::begin(rhs), std::end(rhs));
if (lhs_size != rhs_size) {
return std::to_string(lhs_size) + " indices vs. " +
std::to_string(rhs_size) + " indices";
}
auto lhs_it = std::begin(lhs);
auto rhs_it = std::begin(rhs);
std::string diff;
for (std::size_t i = 0; i < static_cast<std::size_t>(lhs_size); ++i) {
const Index &lhs_idx = *lhs_it++;
const Index &rhs_idx = *rhs_it++;
std::string subdiff = toplevel_diff(lhs_idx, rhs_idx);
if (subdiff.empty()) {
continue;
}
if (!diff.empty()) {
diff += ", ";
}
diff += "#" + std::to_string(i + 1) + ": " + subdiff;
}
return diff;
}
std::string diff_spaces(const IndexSpace &lhs, const IndexSpace &rhs) {
if (lhs == rhs) {
return {};
}
const auto &lhs_attrs = lhs.attr();
const auto &rhs_attrs = rhs.attr();
std::stringstream stream;
using AttrSet = std::bitset<sizeof(std::uint32_t) * CHAR_BIT>;
if (lhs_attrs.type() != rhs_attrs.type()) {
stream << "Types differ: " << AttrSet(lhs.type().to_int32()) << " vs. "
<< AttrSet(rhs.type().to_int32());
} else if (lhs_attrs.qns() != rhs_attrs.qns()) {
stream << "QNs differ: " << AttrSet(lhs.qns().to_int32()) << " vs. "
<< AttrSet(rhs.qns().to_int32());
} else if (lhs.base_key() != rhs.base_key()) {
stream << "Base key differs: " << to_string(lhs.base_key()) << " vs. "
<< to_string(rhs.base_key());
} else if (lhs.approximate_size() != rhs.approximate_size()) {
stream << "Size differs: " << std::to_string(lhs.approximate_size())
<< " vs. " << std::to_string(rhs.approximate_size());
} else {
SEQUANT_UNREACHABLE;
}
SEQUANT_ASSERT(!stream.str().empty());
return stream.str();
}
std::string toplevel_diff(const Index &lhs, const Index &rhs) {
if (lhs == rhs) {
return {};
}
if (lhs.full_label() != rhs.full_label()) {
return to_string(lhs.full_label()) + " vs. " + to_string(rhs.full_label());
}
if (lhs.space() != rhs.space()) {
// No string representation of spaces, unfortunately
return "Spaces differ: " + diff_spaces(lhs.space(), rhs.space());
}
if (lhs.has_proto_indices() != rhs.has_proto_indices()) {
return (lhs.has_proto_indices() ? "with" : "without") +
std::string(" vs. ") +
(rhs.has_proto_indices() ? "with" : "without") + " proto-indices";
}
if (lhs.proto_indices() != rhs.proto_indices()) {
return "Proto indices differ: " +
diff_indices(lhs.proto_indices(), rhs.proto_indices());
}
if (lhs.tag() != rhs.tag()) {
return "Different tags";
}
// We have run out of ideas of what to check
SEQUANT_ABORT("Unexpected difference between indices");
}
std::string toplevel_diff(const Tensor &lhs, const Tensor &rhs) {
if (lhs == rhs) {
return {};
}
if (lhs.label() != rhs.label()) {
return "Names differ: " + to_string(lhs.label()) + " vs. " +
to_string(rhs.label());
}
if (lhs.slots().size() != rhs.slots().size()) {
return std::to_string(lhs.slots().size()) + " indices vs. " +
std::to_string(rhs.slots().size()) + " indices";
}
if (lhs.symmetry() != rhs.symmetry()) {
return "Symmetry differs: " + to_string(to_wstring(lhs.symmetry())) +
" vs. " + to_string(to_wstring(rhs.symmetry()));
}
if (lhs.column_symmetry() != rhs.column_symmetry()) {
return "Particle-Symmetry differs: " +
to_string(to_wstring(lhs.column_symmetry())) + " vs. " +
to_string(to_wstring(rhs.column_symmetry()));
}
if (lhs.braket_symmetry() != rhs.braket_symmetry()) {
return "BraKet-Symmetry differs: " +
to_string(to_wstring(lhs.braket_symmetry())) + " vs. " +
to_string(to_wstring(rhs.braket_symmetry()));
}
if (lhs.bra() != rhs.bra()) {
return "Bra indices differ: " + diff_indices(lhs.bra(), rhs.bra());
}
if (lhs.ket() != rhs.ket()) {
return "Ket indices differ: " + diff_indices(lhs.bra(), rhs.bra());
}
if (lhs.aux() != rhs.aux()) {
return "Aux indices differ: " + diff_indices(lhs.ket(), rhs.ket());
}
// Really, this shouldn't produce an empty diff as the objects compare as
// non-equal but we have run out of ideas of what to check
SEQUANT_ABORT("Unhandled difference between tensors");
}
std::string toplevel_diff(const Sum & /*lhs*/, const Sum & /*rhs*/) {
// There is no way two Sum objects can be different on the top-level
return {};
}
std::string toplevel_diff(const Product &lhs, const Product &rhs) {
if (lhs.scalar() != rhs.scalar()) {
return "Prefactor differs: " +
toplevel_diff(Constant(lhs.scalar()), Constant(rhs.scalar()));
}
return {};
}
std::string diff(const Expr &lhs, const Expr &rhs) {
if (lhs == rhs) {
return {};
}
if (lhs.type_id() != rhs.type_id()) {
return std::string("Types differ: ") + typeid(lhs).name() + " (" +
std::to_string(lhs.type_id()) + " vs. " + typeid(rhs).name() +
std::to_string(rhs.type_id());
}
auto lhs_begin = std::begin(lhs);
auto lhs_end = std::end(lhs);
auto rhs_begin = std::begin(rhs);
[[maybe_unused]] auto rhs_end = std::end(rhs);
auto lhs_size = std::distance(lhs_begin, lhs_end);
auto rhs_size = std::distance(lhs_begin, lhs_end);
if (lhs_size != rhs_size) {
return "Sizes differ: " + std::to_string(lhs_size) + " vs. " +
std::to_string(rhs_size);
}
std::string diff_str;
for (std::size_t i = 0; i < static_cast<std::size_t>(lhs_size); ++i) {
const Expr &lhs_nested = *(*lhs_begin++);
const Expr &rhs_nested = *(*rhs_begin++);
std::string nested_diff = diff(lhs_nested, rhs_nested);
if (nested_diff.empty()) {
continue;
}
if (diff_str.empty()) {
diff_str += "Subexpression diff begin:\n";
}
diff_str +=
"Sub-Expr #" + std::to_string(i + 1) + ":\n" + nested_diff + "\n";
}
if (!diff_str.empty()) {
diff_str += "Subexpression diff end";
return diff_str;
}
if (lhs.is<Sum>()) {
diff_str = toplevel_diff(lhs.as<Sum>(), rhs.as<Sum>());
} else if (lhs.is<Product>()) {
diff_str = toplevel_diff(lhs.as<Product>(), rhs.as<Product>());
} else if (lhs.is<Tensor>()) {
diff_str = toplevel_diff(lhs.as<Tensor>(), rhs.as<Tensor>());
} else if (lhs.is<Constant>()) {
diff_str = toplevel_diff(lhs.as<Constant>(), rhs.as<Constant>());
} else if (lhs.is<Variable>()) {
diff_str = toplevel_diff(lhs.as<Variable>(), rhs.as<Variable>());
} else {
SEQUANT_ABORT("Unhandled expression type");
}
return diff_str;
}
#define SEQUANT_EXPR_INVALID(message) \
if (msg) { \
*msg = message; \
} \
return false;
bool is_valid(const ExprPtr &expr, std::string *msg) {
if (!expr) {
SEQUANT_EXPR_INVALID("Expression is null");
}
return is_valid(*expr, msg);
}
bool is_valid(const Expr &expr, std::string *msg) {
if (!expr.is_atom()) {
// Validate children first
for (const ExprPtr ¤t : expr) {
if (!is_valid(current, msg)) {
return false;
}
}
}
if (expr.is<Variable>()) {
// Nothing to validate
} else if (expr.is<Constant>()) {
const Constant &c = expr.as<Constant>();
if (denominator(c.value().real()) == 0) {
SEQUANT_EXPR_INVALID("Denominator of real part of constant is zero");
}
if (denominator(c.value().imag()) == 0) {
SEQUANT_EXPR_INVALID("Denominator of imaginary part of constant is zero");
}
} else if (expr.is<Tensor>()) {
// Nothing to validate
} else if (expr.is<Product>()) {
const Product &prod = expr.as<Product>();
auto factor = prod.scalar();
if (denominator(factor.real()) == 0) {
SEQUANT_EXPR_INVALID(
"Denominator of real part of product factor is zero");
}
if (denominator(factor.imag()) == 0) {
SEQUANT_EXPR_INVALID(
"Denominator of imaginary part of product factor is zero");
}
// Check that indices don't appear more than 2 times
container::map<Index, std::size_t> index_counter;
for (const ExprPtr &factor : prod.factors()) {
IndexGroups<> indices = get_unique_indices(*factor);
for (const Index &idx :
ranges::views::concat(indices.bra, indices.ket, indices.aux)) {
index_counter[idx] += 1;
}
}
for (const auto &[idx, count] : index_counter) {
if (count > 2) {
SEQUANT_EXPR_INVALID("Index " + toUtf8(idx.full_label()) +
" appears more than 2 times");
}
}
} else if (expr.is<Sum>()) {
// Verify that all summands have the same external indices
const Sum &sum = expr.as<Sum>();
auto extractor = [](const ExprPtr &expr) {
return get_unique_indices(expr);
};
const IndexGroups<> ref = extractor(sum.summand(0));
auto compare = [&ref](const IndexGroups<> &grps) {
return std::ranges::is_permutation(ref.bra, grps.bra) &&
std::ranges::is_permutation(ref.ket, grps.ket) &&
std::ranges::is_permutation(ref.aux, grps.aux);
};
bool consistent = std::ranges::all_of(sum.summands(), compare, extractor);
if (!consistent) {
SEQUANT_EXPR_INVALID("Inconsistent external indices in sum");
}
} else {
SEQUANT_ASSERT(false, "Unsupported expression type in is_valid");
}
return true;
}
bool is_valid(const ResultExpr &expr, std::string *msg) {
if (!is_valid(expr.expression(), msg)) {
return false;
}
// We need to make sure to remove any symmetrizers from the expression in
// order to not mess up the determination of external indices
ExprPtr rhs = expr.expression().clone();
pop_tensor(rhs, L"A");
pop_tensor(rhs, L"S");
IndexGroups<> externals = get_unique_indices(rhs);
if (!std::ranges::is_permutation(expr.bra(), externals.bra)) {
SEQUANT_EXPR_INVALID(
"Bra indices of result are inconsistent with the rhs expression");
}
if (!std::ranges::is_permutation(expr.ket(), externals.ket)) {
SEQUANT_EXPR_INVALID(
"Ket indices of result are inconsistent with the rhs expression");
}
if (!std::ranges::is_permutation(expr.aux(), externals.aux)) {
SEQUANT_EXPR_INVALID(
"Aux indices of result are inconsistent with the rhs expression");
}
// TODO: check whether specified symmetries of result are fulfilled in the rhs
// expression
return true;
}
#undef SEQUANT_EXPR_INVALID
ExprPtr transform_expr(const ExprPtr &expr,
const container::map<Index, Index> &index_replacements,
Constant::scalar_type scaling_factor) {
if (expr->is<Constant>() || expr->is<Variable>()) {
if (scaling_factor != 1) {
return ex<Constant>(scaling_factor) * expr;
}
return expr;
}
auto transform_tensor =
[&index_replacements](const ExprPtr &tensor) -> ExprPtr {
ExprPtr result = tensor.clone();
auto &result_tensor = result->as<AbstractTensor>();
transform_indices(result_tensor, index_replacements);
reset_tags(result_tensor);
return result;
};
auto transform_product = [&transform_tensor,
&scaling_factor](const Product &product) {
auto result = std::make_shared<Product>();
result->scale(product.scalar());
for (auto &&term : product) {
if (term->is<AbstractTensor>()) {
result->append(1, transform_tensor(term));
} else if (term->is<Variable>() || term->is<Constant>()) {
result->append(1, term->clone());
} else {
throw std::runtime_error("Invalid Expr type in transform_product");
}
}
result->scale(scaling_factor);
return result;
};
if (expr->is<AbstractTensor>()) {
auto result = transform_tensor(expr);
if (scaling_factor != 1) {
result = result * ex<Constant>(scaling_factor);
}
return result;
} else if (expr->is<Product>()) {
auto result = transform_product(expr->as<Product>());
return result;
} else if (expr->is<Sum>()) {
auto result = std::make_shared<Sum>();
for (auto &term : *expr) {
result->append(transform_expr(term, index_replacements, scaling_factor));
}
return result;
} else {
throw std::runtime_error("Invalid Expr type in transform_expr");
}
}
std::optional<ExprPtr> pop_tensor(ExprPtr &expression,
std::wstring_view label) {
std::optional<ExprPtr> tensor;
if (expression->is<Sum>()) {
Sum result{};
for (ExprPtr &term : expression.as<Sum>()) {
std::optional<ExprPtr> popped = pop_tensor(term, label);
if (!tensor.has_value()) {
tensor = popped;
}
SEQUANT_ASSERT(tensor == popped);
result.append(std::move(term));
}
expression.as<Sum>() = std::move(result);
return tensor;
}
if (expression->is<Product>()) {
Product result;
result.scale(expression.as<Product>().scalar());
for (ExprPtr &factor : expression.as<Product>().factors()) {
std::optional<ExprPtr> popped = pop_tensor(factor, label);
if (!tensor.has_value()) {
tensor = popped;
}
SEQUANT_ASSERT(!popped.has_value() || tensor == popped);
if (!factor.is<Constant>() || !factor.as<Constant>().is_zero()) {
result.append(1, std::move(factor), Product::Flatten::No);
}
}
if (result.size() > 1 || (result.size() == 1 && result.scalar() != 1)) {
expression.as<Product>() = std::move(result);
} else if (result.size() == 1) {
expression = std::move(result.factor(0));
} else {
expression = ex<Constant>(0);
}
return tensor;
}
if (expression->is<Tensor>()) {
if (expression.as<Tensor>().label() == label) {
tensor = expression;
expression = ex<Constant>(0);
}
return tensor;
}
if (expression->is<Constant>() || expression->is<Variable>()) {
return tensor;
}
throw std::runtime_error("Unhandled expression type in pop_tensor");
}
} // namespace sequant