Program Listing for File binary_node.hpp¶
↰ Return to documentation for file (SeQuant/core/binary_node.hpp
)
#ifndef SEQUANT_BINARY_NODE_HPP
#define SEQUANT_BINARY_NODE_HPP
#include <iostream>
#include <memory>
#include <range/v3/numeric/accumulate.hpp>
#include <sstream>
#include <stdexcept>
#include <utility>
namespace sequant {
template <typename>
class FullBinaryNode;
struct PreOrder {};
struct PostOrder {};
struct InOrder {};
namespace {
struct VisitLeaf {};
struct VisitInternal {};
struct VisitAll {};
template <
typename T, typename V, typename Order, typename NodeType,
typename = std::enable_if_t<std::is_invocable_v<V, FullBinaryNode<T>>>>
void visit(FullBinaryNode<T> const& node, V const& f, Order, NodeType) {
static_assert(std::is_same_v<Order, PreOrder> ||
std::is_same_v<Order, InOrder> ||
std::is_same_v<Order, PostOrder>,
"Unsupported visit order");
static_assert(std::is_same_v<NodeType, VisitLeaf> ||
std::is_same_v<NodeType, VisitInternal> ||
std::is_same_v<NodeType, VisitAll>,
"Not sure which nodes to visit");
if (node.leaf()) {
if constexpr (!std::is_same_v<NodeType, VisitInternal>) f(node);
} else {
if constexpr (std::is_same_v<Order, PreOrder> &&
!std::is_same_v<NodeType, VisitLeaf>)
f(node);
visit(node.left(), f, Order{}, NodeType{});
if constexpr (std::is_same_v<Order, InOrder> &&
!std::is_same_v<NodeType, VisitLeaf>)
f(node);
visit(node.right(), f, Order{}, NodeType{});
if constexpr (std::is_same_v<Order, PostOrder> &&
!std::is_same_v<NodeType, VisitLeaf>)
f(node);
}
}
} // namespace
template <typename T>
class FullBinaryNode {
public:
using node_ptr = std::unique_ptr<FullBinaryNode<T>>;
using value_type = T;
private:
T data_;
node_ptr left_{nullptr};
node_ptr right_{nullptr};
node_ptr deep_copy() const {
return leaf() ? std::make_unique<FullBinaryNode<T>>(data_)
: std::make_unique<FullBinaryNode<T>>(
data_, left_->deep_copy(), right_->deep_copy());
}
static node_ptr const& checked_ptr_access(node_ptr const& n) {
if (n)
return n;
else
throw std::runtime_error(
"Dereferenced nullptr: use leaf() method to check leaf node.");
}
public:
explicit FullBinaryNode(T d) : data_{std::move(d)} {}
FullBinaryNode(T d, T l, T r)
: data_{std::move(d)},
left_{std::make_unique<FullBinaryNode>(std::move(l))},
right_{std::make_unique<FullBinaryNode>(std::move(r))} {}
FullBinaryNode(T d, FullBinaryNode<T> l, FullBinaryNode<T> r)
: data_{std::move(d)},
left_{std::make_unique<FullBinaryNode<T>>(std::move(l))},
right_{std::make_unique<FullBinaryNode<T>>(std::move(r))} {}
FullBinaryNode(T d, node_ptr&& l, node_ptr&& r)
: data_{std::move(d)}, left_{std::move(l)}, right_{std::move(r)} {}
FullBinaryNode(FullBinaryNode<T> const& other)
: data_{other.data_},
left_{other.left_ ? other.left_->deep_copy() : nullptr},
right_{other.right_ ? other.right_->deep_copy() : nullptr} {}
FullBinaryNode& operator=(FullBinaryNode<T> const& other) {
auto temp = other.deep_copy();
data_ = std::move(temp->data_);
left_ = std::move(temp->left_);
right_ = std::move(temp->right_);
return *this;
}
FullBinaryNode(FullBinaryNode<T>&&) = default;
FullBinaryNode& operator=(FullBinaryNode<T>&&) = default;
FullBinaryNode const& left() const { return *checked_ptr_access(left_); }
FullBinaryNode const& right() const { return *checked_ptr_access(right_); }
[[nodiscard]] bool leaf() const { return !(left_ || right_); }
T const& operator*() const { return data_; }
T& operator*() { return data_; }
T const* operator->() const { return &data_; }
T* operator->() { return &data_; }
template <typename Cont, typename F>
FullBinaryNode(Cont const& container, F&& binarize) {
using value_type = decltype(*ranges::begin(container));
static_assert(std::is_invocable_v<F, value_type const&>,
"Binarizer to handle terminal nodes missing");
using return_data_t = std::invoke_result_t<F, value_type const&>;
static_assert(
std::is_invocable_v<F, return_data_t const&, return_data_t const&>,
"Binarizer to handle non-terminal nodes missing");
static_assert(
std::is_same_v<return_data_t,
std::invoke_result_t<F, return_data_t const&,
return_data_t const&>>,
"function(...) and function(..., ...) have different return types");
using ranges::accumulate;
using ranges::begin;
using ranges::end;
auto node =
accumulate(begin(container) + 1, end(container), // range
FullBinaryNode{binarize(*begin(container))}, // init
[&binarize](auto&& acc, const auto& val) { // predicate
auto rnode = FullBinaryNode{binarize(val)};
return FullBinaryNode{binarize(*acc, *rnode),
std::move(acc), std::move(rnode)};
});
*this = std::move(node);
}
template <
typename F, typename Order = PostOrder,
std::enable_if_t<
std::is_void_v<std::invoke_result_t<F, FullBinaryNode<T> const&>>,
bool> = true>
void visit(F const& visitor, Order = {}) const {
sequant::visit(*this, //
visitor, //
Order{}, //
VisitAll{});
}
template <
typename F,
std::enable_if_t<std::is_invocable_r_v<bool, F, FullBinaryNode<T> const&>,
bool> = true>
void visit(F const& visitor) const {
if (visitor(*this) && !leaf()) {
left().visit(visitor);
right().visit(visitor);
}
}
template <
typename F,
std::enable_if_t<std::is_invocable_r_v<bool, F, FullBinaryNode<T> const&>,
bool> = true>
void visit_internal(F const& visitor) const {
if (leaf()) return;
if (visitor(*this)) {
left().visit_internal(visitor);
right().visit_internal(visitor);
}
}
template <
typename F, typename Order = PostOrder,
std::enable_if_t<
std::is_void_v<std::invoke_result_t<F, FullBinaryNode<T> const&>>,
bool> = true>
void visit_internal(F const& visitor, Order = {}) const {
sequant::visit(*this, //
visitor, //
Order{}, //
VisitInternal{});
}
template <
typename F, typename Order = PostOrder,
std::enable_if_t<
std::is_void_v<std::invoke_result_t<F, FullBinaryNode<T> const&>>,
bool> = true>
void visit_leaf(F const& visitor, Order = {}) const {
sequant::visit(*this, //
visitor, //
Order{}, //
VisitLeaf{});
}
private:
template <typename Ostream, typename F>
[[maybe_unused]] int digraph(Ostream& os, F const& label_gen,
int count = 0) const {
os << "node" << count << "[label=" << label_gen(*this) << "];\n";
if (this->leaf()) return count;
auto lcount = left().digraph(os, label_gen, count + 1);
auto rcount = right().digraph(os, label_gen, lcount + 1);
os << "node" << count << " -> "
<< "node" << count + 1 << ";\n";
os << "node" << count << " -> "
<< "node" << lcount + 1 << ";\n";
return rcount;
}
template <typename Ostream, typename F, typename G>
void tikz(Ostream& os, F const& label_gen, G const& spec_gen,
size_t indent = 2) const {
auto pad = [](Ostream& o, size_t i) {
for (size_t j = 0; j < i; ++j) o << " ";
};
// pad(os, indent);
os << "node [" << spec_gen(*this) << "]"
<< " {" << label_gen(*this) << "}";
if (leaf()) return;
os << "\n";
pad(os, indent);
os << "child {";
left().tikz(os, label_gen, spec_gen, indent + 2);
os << "}";
os << "\n";
pad(os, indent);
os << "child {";
right().tikz(os, label_gen, spec_gen, indent + 2);
os << "}";
}
public:
template <typename string_t, typename F>
string_t digraph(F const& label_gen, string_t const& graph_name = {}) const {
static_assert(std::is_invocable_r_v<string_t, F, FullBinaryNode<T> const&>,
"node label generator F(FullBinaryNode<T> const &) should "
"return string_t");
auto oss = std::basic_ostringstream{string_t{}};
oss << "digraph " << graph_name << "{\n";
this->digraph(oss, label_gen, 0);
oss << "}";
oss.flush();
return oss.str();
}
template <typename string_t>
string_t tikz(
std::function<string_t(FullBinaryNode<T> const&)> label_gen,
std::function<string_t(FullBinaryNode<T> const&)> spec_gen) const {
auto oss = std::basic_ostringstream{string_t{}};
oss << "\\tikz{\n\\";
tikz(oss, label_gen, spec_gen);
oss << "\n}";
oss.flush();
return oss.str();
}
}; // FullBinaryNode<T>
template <typename T, typename U>
bool operator==(FullBinaryNode<T> const& lhs, FullBinaryNode<U> const& rhs) {
return ((*lhs == *rhs) &&
((lhs.leaf() && rhs.leaf()) ||
(lhs.left() == rhs.left() && lhs.right() == rhs.right())));
}
} // namespace sequant
#endif // SEQUANT_BINARY_NODE_HPP