Program Listing for File tensor.hpp

Return to documentation for file (SeQuant/core/tensor.hpp)

//
// Created by Eduard Valeyev on 3/23/18.
//

#ifndef SEQUANT_TENSOR_HPP
#define SEQUANT_TENSOR_HPP

#include <SeQuant/core/abstract_tensor.hpp>
#include <SeQuant/core/attr.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/context.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/latex.hpp>
#include <SeQuant/core/utility/strong.hpp>

#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <vector>

#include <range/v3/all.hpp>

namespace sequant {

// strong type wrapper for objects associated with bra
DEFINE_STRONG_TYPE_FOR_RANGE_AND_RANGESIZE(bra);
// strong type wrapper for objects associated with ket
DEFINE_STRONG_TYPE_FOR_RANGE_AND_RANGESIZE(ket);

class Tensor : public Expr, public AbstractTensor, public Labeled {
 private:
  using index_container_type = container::svector<Index>;
  static auto make_indices(IndexList indices) { return indices; }
  static auto make_indices(WstrList index_labels) {
    index_container_type result;
    result.reserve(index_labels.size());
    for (const auto &label : index_labels) {
      result.push_back(Index{label});
    }
    return result;
  }
  static auto make_indices(
      std::initializer_list<const wchar_t *> index_labels) {
    index_container_type result;
    result.reserve(index_labels.size());
    for (const auto &label : index_labels) {
      result.push_back(Index{label});
    }
    return result;
  }
  template <typename IndexRange>
  static index_container_type make_indices(IndexRange &&indices) {
    if constexpr (std::is_same_v<index_container_type,
                                 std::decay_t<IndexRange>>) {
      return std::forward<IndexRange>(indices);
    } else {
      using ranges::begin;
      using ranges::end;
      return index_container_type(begin(indices), end(indices));
    }
  }

  auto braket() { return ranges::views::concat(bra_, ket_); }

  void assert_nonreserved_label(std::wstring_view label) const;

  // utility for dispatching to private ctor
  struct reserved_tag {};

  // list of friends who can make Tensor objects with reserved labels
  friend ExprPtr make_overlap(const Index &bra_index, const Index &ket_index);

  template <
      typename IndexRange1, typename IndexRange2,
      typename = std::enable_if_t<
          (meta::is_statically_castable_v<
              meta::range_value_t<IndexRange1>,
              Index>)&&(meta::
                            is_statically_castable_v<
                                meta::range_value_t<IndexRange2>, Index>)>>
  Tensor(std::wstring_view label, const bra<IndexRange1> &bra_indices,
         const ket<IndexRange2> &ket_indices, reserved_tag,
         Symmetry s = Symmetry::nonsymm,
         BraKetSymmetry bks = get_default_context().braket_symmetry(),
         ParticleSymmetry ps = ParticleSymmetry::symm)
      : label_(label),
        bra_(make_indices(bra_indices)),
        ket_(make_indices(ket_indices)),
        symmetry_(s),
        braket_symmetry_(bks),
        particle_symmetry_(ps) {
    validate_symmetries();
  }

  Tensor(std::wstring_view label, bra<index_container_type> &&bra_indices,
         ket<index_container_type> &&ket_indices, reserved_tag,
         Symmetry s = Symmetry::nonsymm,
         BraKetSymmetry bks = get_default_context().braket_symmetry(),
         ParticleSymmetry ps = ParticleSymmetry::symm)
      : label_(label),
        bra_(std::move(bra_indices)),
        ket_(std::move(ket_indices)),
        symmetry_(s),
        braket_symmetry_(bks),
        particle_symmetry_(ps) {
    validate_symmetries();
  }

 public:
  Tensor() = default;
  virtual ~Tensor();

  template <
      typename IndexRange1, typename IndexRange2,
      typename = std::enable_if_t<
          (meta::is_statically_castable_v<
              meta::range_value_t<IndexRange1>,
              Index>)&&(meta::
                            is_statically_castable_v<
                                meta::range_value_t<IndexRange2>, Index>)>>
  Tensor(std::wstring_view label, const bra<IndexRange1> &bra_indices,
         const ket<IndexRange2> &ket_indices, Symmetry s = Symmetry::nonsymm,
         BraKetSymmetry bks = get_default_context().braket_symmetry(),
         ParticleSymmetry ps = ParticleSymmetry::symm)
      : Tensor(label, bra_indices, ket_indices, reserved_tag{}, s, bks, ps) {
    assert_nonreserved_label(label_);
  }

  Tensor(std::wstring_view label, bra<index_container_type> &&bra_indices,
         ket<index_container_type> &&ket_indices,
         Symmetry s = Symmetry::nonsymm,
         BraKetSymmetry bks = get_default_context().braket_symmetry(),
         ParticleSymmetry ps = ParticleSymmetry::symm)
      : Tensor(label, std::move(bra_indices), std::move(ket_indices),
               reserved_tag{}, s, bks, ps) {
    assert_nonreserved_label(label_);
  }

  explicit operator bool() const {
    return !label_.empty() && symmetry_ != Symmetry::invalid &&
           braket_symmetry_ != BraKetSymmetry::invalid &&
           particle_symmetry_ != ParticleSymmetry::invalid;
  }

  std::wstring_view label() const override { return label_; }
  const auto &bra() const { return bra_; }
  const auto &ket() const { return ket_; }
  auto braket() const { return ranges::views::concat(bra_, ket_); }
  auto const_braket() const { return this->braket(); }
  Symmetry symmetry() const { return symmetry_; }
  BraKetSymmetry braket_symmetry() const { return braket_symmetry_; }
  ParticleSymmetry particle_symmetry() const { return particle_symmetry_; }

  std::size_t bra_rank() const { return bra_.size(); }
  std::size_t ket_rank() const { return ket_.size(); }
  std::size_t rank() const {
    if (bra_rank() != ket_rank()) {
      throw std::logic_error("Tensor::rank(): bra rank != ket rank");
    }
    return bra_rank();
  }

  std::wstring to_latex() const override {
    std::wstring result;
    std::vector<std::wstring> labels = {L"g", L"t", L"λ", L"t¹", L"λ¹"};
    bool add_bar =
        ranges::find(labels, this->label()) != labels.end() && this->rank() > 1;

    result = L"{";
    if ((this->symmetry() == Symmetry::antisymm) && add_bar)
      result += L"\\bar{";
    result += utf_to_latex(this->label());
    if ((this->symmetry() == Symmetry::antisymm) && add_bar) result += L"}";
    result += L"^{";
    for (const auto &i : this->ket()) result += sequant::to_latex(i);
    result += L"}_{";
    for (const auto &i : this->bra()) result += sequant::to_latex(i);
    result += L"}}";
    return result;
  }

  ExprPtr canonicalize() override;

  virtual void adjoint() override;

  template <template <typename, typename, typename... Args> class Map,
            typename... Args>
  bool transform_indices(const Map<Index, Index, Args...> &index_map) {
    bool mutated = false;
    ranges::for_each(braket(), [&](auto &idx) {
      if (idx.transform(index_map)) mutated = true;
    });
    if (mutated) this->reset_hash_value();
    return mutated;
  }

  type_id_type type_id() const override { return get_type_id<Tensor>(); };

  ExprPtr clone() const override { return ex<Tensor>(*this); }

  void reset_tags() const {
    ranges::for_each(braket(), [](const auto &idx) { idx.reset_tag(); });
  }

  hash_type bra_hash_value() const {
    if (!hash_value_)  // if hash not computed, or reset, recompute
      memoizing_hash();
    return *bra_hash_value_;
  }

 private:
  std::wstring label_{};
  sequant::bra<index_container_type> bra_{};
  sequant::ket<index_container_type> ket_{};
  Symmetry symmetry_ = Symmetry::invalid;
  BraKetSymmetry braket_symmetry_ = BraKetSymmetry::invalid;
  ParticleSymmetry particle_symmetry_ = ParticleSymmetry::invalid;
  mutable std::optional<hash_type>
      bra_hash_value_;  // memoized byproduct of memoizing_hash()
  bool is_adjoint_ = false;

  void validate_symmetries() {
    // (anti)symmetric bra or ket makes sense only for particle-symmetric
    // tensors
    if (symmetry_ == Symmetry::symm || symmetry_ == Symmetry::antisymm)
      assert(particle_symmetry_ == ParticleSymmetry::symm);
  }

  hash_type memoizing_hash() const override {
    using std::begin;
    using std::end;
    auto val = hash::range(begin(bra()), end(bra()));
    bra_hash_value_ = val;
    hash::range(val, begin(ket()), end(ket()));
    hash::combine(val, label_);
    hash::combine(val, symmetry_);
    hash_value_ = val;
    return *hash_value_;
  }
  void reset_hash_value() const override {
    Expr::reset_hash_value();
    bra_hash_value_.reset();
  }

  bool static_equal(const Expr &that) const override {
    const auto &that_cast = static_cast<const Tensor &>(that);
    if (this->label() == that_cast.label() &&
        this->symmetry() == that_cast.symmetry() &&
        this->bra_rank() == that_cast.bra_rank() &&
        this->ket_rank() == that_cast.ket_rank()) {
      // compare hash values first
      if (this->hash_value() ==
          that.hash_value())  // hash values agree -> do full comparison
        return this->bra() == that_cast.bra() && this->ket() == that_cast.ket();
      else
        return false;
    } else
      return false;
  }

  bool static_less_than(const Expr &that) const override {
    const auto &that_cast = static_cast<const Tensor &>(that);
    if (this == &that) return false;
    if (this->label() == that_cast.label()) {
      if (this->bra_rank() == that_cast.bra_rank()) {
        if (this->ket_rank() == that_cast.ket_rank()) {
          //          v1: compare hashes only
          //          return Expr::static_less_than(that);
          //          v2: compare fully
          if (this->bra_hash_value() == that_cast.bra_hash_value()) {
            return std::lexicographical_compare(
                this->ket().begin(), this->ket().end(), that_cast.ket().begin(),
                that_cast.ket().end());
          } else {
            return std::lexicographical_compare(
                this->bra().begin(), this->bra().end(), that_cast.bra().begin(),
                that_cast.bra().end());
          }
        } else {
          return this->ket_rank() < that_cast.ket_rank();
        }
      } else {
        return this->bra_rank() < that_cast.bra_rank();
      }
    } else {
      return this->label() < that_cast.label();
    }
  }

  // these implement the AbstractTensor interface
  AbstractTensor::const_any_view_randsz _bra() const override final {
    return ranges::counted_view<const Index *>(
        bra_.empty() ? nullptr : &(bra_[0]), bra_.size());
  }
  AbstractTensor::const_any_view_randsz _ket() const override final {
    return ranges::counted_view<const Index *>(
        ket_.empty() ? nullptr : &(ket_[0]), ket_.size());
  }
  AbstractTensor::const_any_view_rand _braket() const override final {
    return braket();
  }
  std::size_t _bra_rank() const override final { return bra_rank(); }
  std::size_t _ket_rank() const override final { return ket_rank(); }
  Symmetry _symmetry() const override final { return symmetry_; }
  BraKetSymmetry _braket_symmetry() const override final {
    return braket_symmetry_;
  }
  ParticleSymmetry _particle_symmetry() const override final {
    return particle_symmetry_;
  }
  std::size_t _color() const override final { return 0; }
  bool _is_cnumber() const override final { return true; }
  std::wstring_view _label() const override final { return label_; }
  std::wstring _to_latex() const override final { return to_latex(); }
  bool _transform_indices(
      const container::map<Index, Index> &index_map) override final {
    return transform_indices(index_map);
  }
  void _reset_tags() override final { reset_tags(); }
  bool operator<(const AbstractTensor &other) const override final {
    auto *other_tensor = dynamic_cast<const Tensor *>(&other);
    if (other_tensor) {
      const Expr *other_expr = static_cast<const Expr *>(other_tensor);
      return this->static_less_than(*other_expr);
    } else
      return false;  // TODO do we compare typeid? labels? probably the latter
  }

  AbstractTensor::any_view_randsz _bra_mutable() override final {
    this->reset_hash_value();
    return ranges::counted_view<Index *>(bra_.empty() ? nullptr : &(bra_[0]),
                                         bra_.size());
  }
  AbstractTensor::any_view_randsz _ket_mutable() override final {
    this->reset_hash_value();
    return ranges::counted_view<Index *>(ket_.empty() ? nullptr : &(ket_[0]),
                                         ket_.size());
  }

};  // class Tensor

using TensorPtr = std::shared_ptr<Tensor>;

inline std::wstring overlap_label() { return L"s"; }

inline ExprPtr make_overlap(const Index &bra_index, const Index &ket_index) {
  return ex<Tensor>(Tensor(overlap_label(), bra{bra_index}, ket{ket_index},
                           Tensor::reserved_tag{}));
}

}  // namespace sequant

#endif  // SEQUANT_TENSOR_HPP