26 #ifndef TILEDARRAY_MATH_BLAS_H__INCLUDED
27 #define TILEDARRAY_MATH_BLAS_H__INCLUDED
32 #include <blas/dot.hh>
33 #include <blas/gemm.hh>
34 #include <blas/scal.hh>
35 #include <blas/util.hh>
36 #include <blas/wrappers.hh>
47 static constexpr
auto NoTranspose = Op::NoTrans;
48 static constexpr
auto Transpose = Op::Trans;
49 static constexpr
auto ConjTranspose = Op::ConjTrans;
54 if (op == NoTranspose)
56 else if (op == Transpose)
62 template <
typename T,
int Options = ::Eigen::ColMajor>
63 using Matrix = ::Eigen::Matrix<T, ::Eigen::Dynamic, ::Eigen::Dynamic, Options>;
66 using Vector = ::Eigen::Matrix<T, ::Eigen::Dynamic, 1, ::Eigen::ColMajor>;
70 template <
typename S1,
typename T1,
typename T2,
typename S2,
typename T3>
72 const integer k,
const S1 alpha,
const T1* a,
74 const S2 beta, T3* c,
const integer ldc) {
76 static const unsigned int notrans_notrans = 0x00000000,
77 notrans_trans = 0x00000004,
78 trans_notrans = 0x00000001,
79 trans_trans = 0x00000005,
80 notrans_conjtrans = 0x00000008,
81 trans_conjtrans = 0x00000009,
82 conjtrans_notrans = 0x00000002,
83 conjtrans_trans = 0x00000006,
84 conjtrans_conjtrans = 0x0000000a;
87 typedef Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
89 typedef Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
91 typedef Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
93 Eigen::Map<const matrixA_type, Eigen::AutoAlign, Eigen::OuterStride<>> A(
94 a, (op_a == NoTranspose ? m : k), (op_a == NoTranspose ? k : m),
95 Eigen::OuterStride<>(lda));
96 Eigen::Map<const matrixB_type, Eigen::AutoAlign, Eigen::OuterStride<>> B(
97 b, (op_b == NoTranspose ? k : n), (op_b == NoTranspose ? n : k),
98 Eigen::OuterStride<>(ldb));
99 Eigen::Map<matrixC_type, Eigen::AutoAlign, Eigen::OuterStride<>> C(
100 c, m, n, Eigen::OuterStride<>(ldc));
102 const bool beta_is_nonzero = (beta !=
static_cast<S2
>(0));
105 case notrans_notrans:
107 C.noalias() = alpha * A * B + beta * C;
109 C.noalias() = alpha * A * B;
113 C.noalias() = alpha * A * B.transpose() + beta * C;
115 C.noalias() = alpha * A * B.transpose();
119 C.noalias() = alpha * A.transpose() * B + beta * C;
121 C.noalias() = alpha * A.transpose() * B;
125 C.noalias() = alpha * A.transpose() * B.transpose() + beta * C;
127 C.noalias() = alpha * A.transpose() * B.transpose();
130 case notrans_conjtrans:
132 C.noalias() = alpha * A * B.adjoint() + beta * C;
134 C.noalias() = alpha * A * B.adjoint();
136 case trans_conjtrans:
138 C.noalias() = alpha * A.transpose() * B.adjoint() + beta * C;
140 C.noalias() = alpha * A.transpose() * B.adjoint();
142 case conjtrans_notrans:
144 C.noalias() = alpha * A.adjoint() * B + beta * C;
146 C.noalias() = alpha * A.adjoint() * B;
148 case conjtrans_trans:
150 C.noalias() = alpha * A.adjoint() * B.transpose() + beta * C;
152 C.noalias() = alpha * A.adjoint() * B.transpose();
154 case conjtrans_conjtrans:
156 C.noalias() = alpha * A.adjoint() * B.adjoint() + beta * C;
158 C.noalias() = alpha * A.adjoint() * B.adjoint();
164 const integer k,
const float alpha,
const float* a,
166 const float beta,
float* c,
const integer ldc) {
167 ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
172 const integer k,
const double alpha,
const double* a,
174 const double beta,
double* c,
const integer ldc) {
175 ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
180 const integer k,
const std::complex<float> alpha,
181 const std::complex<float>* a,
const integer lda,
182 const std::complex<float>* b,
const integer ldb,
183 const std::complex<float> beta, std::complex<float>* c,
185 ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
190 const integer k,
const std::complex<double> alpha,
191 const std::complex<double>* a,
const integer lda,
192 const std::complex<double>* b,
const integer ldb,
193 const std::complex<double> beta, std::complex<double>* c,
195 ::blas::gemm(::blas::Layout::ColMajor, op_b, op_a, n, m, k, alpha, b, ldb, a,
201 template <
typename T,
typename U>
202 inline typename std::enable_if<detail::is_numeric_v<T>>::type
scale(
203 const integer n,
const T alpha, U* x) {
208 ::blas::scal(n, alpha, x, 1);
212 ::blas::scal(n, alpha, x, 1);
216 std::complex<float>* x) {
217 ::blas::scal(n, alpha, x, 1);
221 std::complex<double>* x) {
222 ::blas::scal(n, alpha, x, 1);
225 inline void scale(
const integer n,
const float alpha, std::complex<float>* x) {
226 ::blas::scal(n, std::complex<float>{alpha, 0}, x, 1);
230 std::complex<double>* x) {
231 ::blas::scal(n, std::complex<double>{alpha, 0}, x, 1);
236 template <
typename T,
typename U>
241 inline float dot(
const integer n,
const float* x,
const float* y) {
245 inline double dot(
integer n,
const double* x,
const double* y) {
249 inline std::complex<float>
dot(
integer n,
const std::complex<float>* x,
250 const std::complex<float>* y) {
254 inline std::complex<double>
dot(
integer n,
const std::complex<double>* x,
255 const std::complex<double>* y) {
268 #endif // TILEDARRAY_MATH_BLAS_H__INCLUDED