Program Listing for File optimize.cpp¶
↰ Return to documentation for file (SeQuant/core/optimize/optimize.cpp
)
#include <SeQuant/core/binary_node.hpp>
#include <SeQuant/core/complex.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/eval_expr.hpp>
#include <SeQuant/core/eval_node.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/optimize.hpp>
#include <range/v3/iterator/basic_iterator.hpp>
#include <range/v3/range/access.hpp>
#include <range/v3/view/tail.hpp>
#include <range/v3/view/transform.hpp>
#include <range/v3/view/view.hpp>
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <memory>
#include <stack>
#include <utility>
#include <vector>
namespace sequant {
class Tensor;
namespace opt {
ExprPtr tail_factor(ExprPtr const& expr) noexcept {
if (expr->is<Tensor>())
return expr->clone();
else if (expr->is<Product>()) {
auto scalar = expr->as<Product>().scalar();
if (scalar == 1 && expr->size() == 2) {
// product with
// -single factor that is a tensor
// -scalar is just 1
// will not be formed because of this block
return expr->at(1);
}
auto facs = ranges::views::tail(*expr);
return ex<Product>(Product{scalar, ranges::begin(facs), ranges::end(facs)});
} else {
// sum
auto summands = *expr | ranges::views::transform(
[](auto const& x) { return tail_factor(x); });
return ex<Sum>(Sum{ranges::begin(summands), ranges::end(summands)});
}
}
void pull_scalar(ExprPtr expr) noexcept {
if (!expr->is<Product>()) return;
auto& prod = expr->as<Product>();
auto scal = prod.scalar();
for (auto&& x : *expr)
if (x->is<Product>()) {
auto& p = x->as<Product>();
scal *= p.scalar();
p.scale(1 / p.scalar());
}
prod.scale(1 / prod.scalar());
prod.scale(scal);
}
bool has_only_single_atom(const ExprPtr& term) {
if (term->is_atom()) {
return true;
}
// Recursively check that all elements in the expression tree have only a
// single element in them. At this point this means checking for Sum or
// Product objects that only have a single addend or factor respectively.
return term->size() == 1 && has_only_single_atom(*term->begin());
}
container::vector<container::vector<size_t>> clusters(Sum const& expr) {
using ranges::views::tail;
using ranges::views::transform;
using hash_t = size_t;
using pos_t = size_t;
using stack_t = std::stack<pos_t, container::vector<pos_t>>;
container::map<hash_t, container::set<pos_t>> positions;
{
pos_t pos = 0;
auto visitor = [&positions, &pos](auto const& n) {
auto h = hash::value(*n);
if (auto&& found = positions.find(h); found != positions.end()) {
found->second.emplace(pos);
} else {
positions.emplace(h, decltype(positions)::mapped_type{pos});
}
};
for (auto const& term : expr) {
auto const node = eval_node<EvalExpr>(term);
if (has_only_single_atom(term)) {
visitor(node);
} else {
node.visit_internal(visitor);
}
++pos;
}
}
container::map<pos_t, container::vector<pos_t>> connections;
{
for (auto const& [_, v] : positions) {
auto const v0 = ranges::front(v);
auto const v_ = ranges::views::tail(v) |
ranges::to<decltype(connections)::mapped_type>;
if (auto&& found = connections.find(v0); found != connections.end())
for (auto p : v_) found->second.push_back(p);
else
connections.emplace(v0, v_);
}
}
positions.clear();
container::vector<container::vector<pos_t>> result;
{
container::set<pos_t> visited;
for (auto k : connections | ranges::views::keys)
if (!visited.contains(k)) {
stack_t dfs_stack;
dfs_stack.push(k);
container::vector<pos_t> clstr;
while (!dfs_stack.empty()) {
auto p = dfs_stack.top();
dfs_stack.pop();
if (!visited.contains(p)) {
clstr.push_back(p);
visited.emplace(p);
}
if (auto&& found = connections.find(p); found != connections.end())
for (auto p_ : ranges::views::reverse(found->second))
dfs_stack.push(p_);
}
result.emplace_back(std::move(clstr));
}
}
return result;
}
Sum reorder(Sum const& sum) {
Sum result;
for (auto const& clstr : clusters(sum))
for (auto p : clstr) result.append(sum.at(p));
assert(result.size() == sum.size());
return result;
}
} // namespace opt
ExprPtr optimize(ExprPtr const& expr) {
return opt::optimize(
expr, [](Index const& ix) { return ix.space().approximate_size(); });
}
} // namespace sequant