Program Listing for File binary_node.hpp

Return to documentation for file (SeQuant/core/binary_node.hpp)

#ifndef SEQUANT_BINARY_NODE_HPP
#define SEQUANT_BINARY_NODE_HPP

#include <SeQuant/core/utility/macros.hpp>

#include <memory>
#include <sstream>
#include <stdexcept>
#include <type_traits>
#include <utility>

#include <range/v3/numeric/accumulate.hpp>
#include <range/v3/view.hpp>

namespace sequant {

template <typename>
class FullBinaryNode;

namespace meta {

template <typename>
constexpr bool is_full_binary_node{false};

template <typename T>
constexpr bool is_full_binary_node<FullBinaryNode<T>>{true};

}  // namespace meta

enum class TreeTraversal {
  None = 0,
  PreOrder = 0b001,
  PostOrder = 0b010,
  InOrder = 0b100,

  PreAndPostOrder = PreOrder | PostOrder,
  PreAndInOrder = PreOrder | InOrder,
  PostAndInOrder = PostOrder | InOrder,

  Any = PreOrder | PostOrder | InOrder,
};

constexpr TreeTraversal operator|(TreeTraversal lhs, TreeTraversal rhs) {
  return static_cast<TreeTraversal>(
      static_cast<std::underlying_type_t<TreeTraversal>>(lhs) |
      static_cast<std::underlying_type_t<TreeTraversal>>(rhs));
}

constexpr TreeTraversal operator&(TreeTraversal lhs, TreeTraversal rhs) {
  return static_cast<TreeTraversal>(
      static_cast<std::underlying_type_t<TreeTraversal>>(lhs) &
      static_cast<std::underlying_type_t<TreeTraversal>>(rhs));
}

#define TRAVERSAL_TO_TEMPLATE_ARG(order, functionName, functionArgs) \
  switch (order) {                                                   \
    case TreeTraversal::None:                                        \
      break;                                                         \
    case TreeTraversal::PreOrder:                                    \
      functionName<TreeTraversal::PreOrder> functionArgs;            \
      break;                                                         \
    case TreeTraversal::PostOrder:                                   \
      functionName<TreeTraversal::PostOrder> functionArgs;           \
      break;                                                         \
    case TreeTraversal::InOrder:                                     \
      functionName<TreeTraversal::InOrder> functionArgs;             \
      break;                                                         \
    case TreeTraversal::PreAndPostOrder:                             \
      functionName<TreeTraversal::PreAndPostOrder> functionArgs;     \
      break;                                                         \
    case TreeTraversal::PreAndInOrder:                               \
      functionName<TreeTraversal::PreAndInOrder> functionArgs;       \
      break;                                                         \
    case TreeTraversal::PostAndInOrder:                              \
      functionName<TreeTraversal::PostAndInOrder> functionArgs;      \
      break;                                                         \
    case TreeTraversal::Any:                                         \
      functionName<TreeTraversal::Any> functionArgs;                 \
      break;                                                         \
  }

namespace {

struct VisitLeaf {};

struct VisitInternal {};

struct VisitAll {};

template <typename Visitor, typename Node>
bool invoke_tree_visitor(Visitor&& f, Node&& node, TreeTraversal context) {
  if constexpr (std::is_invocable_v<Visitor, Node, TreeTraversal>) {
    using result_type = std::invoke_result_t<Visitor, Node, TreeTraversal>;
    if constexpr (std::is_same_v<result_type, void>) {
      std::forward<Visitor>(f)(std::forward<Node>(node), context);
      return true;
    } else {
      return static_cast<bool>(
          std::forward<Visitor>(f)(std::forward<Node>(node), context));
    }
  } else {
    static_assert(std::is_invocable_v<Visitor, Node>,
                  "Visitor must be a callable that takes a FullBinaryNode<T> "
                  "and optionally a TreeTraversal argument");
    using result_type = std::invoke_result_t<Visitor, Node>;
    if constexpr (std::is_same_v<result_type, void>) {
      std::forward<Visitor>(f)(std::forward<Node>(node));
      return true;
    } else {
      return static_cast<bool>(
          std::forward<Visitor>(f)(std::forward<Node>(node)));
    }
  }
};

template <typename Node>
std::remove_reference_t<Node>* get_parent_ptr(Node&& node) {
  return node.root() ? nullptr : &node.parent();
}

template <TreeTraversal order, typename Node, typename Visitor,
          typename NodeType>
void visit(Node&& node, Visitor&& f, NodeType) {
  static_assert(std::is_same_v<NodeType, VisitLeaf> ||
                    std::is_same_v<NodeType, VisitInternal> ||
                    std::is_same_v<NodeType, VisitAll>,
                "Not sure which nodes to visit");

  std::remove_reference_t<Node>* current_ptr = &node;
  const std::remove_cvref_t<Node>* previous_ptr = &node;

  // Implementation note: we use a loop rather than recursive function calls as
  // the latter implicitly imposes a maximum depth of a tree we can visit
  // without triggering a stack overflow that will crash the program.
  while (current_ptr) {
    const Node& current = *current_ptr;

    bool continue_with_subtree = true;

    // Note: Invoking the visitor function may change the node object!
    // Hence, the need to repeatedly check for whether current might be a leaf

    if (current.leaf()) {
      // Arrived at a tree's leaf
      if constexpr (!std::is_same_v<NodeType, VisitInternal>) {
        continue_with_subtree =
            invoke_tree_visitor(f, current, TreeTraversal::Any);
      }

      // Move back up to parent
      current_ptr = get_parent_ptr(current);
    } else if (previous_ptr == &current.left()) {
      if constexpr ((order & TreeTraversal::InOrder) ==
                        TreeTraversal::InOrder &&
                    !std::is_same_v<NodeType, VisitLeaf>) {
        continue_with_subtree =
            invoke_tree_visitor(f, current, TreeTraversal::InOrder);
      }

      // Finished visiting left, now visit right
      current_ptr = current.leaf() ? get_parent_ptr(current) : &current.right();
    } else if (previous_ptr == &current.right()) {
      if constexpr ((order & TreeTraversal::PostOrder) ==
                        TreeTraversal::PostOrder &&
                    !std::is_same_v<NodeType, VisitLeaf>) {
        continue_with_subtree =
            invoke_tree_visitor(f, current, TreeTraversal::PostOrder);
      }

      // Finished visiting right, now move back up to parent (if any)
      current_ptr = get_parent_ptr(current);
    } else {
      SEQUANT_ASSERT(current.root() || previous_ptr == &current.parent());
      if constexpr ((order & TreeTraversal::PreOrder) ==
                        TreeTraversal::PreOrder &&
                    !std::is_same_v<NodeType, VisitLeaf>) {
        continue_with_subtree =
            invoke_tree_visitor(f, current, TreeTraversal::PreOrder);
      }

      // Coming from parent (or started at root), start by visiting left
      current_ptr = current.leaf() ? get_parent_ptr(current) : &current.left();
    }

    previous_ptr = &current;

    if (!continue_with_subtree) {
      // Overwrite to make next target the parent (if any)
      current_ptr = get_parent_ptr(current);
    }
  }
}

}  // namespace

template <typename T>
class FullBinaryNode {
 public:
  using node_ptr = std::unique_ptr<FullBinaryNode<T>>;
  using value_type = T;

 private:
  T data_;

  node_ptr left_{nullptr};

  node_ptr right_{nullptr};

  FullBinaryNode<T>* parent_{nullptr};

  node_ptr deep_copy() const {
    return leaf() ? std::make_unique<FullBinaryNode<T>>(data_)
                  : std::make_unique<FullBinaryNode<T>>(
                        data_, left_->deep_copy(), right_->deep_copy());
  }

  template <typename Ptr>
  static Ptr const& checked_ptr_access(Ptr const& n) {
    if (n)
      return n;
    else
      throw std::runtime_error(
          "Dereferenced nullptr: use leaf() or root() methods to check for "
          "leaf and root nodes");
  }

 public:
  explicit FullBinaryNode(T d) : data_{std::move(d)} {}

  FullBinaryNode(T d, T l, T r)
      : FullBinaryNode(std::move(d),
                       std::make_unique<FullBinaryNode>(std::move(l)),
                       std::make_unique<FullBinaryNode>(std::move(r))) {}

  FullBinaryNode(T d, FullBinaryNode<T> l, FullBinaryNode<T> r)
      : FullBinaryNode(std::move(d),
                       std::make_unique<FullBinaryNode<T>>(std::move(l)),
                       std::make_unique<FullBinaryNode<T>>(std::move(r))) {}

  FullBinaryNode(T d, node_ptr&& l, node_ptr&& r)
      : data_{std::move(d)}, left_{std::move(l)}, right_{std::move(r)} {
    if (left_) {
      left_->parent_ = this;
    }
    if (right_) {
      right_->parent_ = this;
    }
  }

  FullBinaryNode(FullBinaryNode<T> const& other)
      : data_{other.data_},
        left_{other.left_ ? other.left_->deep_copy() : nullptr},
        right_{other.right_ ? other.right_->deep_copy() : nullptr},
        parent_(nullptr) {
    if (left_) {
      left_->parent_ = this;
    }
    if (right_) {
      right_->parent_ = this;
    }
  }

  FullBinaryNode& operator=(FullBinaryNode<T> const& other) {
    auto temp = other.deep_copy();
    data_ = std::move(temp->data_);
    left_ = std::move(temp->left_);
    right_ = std::move(temp->right_);
    if (left_) {
      left_->parent_ = this;
    }
    if (right_) {
      right_->parent_ = this;
    }
    // parent_ remains unchanged
    return *this;
  }

  FullBinaryNode(FullBinaryNode<T>&& other) : data_(other.data_) {
    // Note: we explicitly have to initialize data_ in the member initializer
    // list as not doing that would impose default-constructibility on T. The
    // assumption is that all halfway decent compilers will optimize this extra
    // copy away.
    *this = std::move(other);
  }

  FullBinaryNode& operator=(FullBinaryNode<T>&& node) {
    data_ = std::move(node.data_);

    // We have to save a temporary copy of these, in case the node we're moving
    // from is pointed to (and thus owned) by either left_.
    // If we don't do this, overwriting of the owning pointer leads to deleting
    // node, in which case subsequent accesses to it are invalid.
    auto left_tmp = std::move(left_);

    left_ = std::move(node.left_);
    right_ = std::move(node.right_);

    if (left_) {
      left_->parent_ = this;
    }
    if (right_) {
      right_->parent_ = this;
    }

    // parent_ remains unchanged

    return *this;
  }

  FullBinaryNode const& left() const { return *checked_ptr_access(left_); }
  FullBinaryNode& left() { return *checked_ptr_access(left_); }

  FullBinaryNode const& right() const { return *checked_ptr_access(right_); }
  FullBinaryNode& right() { return *checked_ptr_access(right_); }

  [[nodiscard]] bool leaf() const { return !(left_ || right_); }

  [[nodiscard]] bool root() const { return !parent_; }

  [[nodiscard]] const FullBinaryNode<T>& parent() const {
    return *checked_ptr_access(parent_);
  }

  [[nodiscard]] FullBinaryNode<T>& parent() {
    return *checked_ptr_access(parent_);
  }

  [[nodiscard]] std::size_t size() const {
    if (leaf()) {
      return 1;
    }

    return left().size() + right().size() + 1;
  }

  T const& operator*() const { return data_; }

  T& operator*() { return data_; }

  T const* operator->() const { return &data_; }

  T* operator->() { return &data_; }

  template <typename Cont, typename F>
  FullBinaryNode(Cont const& container, F&& binarize) {
    using value_type = decltype(*ranges::begin(container));
    static_assert(std::is_invocable_v<F, value_type const&>,
                  "Binarizer to handle terminal nodes missing");

    using return_data_t = std::invoke_result_t<F, value_type const&>;

    static_assert(
        std::is_invocable_v<F, return_data_t const&, return_data_t const&>,
        "Binarizer to handle non-terminal nodes missing");

    static_assert(
        std::is_same_v<return_data_t,
                       std::invoke_result_t<F, return_data_t const&,
                                            return_data_t const&>>,
        "function(...) and function(..., ...) have different return types");

    using ranges::accumulate;
    using ranges::begin;
    using ranges::end;

    auto node =
        accumulate(begin(container) + 1, end(container),         // range
                   FullBinaryNode{binarize(*begin(container))},  // init
                   [&binarize](auto&& acc, const auto& val) {    // predicate
                     auto rnode = FullBinaryNode{binarize(val)};
                     return FullBinaryNode{binarize(*acc, *rnode),
                                           std::move(acc), std::move(rnode)};
                   });

    *this = std::move(node);
  }

  template <typename F>
  void visit(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) const {
    TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
                              (*this, std::forward<F>(visitor), VisitAll{}));
  }

  template <typename F>
  void visit(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) {
    TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
                              (*this, std::forward<F>(visitor), VisitAll{}));
  }

  template <typename F>
  void visit_internal(F&& visitor,
                      TreeTraversal order = TreeTraversal::PreOrder) const {
    TRAVERSAL_TO_TEMPLATE_ARG(
        order, sequant::visit,
        (*this, std::forward<F>(visitor), VisitInternal{}));
  }

  template <typename F>
  void visit_internal(F&& visitor,
                      TreeTraversal order = TreeTraversal::PreOrder) {
    TRAVERSAL_TO_TEMPLATE_ARG(
        order, sequant::visit,
        (*this, std::forward<F>(visitor), VisitInternal{}));
  }

  template <typename F>
  void visit_leaf(F&& visitor,
                  TreeTraversal order = TreeTraversal::PreOrder) const {
    TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
                              (*this, std::forward<F>(visitor), VisitLeaf{}));
  }

  template <typename F>
  void visit_leaf(F&& visitor, TreeTraversal order = TreeTraversal::PreOrder) {
    TRAVERSAL_TO_TEMPLATE_ARG(order, sequant::visit,
                              (*this, std::forward<F>(visitor), VisitLeaf{}));
  }

 private:
  template <typename F, typename Os>
  [[maybe_unused]] int digraph(F label_gen, Os& os, int count) const {
    os << "node" << count << "[label=" << label_gen(*this) << "];\n";
    if (this->leaf()) return count;
    auto lcount = left().digraph(label_gen, os, count + 1);
    auto rcount = right().digraph(label_gen, os, lcount + 1);
    os << "node" << count << " -> "
       << "node" << count + 1 << ";\n";
    os << "node" << count << " -> "
       << "node" << lcount + 1 << ";\n";

    return rcount;
  }

  template <typename Ostream, typename F, typename G>
  void tikz(Ostream& os, F label_gen, G spec_gen, size_t indent = 2) const {
    auto pad = [](Ostream& o, size_t i) {
      for (size_t j = 0; j < i; ++j) o << " ";
    };

    // pad(os, indent);

    os << "node [" << spec_gen(*this) << "]"
       << " {" << label_gen(*this) << "}";
    if (leaf()) return;
    os << "\n";

    pad(os, indent);
    os << "child {";
    left().tikz(os, label_gen, spec_gen, indent + 2);
    os << "}";
    os << "\n";

    pad(os, indent);
    os << "child {";
    right().tikz(os, label_gen, spec_gen, indent + 2);
    os << "}";
  }

 public:
  template <typename F,
            typename String = std::invoke_result_t<F, FullBinaryNode>>
  String digraph(F label_gen,
                 std::basic_string_view<typename String::value_type>
                     graph_name = {}) const {
    auto oss = std::basic_ostringstream{String{}};

    oss << "digraph " << graph_name << "{\n";
    this->digraph(label_gen, oss, 0);
    oss << "}";
    oss.flush();

    return oss.str();
  }


  template <typename L,
            typename String = std::invoke_result_t<L, FullBinaryNode>,
            typename S = std::function<String(FullBinaryNode)>>
  String tikz(
      L label_gen, S spec_gen = [](auto&&) { return String{}; }) const {
    auto oss = std::basic_ostringstream{String{}};
    oss << "\\tikz[binary tree layout]{\n\\";
    tikz(oss, label_gen, spec_gen);
    oss << "\n}";
    oss.flush();
    return oss.str();
  }

};  // FullBinaryNode<T>

template <typename T, typename U>
bool operator==(FullBinaryNode<T> const& lhs, FullBinaryNode<U> const& rhs) {
  return ((*lhs == *rhs) &&
          ((lhs.leaf() && rhs.leaf()) ||
           (lhs.left() == rhs.left() && lhs.right() == rhs.right())));
}

template <typename T, typename F,
          typename = std::enable_if_t<std::is_invocable_v<F, T>>>
auto transform_node(FullBinaryNode<T> const& node, F fun) {
  if (node.leaf())
    return FullBinaryNode(fun(*node));
  else {
    return FullBinaryNode(fun(*node), transform_node(node.left(), fun),
                          transform_node(node.right(), fun));
  }
}

template <typename Rng,                                //
          typename F,                                  //
          typename Node = ranges::range_value_t<Rng>,  //
          typename = std::enable_if_t<meta::is_full_binary_node<Node>>>
Node fold_left_to_node(Rng rng, F op) {
  using Value = typename Node::value_type;

  constexpr bool invoke_on_value =  //
      std::is_invocable_r_v<Value, F, Value, Value>;

  constexpr bool invoke_on_node =  //
      std::is_invocable_r_v<Value, F, Node, Node>;

  static_assert(invoke_on_value || invoke_on_node);

  using ranges::size;
  SEQUANT_ASSERT(size(rng) > 0);

  using ranges::views::move;
  using ranges::views::tail;
  return ranges::accumulate(
      rng | tail | move, std::move(ranges::front(rng)),
      [&op](auto&& l, auto&& r) {
        if constexpr (invoke_on_node) {
          auto&& val = op(l, r);
          return FullBinaryNode(std::move(val), std::move(l), std::move(r));

        } else {
          auto&& val = op(*l, *r);
          return FullBinaryNode(std::move(val), std::move(l), std::move(r));
        }
      });
}

}  // namespace sequant

#undef TRAVERSAL_TO_TEMPLATE_ARG

#endif  // SEQUANT_BINARY_NODE_HPP