Program Listing for File index.hpp

Return to documentation for file (SeQuant/core/index.hpp)

//
// Created by Eduard Valeyev on 3/20/18.
//

#ifndef SEQUANT_INDEX_H
#define SEQUANT_INDEX_H

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <iostream>
// #include <SeQuant/core/space.hpp>
#include <SeQuant/core/context.hpp>
#include <SeQuant/core/tag.hpp>
#include <SeQuant/core/utility/string.hpp>
// Only needed due to a (likely) compiler bug in Apple Clang
// #include <SeQuant/core/attr.hpp>

#include <algorithm>
#include <atomic>
#include <cassert>
#include <cstdint>
#include <cwchar>
#include <functional>
#include <initializer_list>
#include <iterator>
#include <map>
#include <mutex>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

#include <range/v3/all.hpp>

// change to 1 to make thread-safe
#define SEQUANT_INDEX_THREADSAFE 1

namespace sequant {

class Index;
using WstrList = std::initializer_list<std::wstring_view>;
using IndexList = std::initializer_list<Index>;

class Index : public Taggable {
  static auto &tmp_index_accessor() {
    // initialized so that the first call to next_tmp_index will return
    // min_tmp_index()
    static std::atomic<std::size_t> index = min_tmp_index() - 1;
    return index;
  }

 public:
  using index_vector = container::vector<Index>;

  Index() = default;

  template <typename String,
            typename = std::enable_if_t<meta::is_basic_string_convertible_v<
                std::remove_reference_t<String>>>>
  Index(String &&label, const IndexSpace &space, IndexList proto_indices,
        bool symmetric_proto_indices = true)
      : label_(to_wstring(std::forward<String>(label))),
        space_(space),
        proto_indices_(proto_indices),
        symmetric_proto_indices_(symmetric_proto_indices) {
    canonicalize_proto_indices();
    check_for_duplicate_proto_indices();
    check_nontmp_label();
  }

  template <typename String,
            typename = std::enable_if_t<meta::is_basic_string_convertible_v<
                std::remove_reference_t<String>>>>
  Index(String &&label, const IndexSpace &space,
        container::vector<Index> proto_indices,
        bool symmetric_proto_indices = true)
      : label_(to_wstring(std::forward<String>(label))),
        space_(space),
        proto_indices_(std::move(proto_indices)),
        symmetric_proto_indices_(symmetric_proto_indices) {
    canonicalize_proto_indices();
    check_for_duplicate_proto_indices();
    check_nontmp_label();
  }

  template <typename String,
            typename = std::enable_if_t<meta::is_basic_string_convertible_v<
                std::remove_reference_t<String>>>>
  Index(String &&label)
      : Index(
            std::forward<String>(label),
            get_default_context().index_space_registry()
                ? get_default_context().index_space_registry()->retrieve(label)
                : Index::default_space,
            {}) {
    check_nontmp_label();
  }


  template <typename IndexOrIndexLabel, typename I,
            typename = std::enable_if_t<
                (std::is_same_v<std::decay_t<IndexOrIndexLabel>, Index> ||
                 meta::is_basic_string_convertible_v<
                     std::decay_t<IndexOrIndexLabel>>)>>
  Index(IndexOrIndexLabel &&index_or_index_label,
        std::initializer_list<I> proto_indices,
        bool symmetric_proto_indices = true)
      : symmetric_proto_indices_(symmetric_proto_indices) {
    if constexpr (!std::is_same_v<std::decay_t<IndexOrIndexLabel>, Index>) {
      label_ = index_or_index_label;
      space_ = get_default_context().index_space_registry()->retrieve(label_);
    } else {
      label_ = index_or_index_label.label();
      space_ = index_or_index_label.space();
    }
    if constexpr (!std::is_same_v<std::decay_t<I>, Index>) {
      if (proto_indices.size() != 0) {
        proto_indices_.reserve(proto_indices.size());
        for (const auto &plabel : proto_indices)
          proto_indices_.push_back(Index(plabel));
      }
    } else
      proto_indices_ = proto_indices;
    canonicalize_proto_indices();
    check_for_duplicate_proto_indices();
    check_nontmp_label();
  }


  template <
      typename IndexOrIndexLabel, typename IndexContainer,
      typename = std::enable_if_t<
          std::is_convertible_v<std::remove_reference_t<IndexContainer>,
                                container::vector<Index>> &&
          (std::is_same_v<std::decay_t<IndexOrIndexLabel>, Index> ||
           meta::is_wstring_convertible_v<std::decay_t<IndexOrIndexLabel>>)>>
  Index(IndexOrIndexLabel &&index_or_index_label,
        IndexContainer &&proto_indices, bool symmetric_proto_indices = true)
      : proto_indices_(std::forward<IndexContainer>(proto_indices)),
        symmetric_proto_indices_(symmetric_proto_indices) {
    if constexpr (!std::is_same_v<std::decay_t<IndexOrIndexLabel>, Index>) {
      label_ = index_or_index_label;
      check_nontmp_label();
      space_ = get_default_context().index_space_registry()->retrieve(label_);
    } else {
      label_ = index_or_index_label.label();
      space_ = index_or_index_label.space();
    }
    canonicalize_proto_indices();
    check_for_duplicate_proto_indices();
  }


  template <typename IndexOrIndexLabel>
  Index(IndexOrIndexLabel &&index_or_index_label, IndexSpace space) {
    if constexpr (std::is_same_v<IndexOrIndexLabel, Index>) {
      label_ = std::move(index_or_index_label.label());
      space_ = std::move(space);
      proto_indices_ = std::move(index_or_index_label.proto_indices_);
      symmetric_proto_indices_ = index_or_index_label.symmetric_proto_indices_;
      canonicalize_proto_indices();
      check_for_duplicate_proto_indices();
    } else if constexpr (std::is_same_v<std::decay_t<IndexOrIndexLabel>,
                                        Index>) {
      label_ = index_or_index_label.label();
      space_ = std::move(space);
      proto_indices_ = index_or_index_label.proto_indices_;
      symmetric_proto_indices_ = index_or_index_label.symmetric_proto_indices_;
      canonicalize_proto_indices();
      check_for_duplicate_proto_indices();
    } else {
      label_ =
          to_wstring(std::forward<IndexOrIndexLabel>(index_or_index_label));
      space_ = std::move(space);
    }
    check_nontmp_label();
  }

  [[nodiscard]] Index replace_space(IndexSpace space) const {
    return Index(*this, std::move(space));
  }

  [[nodiscard]] Index replace_qns(QuantumNumbersAttr qns) const {
    return Index(*this, IndexSpace(this->space().base_key(),
                                   this->space().attr(), std::move(qns)));
  }

  Taggable &tag() { return static_cast<Taggable &>(*this); }
  const Taggable &tag() const { return static_cast<const Taggable &>(*this); }
  void reset_tag() const {
    this->tag().reset();
    ranges::for_each(proto_indices_,
                     [](const Index &idx) { idx.tag().reset(); });
  }

  static Index make_tmp_index(const IndexSpace &space) {
    Index result;
    result.label_ =
        space.base_key() + L'_' + std::to_wstring(Index::next_tmp_index());
    result.space_ = space;
    return result;
  }

  template <typename IndexRange, typename = std::enable_if_t<meta::is_range_v<
                                     std::remove_reference_t<IndexRange>>>>
  static Index make_tmp_index(const IndexSpace &space,
                              IndexRange &&proto_indices,
                              bool symmetric_proto_indices = true) {
    Index result;
    result.label_ =
        space.base_key() + L'_' + std::to_wstring(Index::next_tmp_index());
    result.space_ = space;
    if constexpr (std::is_convertible_v<std::remove_reference_t<IndexRange>,
                                        Index::index_vector>) {
      result.proto_indices_ = std::forward<IndexRange>(proto_indices);
    } else {
      result.proto_indices_ = proto_indices | ranges::to<Index::index_vector>;
    }
    result.symmetric_proto_indices_ = symmetric_proto_indices;
    result.canonicalize_proto_indices();
    result.check_for_duplicate_proto_indices();
    return result;
  }

  static std::pair<std::wstring_view, std::wstring_view> make_split_label(
      std::wstring_view label) {
    auto underscore_position = label.find(L'_');
    if (underscore_position == std::wstring::npos)
      return {label, {}};
    else
      return {{label.data(), underscore_position},
              {label.begin() + underscore_position + 1}};
  }

  static std::wstring make_merged_label(std::wstring_view base_label,
                                        std::wstring_view ordinal_label) {
    if (ordinal_label.empty())
      return std::wstring(base_label);
    else {
      auto result = std::wstring(base_label) + L'_';
      result.append(ordinal_label);
      return result;
    }
  }

  std::wstring_view label() const { return label_; }

  std::pair<std::wstring_view, std::wstring_view> split_label() const {
    return make_split_label(this->label());
  }

  [[deprecated(
      "use to_string to produce TiledArray-compatible index label "
      "representation")]] std::string
  ascii_label() const;

  std::string to_string() const;

  std::wstring_view full_label() const {
    if (!has_proto_indices()) return label();
    if (full_label_) return *full_label_;
    std::wstring result = label_;
    ranges::for_each(proto_indices_, [&result](const Index &idx) {
      result += idx.full_label();
    });
    full_label_ = result;
    return *full_label_;
  }


  template <typename WS, typename = std::enable_if_t<(
                             meta::is_wstring_convertible_v<std::decay_t<WS>>)>>
  [[nodiscard]] std::wstring make_label_plus_suffix(WS &&suffix) const {
    return Index::make_label_plus_suffix(this->label(),
                                         std::forward<WS>(suffix));
  }


  template <typename WS1, typename WS2,
            typename = std::enable_if_t<
                (meta::is_wstring_or_view_v<std::decay_t<WS1>> &&
                 meta::is_wstring_convertible_v<std::decay_t<WS2>>)>>
  [[nodiscard]] static std::wstring make_label_plus_suffix(WS1 &&label,
                                                           WS2 &&suffix) {
    auto underscore_position = label.find(L'_');
    std::wstring result;
    if (underscore_position == std::wstring::npos) {
      result = std::forward<WS1>(label);
      result += suffix;
    } else {
      result = label.substr(0, underscore_position);
      result += suffix;
      result += label.substr(underscore_position);
    }
    return result;
  }


  template <typename WS, typename = std::enable_if_t<(
                             meta::is_wstring_convertible_v<std::decay_t<WS>>)>>
  [[nodiscard]] std::wstring make_label_minus_substring(WS &&substr) const {
    return Index::make_label_minus_substring(this->label(),
                                             std::forward<WS>(substr));
  }


  template <typename WS1, typename WS2,
            typename = std::enable_if_t<
                (meta::is_wstring_or_view_v<std::decay_t<WS1>> &&
                 meta::is_wstring_convertible_v<std::decay_t<WS2>>)>>
  [[nodiscard]] static std::wstring make_label_minus_substring(WS1 &&label,
                                                               WS2 &&substr) {
    auto underscore_position = label.find(L'_');
    std::wstring result;

    auto erase = [](auto &result, const auto &substr) {
      auto pos = result.find(substr);
      if (pos != std::wstring::npos) {
        if constexpr (std::is_same_v<std::decay_t<WS2>, std::wstring> ||
                      std::is_same_v<std::decay_t<WS2>, std::wstring_view>) {
          result.erase(pos, substr.size());
        } else if constexpr (std::is_same_v<std::decay_t<WS2>,
                                            const wchar_t[]> ||
                             std::is_same_v<std::decay_t<WS2>, wchar_t[]> ||
                             std::is_same_v<std::decay_t<WS2>,
                                            const wchar_t *> ||
                             std::is_same_v<std::decay_t<WS2>, wchar_t *>) {
          result.erase(pos, std::strlen(substr));
        } else {
          result.erase(pos, 1);
        }
      }
    };

    if (underscore_position == std::wstring::npos) {
      result = std::forward<WS1>(label);
      erase(result, substr);
    } else {
      result = label.substr(0, underscore_position);
      erase(result, substr);
      result += label.substr(underscore_position);
    }
    return result;
  }

  const IndexSpace &space() const { return space_; }

  bool has_proto_indices() const { return !proto_indices_.empty(); }
  const index_vector &proto_indices() const { return proto_indices_; }
  bool symmetric_proto_indices() const { return symmetric_proto_indices_; }
  Index drop_proto_indices() const {
    return Index(this->label(), this->space());
  }

  std::wstring to_latex() const;

  /*template <typename... Attrs>
  std::wstring to_wolfram(Attrs &&...attrs) const {
    auto protect_subscript = [](const std::wstring_view str) {
      auto subsc_pos = str.rfind(L'_');
      if (subsc_pos == std::wstring_view::npos)
        return std::wstring(str);
      else {
        assert(subsc_pos + 1 < str.size());
        std::wstring result = L"\\!\\(\\*SubscriptBox[\\(";
        result += std::wstring(str.substr(0, subsc_pos));
        result += L"\\), \\(";
        result += std::wstring(str.substr(subsc_pos + 1));
        result += L"\\)]\\)";
        return result;
      }
    };

    using namespace std::literals;
    std::wstring result =
        L"particleIndex[\""s + protect_subscript(this->label()) + L"\"";
    if (this->has_proto_indices()) {
      assert(false && "not yet supported");
    }
    using namespace std::literals;
    result += L","s + ::sequant::to_wolfram(space());
    ((result += ((L","s + ::sequant::to_wolfram(std::forward<Attrs>(attrs))))),
     ...);
    result += L"]";
    return result;
  }*/

  template <typename Range>
  static auto proto_indices_color(const Range &protoindex_range) {
    auto space_attr_view =
        protoindex_range | ranges::views::transform([](const Index &idx) {
          return int64_t(idx.space().attr());
        });
    return hash::range(ranges::begin(space_attr_view),
                       ranges::end(space_attr_view));
  }

  auto proto_indices_color() const {
    return proto_indices_color(proto_indices_);
  }

  auto color() const {
    if (has_proto_indices()) {
      auto result = proto_indices_color();
      hash::combine(result, int64_t(space().attr()));
      return result;
    } else {
      auto result = hash::value(int64_t(space().attr()));
      return result;
    }
  }

  static const std::size_t min_tmp_index();

  static std::size_t next_tmp_index() { return ++tmp_index_accessor(); }

  static void reset_tmp_index();

  template <template <typename, typename, typename... Args> class Map,
            typename... Args>
  bool transform(const Map<Index, Index, Args...> &index_map) {
    bool mutated = false;

    // outline:
    // - try replacing this first
    //   - if this is replaced by an index with protoindices, the protoindices
    //   should not be tagged since they are original and may need to be
    //   replaced also
    // - if not found, try replacing protoindices
    // - if protoindices mutated, try replacing this again

    // is this tagged already? if yes, can't skip, the protoindices may need to
    // be transformed also
    const auto this_is_tagged = this->tag().has_value();
    // sanity check that tag = 0
    if (this_is_tagged) {
      assert(this->tag().value<int>() == 0);
    } else {  // only try replacing this if not already tagged
      auto it = index_map.find(*this);
      if (it != index_map.end()) {
        *this = it->second;
        this->tag().assign(0);
        mutated = true;
      }
    }

    if (!mutated) {
      bool proto_indices_transformed = false;
      for (auto &&subidx : proto_indices_) {
        if (subidx.transform(index_map)) proto_indices_transformed = true;
      }
      if (proto_indices_transformed) {
        mutated = true;
        canonicalize_proto_indices();
        if (!this_is_tagged) {  // if protoindices were mutated, try again, but
                                // only if no tag yet
          auto it = index_map.find(*this);
          if (it != index_map.end()) {
            *this = it->second;
            this->tag().assign(0);
            mutated = true;
          }
        }
      }
    }
    if (mutated) {
      full_label_.reset();
    }
    return mutated;
  }

  struct LabelCompare {
    using is_transparent = void;
    bool operator()(const Index &first, const Index &second) const {
      return first.label() < second.label();
    }
    bool operator()(const Index &first, const std::wstring_view &second) const {
      return first.label() < second;
    }
    bool operator()(const std::wstring_view &first, const Index &second) const {
      return first < second.label();
    }
  };

  struct TypeCompare {
    bool operator()(const Index &first, const Index &second) const {
      bool result;
      if (first.space() == second.space()) {
        result = first.proto_indices() < second.proto_indices();
      } else
        result = first.space() < second.space();
      return result;
    }
  };

  struct TypeEquality {
    bool operator()(const Index &first, const Index &second) const {
      bool result = (first.space() == second.space()) &&
                    (first.proto_indices() == second.proto_indices());
      return result;
    }
  };

 private:
  std::wstring label_{};
  IndexSpace space_{};
  // an unordered set of unique indices on which this index depends on
  // whether proto_indices_ is symmetric w.r.t. permutations; if true,
  // proto_indices_ will be ordered
  index_vector proto_indices_{};
  bool symmetric_proto_indices_ = true;

  mutable std::optional<std::wstring> full_label_;

  const static IndexSpace default_space;

  inline void canonicalize_proto_indices();

  inline void check_for_duplicate_proto_indices();

  void check_nontmp_label() {
    const auto index = label_index(label_);
    if (index && index > min_tmp_index()) {
      throw std::invalid_argument(
          "Index ctor: label index must be less than the value returned by "
          "min_tmp_index()");
    }
  }

  static std::optional<std::size_t> label_index(std::wstring_view label) {
    const auto underscore_position = label.rfind(L'_');
    if (underscore_position != std::wstring::npos) {
      assert(underscore_position + 1 <
             label.size());  // check that there is at least one char past the
                             // underscore
      return std::wcstol(
          label.substr(underscore_position + 1, std::wstring::npos).data(),
          NULL, 10);
    } else
      return {};
  }

  friend class IndexFactory;

  // this ctor is only used by make_tmp_index and IndexFactory and bypasses
  // check for nontmp index
  Index(std::wstring_view label, const IndexSpace *space) noexcept
      : label_(label), space_(*space), proto_indices_() {}

  friend bool operator==(const Index &i1, const Index &i2) {
    return i1.space() == i2.space() && i1.label() == i2.label() &&
           i1.proto_indices() == i2.proto_indices();
  }

  friend bool operator!=(const Index &i1, const Index &i2) {
    return !(i1 == i2);
  }


  friend bool operator<(const Index &i1, const Index &i2) {
    // compare qns, tags and spaces in that sequence

    auto i1_Q = i1.space().qns();
    auto i2_Q = i2.space().qns();

    auto compare_space = [&i1, &i2]() {
      if (i1.space() == i2.space()) {
        if (i1.label() == i2.label()) {
          return i1.proto_indices() < i2.proto_indices();
        } else {
          return i1.label() < i2.label();
        }
      } else {
        return i1.space() < i2.space();
      }
    };

    if (i1_Q == i2_Q) {
      const bool have_tags = i1.tag().has_value() && i2.tag().has_value();

      if (!have_tags || i1.tag() == i2.tag()) {
        return compare_space();
      } else {
        return i1.tag() < i2.tag();
      }
    } else {
      return i1_Q < i2_Q;
    }
  }

};  // class Index

inline const IndexSpace Index::default_space{
    L"", IndexSpace::Type::reserved, IndexSpace::QuantumNumbers::reserved};

void Index::check_for_duplicate_proto_indices() {
#ifndef NDEBUG
  if (!symmetric_proto_indices_) {  // if proto indices not symmetric, sort via
                                    // ptrs
    container::vector<Index const *> vp;
    vp.reserve(proto_indices_.size());
    for (size_t i = 0; i < proto_indices_.size(); ++i)
      vp.push_back(&proto_indices_[i]);
    std::sort(vp.begin(), vp.end(),
              [](Index const *l, Index const *r) { return *l < *r; });
    if (std::adjacent_find(vp.begin(), vp.end(),
                           [](Index const *l, Index const *r) {
                             return *l == *r;
                           }) != vp.end()) {
      throw std::invalid_argument(
          "Index ctor: duplicate proto indices detected");
    }
  } else {  // else search directly
    if (std::adjacent_find(begin(proto_indices_), end(proto_indices_)) !=
        proto_indices_.end()) {
      throw std::invalid_argument(
          "Index ctor: duplicate proto indices detected");
    }
  }
#endif
}

void Index::canonicalize_proto_indices() {
  if (symmetric_proto_indices_)
    std::stable_sort(begin(proto_indices_), end(proto_indices_));
}

class IndexSwapper {
 public:
  IndexSwapper() : even_num_of_swaps_(true) {}
  static IndexSwapper &thread_instance() {
    static thread_local IndexSwapper instance_{};
    return instance_;
  }

  bool even_num_of_swaps() const { return even_num_of_swaps_; }
  void reset() { even_num_of_swaps_ = true; }

 private:
  std::atomic<bool> even_num_of_swaps_;
  void toggle() { even_num_of_swaps_ = !even_num_of_swaps_; }

  friend inline void swap(Index &, Index &);
};

inline void swap(Index &first, Index &second) {
  std::swap(first, second);
  IndexSwapper::thread_instance().toggle();
}

class IndexFactory {
 public:
  IndexFactory() = default;
  template <typename IndexValidator>
  explicit IndexFactory(IndexValidator validator,
                        size_t min_index = Index::min_tmp_index())
      : min_index_(min_index), validator_(validator) {
    assert(min_index_ > 0);
  }

  Index make(const IndexSpace &space) {
    Index result;
    bool valid = false;
    do {
      auto counter_it = counters_.begin();
      {  // if don't have a counter for this space
#if SEQUANT_INDEX_THREADSAFE
        std::scoped_lock lock(mutex_);
#endif
        if ((counter_it = counters_.find(space)) == counters_.end()) {
          counters_[space] = min_index_ - 1;
          counter_it = counters_.find(space);
        }
      }
      result = Index(
          space.base_key() + L'_' + std::to_wstring(++(counter_it->second)),
          &space);
      valid = validator_ ? validator_(result) : true;
    } while (!valid);
    return result;
  }

  Index make(const Index &idx) {
    const auto &space = idx.space();
    Index result;
    bool valid = false;
    do {
      auto counter_it = counters_.begin();
      {  // if don't have a counter for this space
#if SEQUANT_INDEX_THREADSAFE
        std::scoped_lock lock(mutex_);
#endif
        if ((counter_it = counters_.find(space)) == counters_.end()) {
          counters_[space] = min_index_ - 1;
          counter_it = counters_.find(space);
        }
      }
      result = Index(Index(space.base_key() + L'_' +
                               std::to_wstring(++(counter_it->second)),
                           &space),
                     idx.proto_indices());
      valid = validator_ ? validator_(result) : true;
    } while (!valid);
    return result;
  }

 private:
  std::size_t min_index_ = Index::min_tmp_index();
  std::function<bool(const Index &)> validator_ = {};
#if SEQUANT_INDEX_THREADSAFE
  std::mutex mutex_;
  // boost::container::flat_map needs copyable value, which std::atomic is not,
  // so must use std::map
  std::map<IndexSpace, std::atomic<std::size_t>> counters_;
#else
  // until multithreaded skip atomic
  container::map<IndexSpace, std::size_t> counters_;
#endif
};


inline auto hash_value(const Index &idx) {
  const auto &proto_indices = idx.proto_indices();
  using std::begin;
  using std::end;
  auto val = hash::range(begin(proto_indices), end(proto_indices));
  hash::combine(val, idx.label());
  return val;
}

template <typename Container>
auto make_indices(WstrList index_labels = {}) {
  Container result;
  for (const auto &label : index_labels) {
    result.push_back(Index{label});
  }
  return result;
}

}  // namespace sequant

#endif  // SEQUANT_INDEX_H