20 #ifndef TILEDARRAY_PERMUTATION_H__INCLUDED 21 #define TILEDARRAY_PERMUTATION_H__INCLUDED 33 bool operator==(
const Permutation&,
const Permutation&);
34 std::ostream&
operator<<(std::ostream&,
const Permutation&);
35 template <
typename T, std::
size_t N>
36 inline std::array<T,N>
operator*(
const Permutation&,
const std::array<T, N>&);
37 template <
typename T, std::
size_t N>
38 inline std::array<T,N>&
operator*=(std::array<T,N>&,
const Permutation&);
39 template <
typename T,
typename A>
40 inline std::vector<T>
operator*(
const Permutation&,
const std::vector<T, A>&);
41 template <
typename T,
typename A>
42 inline std::vector<T, A>&
operator*=(std::vector<T, A>&,
const Permutation&);
44 inline std::vector<T>
operator*(
const Permutation&,
const T* MADNESS_RESTRICT
const);
56 template <
typename Perm,
typename Arg,
typename Result>
57 inline void permute_array(
const Perm& perm,
const Arg& arg, Result& result) {
59 const unsigned int n =
size(arg);
60 for(
unsigned int i = 0u; i < n; ++i) {
61 const typename Perm::index_type pi = perm[i];
128 std::vector<index_type> p_;
132 template <
typename InIter>
133 bool valid_permutation(InIter first, InIter last) {
135 using diff_type =
typename std::iterator_traits<InIter>::difference_type;
136 const diff_type n = std::distance(first, last);
138 for(; first != last; ++first) {
139 const diff_type value = *first;
140 result = result && value >= 0 && (value < n) && (std::count(first, last, *first) == 1ul);
165 template <
typename InIter,
166 typename std::enable_if<detail::is_input_iterator<InIter>::value>::type* =
nullptr>
170 TA_ASSERT( valid_permutation(first, last) );
177 template <
typename Integer>
190 TA_ASSERT( valid_permutation(p_.begin(), p_.end()) );
197 template <
typename Integer,
198 typename std::enable_if<std::is_integral<Integer>::value>::type* =
nullptr>
248 std::vector<std::vector<index_type> >
cycles()
const {
250 std::vector<std::vector<index_type>> result;
252 std::vector<bool> placed_in_cycle(p_.size(),
false);
257 if (not placed_in_cycle[i]) {
258 std::vector<index_type> cycle(1,i);
259 placed_in_cycle[i] =
true;
262 while (next_i != i) {
263 cycle.push_back(next_i);
264 placed_in_cycle[next_i] =
true;
268 if (cycle.size() != 1) {
269 std::sort(cycle.begin(), cycle.end());
270 result.emplace_back(cycle);
285 result.p_.reserve(
dim);
286 for(
unsigned int i = 0u; i <
dim; ++i)
287 result.p_.emplace_back(i);
301 const unsigned int n = p_.size();
304 result.p_.reserve(n);
306 for(
unsigned int i = 0u; i < n; ++i) {
309 result.p_.emplace_back(result_i);
323 result.p_.resize(n, 0ul);
354 result = result.
mult(value);
355 value = value.
mult(value);
365 operator bool()
const {
return ! p_.empty(); }
375 const std::vector<index_type>&
data()
const {
return p_; }
382 template <
typename Archive>
396 return (p1.
dim() == p2.
dim())
397 && std::equal(p1.
data().begin(), p1.
data().end(), p2.
data().begin());
417 return std::lexicographical_compare(p1.
data().begin(), p1.
data().end(),
418 p2.
data().begin(), p2.
data().end());
427 std::size_t n = p.
dim();
429 for (
unsigned int dim = 0; dim < n - 1; ++dim)
430 output << dim <<
"->" << p.
data()[dim] <<
", ";
431 output << n - 1 <<
"->" << p.
data()[n - 1] <<
"}";
457 return (p1 = p1 * p2);
483 template <
typename T, std::
size_t N>
486 std::array<T,N> result;
500 template <
typename T, std::
size_t N>
503 const std::array<T,N> temp = a;
517 template <
typename T,
typename A>
520 std::vector<T> result(perm.
dim());
534 template <
typename T,
typename A>
536 const std::vector<T, A> temp = v;
547 template <
typename T>
549 const unsigned int n = perm.
dim();
550 std::vector<T> result(n);
551 for(
unsigned int i = 0u; i < n; ++i) {
553 const T ptr_i = ptr[i];
554 result[perm_i] = ptr_i;
561 #endif // TILEDARRAY_PERMUTATION_H__INCLUED constexpr bool operator==(const DenseShape &a, const DenseShape &b)
void permute_array(const Perm &perm, const Arg &arg, Result &result)
Create a permuted copy of an array.
Permutation & operator=(const Permutation &)=default
const_iterator end() const
End element iterator factory function.
bool operator<(const Permutation &p1, const Permutation &p2)
Permutation less-than operator.
const_iterator cend() const
End element iterator factory function.
std::vector< std::vector< index_type > > cycles() const
Cycles decomposition.
constexpr bool operator!=(const DenseShape &a, const DenseShape &b)
bool operator!() const
Not operator.
std::array< T, N > operator*(const Permutation &, const std::array< T, N > &)
Permute a std::array.
Permutation inv() const
Construct the inverse of this permutation.
Permutation(const std::vector< Integer > &a)
Array constructor.
std::vector< index_type >::const_iterator const_iterator
index_type dim() const
Domain size accessor.
const std::vector< index_type > & data() const
Permutation data accessor.
constexpr std::size_t size(T(&)[N])
Array size accessor.
index_type operator[](unsigned int i) const
Element accessor.
Permutation operator-(const Permutation &perm)
Inverse permutation operator.
Permutation(std::initializer_list< Integer > list)
Construct permutation with an initializer list.
Permutation identity() const
Identity permutation factory function.
std::array< T, N > & operator*=(std::array< T, N > &, const Permutation &)
In-place permute a std::array.
Permutation operator^(const Permutation &perm, int n)
Raise perm to the n-th power.
Permutation(std::vector< index_type > &&a)
std::vector move constructor
std::ostream & operator<<(std::ostream &os, const DistArray< Tile, Policy > &a)
Add the tensor to an output stream.
void serialize(Archive &ar)
Serialize permutation.
Permutation of a sequence of objects indexed by base-0 indices.
const_iterator cbegin() const
Begin element iterator factory function.
const_iterator begin() const
Begin element iterator factory function.
Permutation(InIter first, InIter last)
Construct permutation from a range [first,last)
Permutation mult(const Permutation &other) const
Product of this permutation by other.
Permutation pow(int n) const
Raise this permutation to the n-th power.
static Permutation identity(const unsigned int dim)
Identity permutation factory function.