tile_interface.h
Go to the documentation of this file.
1 /*
2  * This file is a part of TiledArray.
3  * Copyright (C) 2014 Virginia Tech
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Justus Calvin
19  * Department of Chemistry, Virginia Tech
20  *
21  * tile_interface.h
22  * Sep 29, 2014
23  *
24  */
25 
26 #ifndef TILEDARRAY_NONINTRUSIVE_API_TENSOR_H__INCLUDED
27 #define TILEDARRAY_NONINTRUSIVE_API_TENSOR_H__INCLUDED
28 
30 #include <TiledArray/type_traits.h>
31 #include <iterator>
32 #include <vector>
33 
34 namespace TiledArray {
35 
36 // Forward declaration
37 namespace math {
38 class GemmHelper;
39 } // namespace math
40 namespace detail {
41 template <typename, typename>
42 class LazyArrayTile;
43 } // namespace detail
44 
268 // Empty operations ----------------------------------------------------------
269 
270 // to check that `arg` is empty (no data) just use std::empty
271 
272 using std::empty;
273 
274 // Subtraction ---------------------------------------------------------------
275 
277 
283 template <typename Left, typename Right>
284 inline auto subt(const Left& left, const Right& right) {
285  return left.subt(right);
286 }
287 
289 
297 template <
298  typename Left, typename Right, typename Scalar,
299  typename std::enable_if<detail::is_numeric_v<Scalar>>::type* = nullptr>
300 inline auto subt(const Left& left, const Right& right, const Scalar factor) {
301  return left.subt(right, factor);
302 }
303 
305 
312 template <typename Left, typename Right, typename Perm,
313  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
314 inline auto subt(const Left& left, const Right& right, const Perm& perm) {
315  return left.subt(right, perm);
316 }
317 
319 
328 template <
329  typename Left, typename Right, typename Scalar, typename Perm,
330  typename std::enable_if<detail::is_numeric_v<Scalar> &&
331  detail::is_permutation_v<Perm>>::type* = nullptr>
332 inline auto subt(const Left& left, const Right& right, const Scalar factor,
333  const Perm& perm) {
334  return left.subt(right, factor, perm);
335 }
336 
338 
344 template <
345  typename Arg, typename Scalar,
346  typename std::enable_if<detail::is_numeric_v<Scalar>>::type* = nullptr>
347 inline auto subt(const Arg& arg, const Scalar value) {
348  return arg.subt(value);
349 }
350 
352 
359 template <
360  typename Arg, typename Scalar, typename Perm,
361  typename std::enable_if<detail::is_numeric_v<Scalar> &&
362  detail::is_permutation_v<Perm>>::type* = nullptr>
363 inline auto subt(const Arg& arg, const Scalar value, const Perm& perm) {
364  return arg.subt(value, perm);
365 }
366 
368 
374 template <typename Result, typename Arg>
375 inline Result& subt_to(Result& result, const Arg& arg) {
376  return result.subt_to(arg);
377 }
378 
380 
388 template <
389  typename Result, typename Arg, typename Scalar,
390  typename std::enable_if<detail::is_numeric_v<Scalar>>::type* = nullptr>
391 inline Result& subt_to(Result& result, const Arg& arg, const Scalar factor) {
392  return result.subt_to(arg, factor);
393 }
394 
396 
402 template <
403  typename Result, typename Scalar,
404  typename std::enable_if<detail::is_numeric_v<Scalar>>::type* = nullptr>
405 inline Result& subt_to(Result& result, const Scalar value) {
406  return result.subt_to(value);
407 }
408 
409 template <typename... T>
410 using result_of_subt_t = decltype(subt(std::declval<T>()...));
411 
412 template <typename... T>
413 using result_of_subt_to_t = decltype(subt_to(std::declval<T>()...));
414 
415 // Multiplication operations -------------------------------------------------
416 
418 
424 template <typename Left, typename Right>
425 inline auto mult(const Left& left, const Right& right) {
426  return left.mult(right);
427 }
428 
430 
438 template <typename Left, typename Right, typename Scalar,
439  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
440 inline auto mult(const Left& left, const Right& right, const Scalar factor) {
441  return left.mult(right, factor);
442 }
443 
445 
452 template <
453  typename Left, typename Right, typename Perm,
454  typename = std::enable_if_t<detail::is_permutation_v<Perm> &&
455  detail::has_member_function_mult_anyreturn_v<
456  const Left, const Right&, const Perm&>>>
457 inline auto mult(const Left& left, const Right& right, const Perm& perm) {
458  return left.mult(right, perm);
459 }
460 
462 
471 template <typename Left, typename Right, typename Scalar, typename Perm,
472  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar> &&
473  detail::is_permutation_v<Perm>>* = nullptr>
474 inline auto mult(const Left& left, const Right& right, const Scalar factor,
475  const Perm& perm) {
476  return left.mult(right, factor, perm);
477 }
478 
480 
486 template <typename Result, typename Arg>
487 inline Result& mult_to(Result& result, const Arg& arg) {
488  return result.mult_to(arg);
489 }
490 
492 
500 template <typename Result, typename Arg, typename Scalar,
501  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
502 inline Result& mult_to(Result& result, const Arg& arg, const Scalar factor) {
503  return result.mult_to(arg, factor);
504 }
505 
506 template <typename... T>
507 using result_of_mult_t = decltype(mult(std::declval<T>()...));
508 
509 template <typename... T>
510 using result_of_mult_to_t = decltype(mult_to(std::declval<T>()...));
511 
512 // Generic element-wise binary operations
513 // ---------------------------------------------
514 
515 // clang-format off
517 
525 // clang-format on
526 template <typename Left, typename Right, typename Op>
527 inline decltype(auto) binary(const Left& left, const Right& right, Op&& op) {
528  return left.binary(right, std::forward<Op>(op));
529 }
530 
531 // clang-format off
533 
543 // clang-format on
544 template <typename Left, typename Right, typename Op, typename Perm,
545  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
546 inline decltype(auto) binary(const Left& left, const Right& right, Op&& op,
547  const Perm& perm) {
548  return left.binary(right, std::forward<Op>(op), perm);
549 }
550 
551 // clang-format off
553 
561 // clang-format on
562 template <typename Left, typename Right, typename Op>
563 inline Left& inplace_binary(Left& left, const Right& right, Op&& op) {
564  return left.inplace_binary(right, std::forward<Op>(op));
565 }
566 
567 template <typename... T>
568 using result_of_binary_t = decltype(binary(std::declval<T>()...));
569 
570 template <typename... T>
572  decltype(inplace_binary(std::declval<T>()...));
573 
574 // Scaling operations --------------------------------------------------------
575 
576 // see tile_interface/scale.h
577 
578 // Negation operations -------------------------------------------------------
579 
581 
585 template <typename Arg>
586 inline auto neg(const Arg& arg) {
587  return arg.neg();
588 }
589 
591 
596 template <typename Arg, typename Perm,
597  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
598 inline auto neg(const Arg& arg, const Perm& perm) {
599  return arg.neg(perm);
600 }
601 
603 
607 template <typename Result>
608 inline Result& neg_to(Result& result) {
609  return result.neg_to();
610 }
611 
612 template <typename... T>
613 using result_of_neg_t = decltype(neg(std::declval<T>()...));
614 
615 template <typename... T>
616 using result_of_neg_to_t = decltype(neg_to(std::declval<T>()...));
617 
618 // Complex conjugate operations ---------------------------------------------
619 
621 
625 template <typename Arg>
626 inline auto conj(const Arg& arg) {
627  return arg.conj();
628 }
629 
631 
637 template <typename Arg, typename Scalar,
638  typename std::enable_if<
639  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
640 inline auto conj(const Arg& arg, const Scalar factor) {
641  return arg.conj(factor);
642 }
643 
645 
650 template <typename Arg, typename Perm,
651  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
652 inline auto conj(const Arg& arg, const Perm& perm) {
653  return arg.conj(perm);
654 }
655 
657 
664 template <
665  typename Arg, typename Scalar, typename Perm,
666  typename std::enable_if<TiledArray::detail::is_numeric_v<Scalar> &&
667  detail::is_permutation_v<Perm>>::type* = nullptr>
668 inline auto conj(const Arg& arg, const Scalar factor, const Perm& perm) {
669  return arg.conj(factor, perm);
670 }
671 
673 
677 template <typename Result>
678 inline Result& conj_to(Result& result) {
679  return result.conj_to();
680 }
681 
683 
689 template <typename Result, typename Scalar,
690  typename std::enable_if<
691  TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
692 inline Result& conj_to(Result& result, const Scalar factor) {
693  return result.conj_to(factor);
694 }
695 
696 template <typename... T>
697 using result_of_conj_t = decltype(conj(std::declval<T>()...));
698 
699 template <typename... T>
700 using result_of_conj_to_t = decltype(conj_to(std::declval<T>()...));
701 
702 // Generic element-wise unary operations
703 // ---------------------------------------------
704 
705 // clang-format off
707 
713 // clang-format on
714 template <typename Arg, typename Op>
715 inline decltype(auto) unary(const Arg& arg, Op&& op) {
716  return arg.unary(std::forward<Op>(op));
717 }
718 
719 // clang-format off
721 
728 // clang-format on
729 template <typename Arg, typename Op, typename Perm,
730  typename = std::enable_if_t<detail::is_permutation_v<Perm>>>
731 inline decltype(auto) unary(const Arg& arg, Op&& op, const Perm& perm) {
732  return arg.unary(std::forward<Op>(op), perm);
733 }
734 
735 // clang-format off
737 
743 // clang-format on
744 template <typename Result, typename Op>
745 inline Result& inplace_unary(Result& arg, Op&& op) {
746  return arg.inplace_unary(std::forward<Op>(op));
747 }
748 
749 template <typename... T>
750 using result_of_unary_t = decltype(unary(std::declval<T>()...));
751 
752 template <typename... T>
753 using result_of_inplace_unary_t = decltype(inplace_unary(std::declval<T>()...));
754 
755 // Contraction operations ----------------------------------------------------
756 
758 
769 template <typename Left, typename Right, typename Scalar,
770  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
771 inline auto gemm(const Left& left, const Right& right, const Scalar factor,
772  const math::GemmHelper& gemm_config) {
773  return left.gemm(right, factor, gemm_config);
774 }
775 
778 
791 template <typename Result, typename Left, typename Right, typename Scalar,
792  std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
793 inline Result& gemm(Result& result, const Left& left, const Right& right,
794  const Scalar factor, const math::GemmHelper& gemm_config) {
795  return result.gemm(left, right, factor, gemm_config);
796 }
797 
800 
830 template <typename Result, typename Left, typename Right,
831  typename ElementMultiplyAddOp,
832  std::enable_if_t<std::is_invocable_r_v<
833  void, std::remove_reference_t<ElementMultiplyAddOp>,
834  typename Result::value_type&, const typename Left::value_type&,
835  const typename Right::value_type&>>* = nullptr>
836 inline Result& gemm(Result& result, const Left& left, const Right& right,
837  const math::GemmHelper& gemm_config,
838  ElementMultiplyAddOp&& element_multiplyadd_op) {
839  return result.gemm(
840  left, right, gemm_config,
841  std::forward<ElementMultiplyAddOp>(element_multiplyadd_op));
842 }
843 
844 template <typename... T>
845 using result_of_gemm_t = decltype(gemm(std::declval<T>()...));
846 
847 // Reduction operations ------------------------------------------------------
848 
850 
854 // template <typename Arg>
855 // inline auto trace(const Arg& arg) {
856 // return arg.trace();
857 //}
858 
860 
864 template <typename Arg>
865 inline auto sum(const Arg& arg) {
866  return arg.sum();
867 }
868 
870 
874 template <typename Arg>
875 inline auto product(const Arg& arg) {
876  return arg.product();
877 }
878 
880 
885 template <typename Arg>
886 inline auto squared_norm(const Arg& arg) {
887  return arg.squared_norm();
888 }
889 
891 
895 template <typename Arg>
896 inline auto norm(const Arg& arg) {
897  return arg.norm();
898 }
899 
901 
906 template <typename Arg, typename ResultType>
907 inline void norm(const Arg& arg, ResultType& result) {
908  result = arg.template norm<ResultType>();
909 }
910 
912 
916 template <typename Arg>
917 inline auto max(const Arg& arg) {
918  return arg.max();
919 }
920 
922 
926 template <typename Arg>
927 inline auto min(const Arg& arg) {
928  return arg.min();
929 }
930 
932 
936 template <typename Arg>
937 inline auto abs_max(const Arg& arg) {
938  return arg.abs_max();
939 }
940 
942 
946 template <typename Arg>
947 inline auto abs_min(const Arg& arg) {
948  return arg.abs_min();
949 }
950 
952 
958 template <typename Left, typename Right>
959 inline auto dot(const Left& left, const Right& right) {
960  return left.dot(right);
961 }
962 
964 
970 template <typename Left, typename Right>
971 inline auto inner_product(const Left& left, const Right& right) {
972  return left.inner_product(right);
973 }
974 
975 // template <typename T>
976 // using result_of_trace_t = decltype(mult(std::declval<T>()));
977 
978 template <typename T>
979 using result_of_sum_t = decltype(sum(std::declval<T>()));
980 
981 template <typename T>
982 using result_of_product_t = decltype(product(std::declval<T>()));
983 
984 template <typename T>
985 using result_of_squared_norm_t = decltype(squared_norm(std::declval<T>()));
986 
987 template <typename T, typename ResultType = T>
989  decltype(norm(std::declval<T>(), std::declval<ResultType&>()));
990 
991 template <typename T>
992 using result_of_max_t = decltype(max(std::declval<T>()));
993 
994 template <typename T>
995 using result_of_min_t = decltype(min(std::declval<T>()));
996 
997 template <typename T>
998 using result_of_abs_max_t = decltype(abs_max(std::declval<T>()));
999 
1000 template <typename T>
1001 using result_of_abs_min_t = decltype(abs_min(std::declval<T>()));
1002 
1003 template <typename L, typename R>
1004 using result_of_dot_t = decltype(dot(std::declval<L>(), std::declval<R>()));
1005 
1008 } // namespace TiledArray
1009 
1010 #endif /* TILEDARRAY_NONINTRUSIVE_API_TENSOR_H__INCLUDED */
decltype(auto) subt(const Tile< Left > &left, const Tile< Right > &right)
Subtract tile arguments.
Definition: tile.h:879
decltype(conj(std::declval< T >()...)) result_of_conj_t
decltype(abs_max(std::declval< T >())) result_of_abs_max_t
decltype(auto) unary(const Tile< Arg > &arg, Op &&op)
Unary element-wise transform producing a new tile.
Definition: tile.h:1344
Contraction to *GEMM helper.
Definition: gemm_helper.h:40
::blas::Op Op
Definition: blas.h:46
decltype(norm(std::declval< T >(), std::declval< ResultType & >())) result_of_norm_t
auto inner_product(const DistArray< Tile, Policy > &a, const DistArray< Tile, Policy > &b)
Definition: dist_array.h:1647
Tile< Left > & inplace_binary(Tile< Left > &left, const Tile< Right > &right, Op &&op)
Binary element-wise in-place transform.
Definition: tile.h:1157
decltype(min(std::declval< T >())) result_of_min_t
Tile< Result > & mult_to(Tile< Result > &result, const Tile< Arg > &arg)
Multiply to the result tile.
Definition: tile.h:1081
decltype(auto) binary(const Tile< Left > &left, const Tile< Right > &right, Op &&op)
Binary element-wise transform producing a new tile.
Definition: tile.h:1118
decltype(auto) conj(const Tile< Arg > &arg)
Create a complex conjugated copy of a tile.
Definition: tile.h:1256
decltype(mult_to(std::declval< T >()...)) result_of_mult_to_t
decltype(abs_min(std::declval< T >())) result_of_abs_min_t
auto dot(const DistArray< Tile, Policy > &a, const DistArray< Tile, Policy > &b)
Definition: dist_array.h:1640
decltype(auto) norm(const Tile< Arg > &arg)
Vector 2-norm of a tile.
Definition: tile.h:1527
decltype(auto) min(const Tile< Arg > &arg)
Minimum element of a tile.
Definition: tile.h:1559
decltype(binary(std::declval< T >()...)) result_of_binary_t
decltype(neg_to(std::declval< T >()...)) result_of_neg_to_t
auto abs_min(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1630
auto abs_max(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1635
decltype(gemm(std::declval< T >()...)) result_of_gemm_t
decltype(auto) neg(const Tile< Arg > &arg)
Negate the tile argument.
Definition: tile.h:1218
decltype(subt(std::declval< T >()...)) result_of_subt_t
decltype(auto) mult(const Tile< Left > &left, const Tile< Right > &right)
Multiplication tile arguments.
Definition: tile.h:1018
auto squared_norm(const DistArray< Tile, Policy > &a)
Definition: dist_array.h:1655
decltype(inplace_unary(std::declval< T >()...)) result_of_inplace_unary_t
decltype(conj_to(std::declval< T >()...)) result_of_conj_to_t
Tile< Result > & neg_to(Tile< Result > &result)
In-place negate tile.
Definition: tile.h:1243
decltype(unary(std::declval< T >()...)) result_of_unary_t
decltype(max(std::declval< T >())) result_of_max_t
Tile< Result > & inplace_unary(Tile< Result > &arg, Op &&op)
Unary element-wise in-place transform.
Definition: tile.h:1374
decltype(dot(std::declval< L >(), std::declval< R >())) result_of_dot_t
constexpr bool empty()
Test for empty tensors in an empty list.
Definition: utility.h:320
decltype(auto) product(const Tile< Arg > &arg)
Multiply the elements of a tile.
Definition: tile.h:1506
decltype(squared_norm(std::declval< T >())) result_of_squared_norm_t
decltype(neg(std::declval< T >()...)) result_of_neg_t
decltype(auto) max(const Tile< Arg > &arg)
Maximum element of a tile.
Definition: tile.h:1549
decltype(inplace_binary(std::declval< T >()...)) result_of_inplace_binary_t
decltype(auto) sum(const Tile< Arg > &arg)
Sum the elements of a tile.
Definition: tile.h:1496
Tile< Result > & subt_to(Tile< Result > &result, const Tile< Arg > &arg)
Subtract from the result tile.
Definition: tile.h:972
decltype(mult(std::declval< T >()...)) result_of_mult_t
decltype(subt_to(std::declval< T >()...)) result_of_subt_to_t
decltype(product(std::declval< T >())) result_of_product_t
Tile< Result > & conj_to(Tile< Result > &result)
In-place complex conjugate a tile.
Definition: tile.h:1311
decltype(sum(std::declval< T >())) result_of_sum_t
decltype(auto) gemm(const Tile< Left > &left, const Tile< Right > &right, const Scalar factor, const math::GemmHelper &gemm_config)
Contract 2 tensors over head/tail modes and scale the product.
Definition: tile.h:1396