Program Listing for File binary_node.hpp¶
↰ Return to documentation for file (SeQuant/core/binary_node.hpp)
#ifndef SEQUANT_BINARY_NODE_HPP
#define SEQUANT_BINARY_NODE_HPP
#include <SeQuant/core/utility/macros.hpp>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <range/v3/numeric/accumulate.hpp>
#include <range/v3/view.hpp>
namespace sequant {
template <typename>
class FullBinaryNode;
namespace meta {
template <typename>
constexpr bool is_full_binary_node{false};
template <typename T>
constexpr bool is_full_binary_node<FullBinaryNode<T>>{true};
} // namespace meta
enum class TreeTraversal {
None = 0,
PreOrder = 0b001,
PostOrder = 0b010,
InOrder = 0b100,
PreAndPostOrder = PreOrder | PostOrder,
PreAndInOrder = PreOrder | InOrder,
PostAndInOrder = PostOrder | InOrder,
Any = PreOrder | PostOrder | InOrder,
};
constexpr TreeTraversal operator|(TreeTraversal lhs, TreeTraversal rhs) {
return static_cast<TreeTraversal>(
static_cast<std::underlying_type_t<TreeTraversal>>(lhs) |
static_cast<std::underlying_type_t<TreeTraversal>>(rhs));
}
constexpr TreeTraversal operator&(TreeTraversal lhs, TreeTraversal rhs) {
return static_cast<TreeTraversal>(
static_cast<std::underlying_type_t<TreeTraversal>>(lhs) &
static_cast<std::underlying_type_t<TreeTraversal>>(rhs));
}
#define TRAVERSAL_TO_TEMPLATE_ARG(order, functionName, functionArgs) \
switch (order) { \
case TreeTraversal::None: \
break; \
case TreeTraversal::PreOrder: \
functionName<TreeTraversal::PreOrder> functionArgs; \
break; \
case TreeTraversal::PostOrder: \
functionName<TreeTraversal::PostOrder> functionArgs; \
break; \
case TreeTraversal::InOrder: \
functionName<TreeTraversal::InOrder> functionArgs; \
break; \
case TreeTraversal::PreAndPostOrder: \
functionName<TreeTraversal::PreAndPostOrder> functionArgs; \
break; \
case TreeTraversal::PreAndInOrder: \
functionName<TreeTraversal::PreAndInOrder> functionArgs; \
break; \
case TreeTraversal::PostAndInOrder: \
functionName<TreeTraversal::PostAndInOrder> functionArgs; \
break; \
case TreeTraversal::Any: \
functionName<TreeTraversal::Any> functionArgs; \
break; \
}
namespace {
struct VisitLeaf {};
struct VisitInternal {};
struct VisitAll {};
template <typename Visitor, typename Node>
bool invoke_tree_visitor(Visitor&& f, Node&& node, TreeTraversal context) {
if constexpr (std::is_invocable_v<Visitor, Node, TreeTraversal>) {
using result_type = std::invoke_result_t<Visitor, Node, TreeTraversal>;
if constexpr (std::is_same_v<result_type, void>) {
std::forward<Visitor>(f)(std::forward<Node>(node), context);
return true;
} else {
return static_cast<bool>(
std::forward<Visitor>(f)(std::forward<Node>(node), context));
}
} else {
static_assert(std::is_invocable_v<Visitor, Node>,
"Visitor must be a callable that takes a FullBinaryNode<T> "
"and optionally a TreeTraversal argument");
using result_type = std::invoke_result_t<Visitor, Node>;
if constexpr (std::is_same_v<result_type, void>) {
std::forward<Visitor>(f)(std::forward<Node>(node));
return true;
} else {
return static_cast<bool>(
std::forward<Visitor>(f)(std::forward<Node>(node)));
}
}
};
template <typename Node>
std::remove_reference_t<Node>* get_parent_ptr(Node&& node) {
return node.root() ? nullptr : &node.parent();
}
template <TreeTraversal order, typename Node, typename Visitor,
typename NodeType>
void visit(Node&& node, Visitor&& f, NodeType) {
static_assert(std::is_same_v<NodeType, VisitLeaf> ||
std::is_same_v<NodeType, VisitInternal> ||
std::is_same_v<NodeType, VisitAll>,
"Not sure which nodes to visit");
std::remove_reference_t<Node>* current_ptr = &node;
const std::remove_cvref_t<Node>* previous_ptr = &node;
// Implementation note: we use a loop rather than recursive function calls as
// the latter implicitly imposes a maximum depth of a tree we can visit
// without triggering a stack overflow that will crash the program.
while (current_ptr) {
const Node& current = *current_ptr;
bool continue_with_subtree = true;
// Note: Invoking the visitor function may change the node object!
// Hence, the need to repeatedly check for whether current might be a leaf
if (current.leaf()) {
// Arrived at a tree's leaf
if constexpr (!std::is_same_v<NodeType, VisitInternal>) {
continue_with_subtree =
invoke_tree_visitor(f, current, TreeTraversal::Any);
}
// Move back up to parent
current_ptr = get_parent_ptr(current);
} else if (previous_ptr == ¤t.left()) {
if constexpr ((order & TreeTraversal::InOrder) ==
TreeTraversal::InOrder &&
!std::is_same_v<NodeType, VisitLeaf>) {
continue_with_subtree =
invoke_tree_visitor(f, current, TreeTraversal::InOrder);
}
// Finished visiting left, now visit right
current_ptr = current.leaf() ? get_parent_ptr(current) : ¤t.right();
} else if (previous_ptr == ¤t.right()) {
if constexpr ((order & TreeTraversal::PostOrder) ==
TreeTraversal::PostOrder &&
!std::is_same_v<NodeType, VisitLeaf>) {
continue_with_subtree =
invoke_tree_visitor(f, current, TreeTraversal::PostOrder);
}
// Finished visiting right, now move back up to parent (if any)
current_ptr = get_parent_ptr(current);
} else {
SEQUANT_ASSERT(current.root() || previous_ptr == ¤t.parent());
if constexpr ((order & TreeTraversal::PreOrder) ==
TreeTraversal::PreOrder &&
!std::is_same_v<NodeType, VisitLeaf>) {
continue_with_subtree =
invoke_tree_visitor(f, current, TreeTraversal::PreOrder);
}
// Coming from parent (or started at root), start by visiting left
current_ptr = current.leaf() ? get_parent_ptr(current) : ¤t.left();
}
previous_ptr = ¤t;
if (!continue_with_subtree) {
// Overwrite to make next target the parent (if any)
current_ptr = get_parent_ptr(current);
}
}
}
} // namespace
template <typename T>
class FullBinaryNode {
public:
using node_ptr = std::unique_ptr<FullBinaryNode<T>>;
using value_type = T;
private:
T data_;
node_ptr left_{nullptr};
node_ptr right_{nullptr};
FullBinaryNode<T>* parent_{nullptr};
node_ptr deep_copy() const {
return leaf() ? std::make_unique<FullBinaryNode<T>>(data_)
: std::make_unique<FullBinaryNode<T>>(
data_, left_->deep_copy(), right_->deep_copy());
}
template <typename Ptr>
static Ptr const& checked_ptr_access(Ptr const& n) {
if (n)
return n;
else
throw std::runtime_error(
"Dereferenced nullptr: use leaf() or root() methods to check for "
"leaf and root nodes");
}
public:
explicit FullBinaryNode(T d) : data_{std::move(d)} {}
FullBinaryNode(T d, T l, T r)
: FullBinaryNode(std::move(d),
std::make_unique<FullBinaryNode>(std::move(l)),
std::make_unique<FullBinaryNode>(std::move(r))) {}
FullBinaryNode(T d, FullBinaryNode<T> l, FullBinaryNode<T> r)
: FullBinaryNode(std::move(d),
std::make_unique<FullBinaryNode<T>>(std::move(l)),
std::make_unique<FullBinaryNode<T>>(std::move(r))) {}
FullBinaryNode(T d, node_ptr&& l, node_ptr&& r)
: data_{std::move(d)}, left_{std::move(l)}, right_{std::move(r)} {
if (left_) {
left_->parent_ = this;
}
if (right_) {
right_->parent_ = this;
}
}
FullBinaryNode(FullBinaryNode<T> const& other)
: data_{other.data_},
left_{other.left_ ? other.left_->deep_copy() : nullptr},
right_{other.right_ ? other.right_->deep_copy() : nullptr},
parent_(nullptr) {
if (left_) {
left_->parent_ = this;
}
if (right_) {
right_->parent_ = this;
}
}
FullBinaryNode& operator=(FullBinaryNode<T> const& other) {
auto temp = other.deep_copy();
data_ = std::move(temp->data_);
left_ = std::move(temp->left_);
right_ = std::move(temp->right_);
if (left_) {
left_->parent_ = this;
}
if (right_) {
right_->parent_ = this;
}
// parent_ remains unchanged
return *this;
}
FullBinaryNode(FullBinaryNode<T>&& other) : data_(other.data_) {
// Note: we explicitly have to initialize data_ in the member initializer
// list as not doing that would impose default-constructibility on T. The
// assumption is that all halfway decent compilers will optimize this extra
// copy away.
*this = std::move(other);
}
FullBinaryNode& operator=(FullBinaryNode<T>&& node) {
data_ = std::move(node.data_);
// We have to save a temporary copy of these, in case the node we're moving
// from is pointed to (and thus owned) by either left_.
// If we don't do this, overwriting of the owning pointer leads to deleting
// node, in which case subsequent accesses to it are invalid.
auto left_tmp = std::move(left_);
left_ = std::move(node.left_);
right_ = std::move(node.right_);
if (left_) {
left_->parent_ = this;
}
if (right_) {
right_->parent_ = this;
}
// parent_ remains unchanged
return *this;
}
FullBinaryNode const& left() const { return *checked_ptr_access(left_); }
FullBinaryNode& left() { return *checked_ptr_access(left_); }
FullBinaryNode const& right() const { return *checked_ptr_access(right_); }
FullBinaryNode& right() { return *checked_ptr_access(right_); }
[[nodiscard]] bool leaf() const { return !(left_ || right_); }
[[nodiscard]] bool root() const { return !parent_; }
[[nodiscard]] const FullBinaryNode<T>& parent() const {
return *checked_ptr_access(parent_);
}
[[nodiscard]] FullBinaryNode<T>& parent() {
return *checked_ptr_access(parent_);
}
[[nodiscard]] std::size_t size() const {
if (leaf()) {
return 1;
}
return left().size() + right().size() + 1;
}
T const& operator*() const { return data_; }
T& operator*() { return data_; }
T const* operator->() const { return &data_; }
T* operator->() { return &data_; }
template <typename Cont, typename F>
FullBinaryNode(Cont const& container, F&& binarize) {
using value_type = decltype(*ranges::begin(container));
static_assert(std::is_invocable_v<F, value_type const&>,
"Binarizer to handle terminal nodes missing");
using return_data_t = std::invoke_result_t<F, value_type const&>;
static_assert(
std::is_invocable_v<F, return_data_t const&, return_data_t const&>,
"Binarizer to handle non-terminal nodes missing");
static_assert(
std::is_same_v<return_data_t,
std::invoke_result_t<F, return_data_t const&,
return_data_t const&>>,
"function(...) and function(..., ...) have different return types");
using ranges::accumulate;
using ranges::begin;
using ranges::end;
auto node =
accumulate(begin(container) + 1, end(container), // range
FullBinaryNode{binarize(*begin(container))}, // init
[&binarize](auto&& acc, const auto& val) { // predicate
auto rnode = FullBinaryNode{binarize(val)};
return FullBinaryNode{binarize(*acc, *rnode),
std::move(acc), std::move(rnode)};
});
*this = std::move(node);
}
template <typename F>
void visit(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) const {
TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
(*this, std::forward<F>(visitor), VisitAll{}));
}
template <typename F>
void visit(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) {
TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
(*this, std::forward<F>(visitor), VisitAll{}));
}
template <typename F>
void visit_internal(F&& visitor,
TreeTraversal order = TreeTraversal::PreOrder) const {
TRAVERSAL_TO_TEMPLATE_ARG(
order, sequant::visit,
(*this, std::forward<F>(visitor), VisitInternal{}));
}
template <typename F>
void visit_internal(F&& visitor,
TreeTraversal order = TreeTraversal::PreOrder) {
TRAVERSAL_TO_TEMPLATE_ARG(
order, sequant::visit,
(*this, std::forward<F>(visitor), VisitInternal{}));
}
template <typename F>
void visit_leaf(F&& visitor,
TreeTraversal order = TreeTraversal::PreOrder) const {
TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
(*this, std::forward<F>(visitor), VisitLeaf{}));
}
template <typename F>
void visit_leaf(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) {
TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
(*this, std::forward<F>(visitor), VisitLeaf{}));
}
private:
template <typename F, typename Os>
[[maybe_unused]] int digraph(F label_gen, Os& os, int count) const {
os << "node" << count << "[label=" << label_gen(*this) << "];\n";
if (this->leaf()) return count;
auto lcount = left().digraph(label_gen, os, count + 1);
auto rcount = right().digraph(label_gen, os, lcount + 1);
os << "node" << count << " -> "
<< "node" << count + 1 << ";\n";
os << "node" << count << " -> "
<< "node" << lcount + 1 << ";\n";
return rcount;
}
template <typename Ostream, typename F, typename G>
void tikz(Ostream& os, F label_gen, G spec_gen, size_t indent = 2) const {
auto pad = [](Ostream& o, size_t i) {
for (size_t j = 0; j < i; ++j) o << " ";
};
// pad(os, indent);
os << "node [" << spec_gen(*this) << "]"
<< " {" << label_gen(*this) << "}";
if (leaf()) return;
os << "\n";
pad(os, indent);
os << "child {";
left().tikz(os, label_gen, spec_gen, indent + 2);
os << "}";
os << "\n";
pad(os, indent);
os << "child {";
right().tikz(os, label_gen, spec_gen, indent + 2);
os << "}";
}
public:
template <typename F,
typename String = std::invoke_result_t<F, FullBinaryNode>>
String digraph(F label_gen,
std::basic_string_view<typename String::value_type>
graph_name = {}) const {
auto oss = std::basic_ostringstream{String{}};
oss << "digraph " << graph_name << "{\n";
this->digraph(label_gen, oss, 0);
oss << "}";
oss.flush();
return oss.str();
}
template <typename L,
typename String = std::invoke_result_t<L, FullBinaryNode>,
typename S = std::function<String(FullBinaryNode)>>
String tikz(
L label_gen, S spec_gen = [](auto&&) { return String{}; }) const {
auto oss = std::basic_ostringstream{String{}};
oss << "\\tikz[binary tree layout]{\n\\";
tikz(oss, label_gen, spec_gen);
oss << "\n}";
oss.flush();
return oss.str();
}
}; // FullBinaryNode<T>
template <typename T, typename U>
bool operator==(FullBinaryNode<T> const& lhs, FullBinaryNode<U> const& rhs) {
return ((*lhs == *rhs) &&
((lhs.leaf() && rhs.leaf()) ||
(lhs.left() == rhs.left() && lhs.right() == rhs.right())));
}
template <typename T, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
auto transform_node(FullBinaryNode<T> const& node, F fun) {
if (node.leaf())
return FullBinaryNode(fun(*node));
else {
return FullBinaryNode(fun(*node), transform_node(node.left(), fun),
transform_node(node.right(), fun));
}
}
template <typename Rng, //
typename F, //
typename Node = ranges::range_value_t<Rng>, //
typename = std::enable_if_t<meta::is_full_binary_node<Node>>>
Node fold_left_to_node(Rng rng, F op) {
using Value = typename Node::value_type;
constexpr bool invoke_on_value = //
std::is_invocable_r_v<Value, F, Value, Value>;
constexpr bool invoke_on_node = //
std::is_invocable_r_v<Value, F, Node, Node>;
static_assert(invoke_on_value || invoke_on_node);
using ranges::size;
SEQUANT_ASSERT(size(rng) > 0);
using ranges::views::move;
using ranges::views::tail;
return ranges::accumulate(
rng | tail | move, std::move(ranges::front(rng)),
[&op](auto&& l, auto&& r) {
if constexpr (invoke_on_node) {
auto&& val = op(l, r);
return FullBinaryNode(std::move(val), std::move(l), std::move(r));
} else {
auto&& val = op(*l, *r);
return FullBinaryNode(std::move(val), std::move(l), std::move(r));
}
});
}
} // namespace sequant
#undef TRAVERSAL_TO_TEMPLATE_ARG
#endif // SEQUANT_BINARY_NODE_HPP