26 #ifndef TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED 27 #define TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED 42 madness::cblas::CBLAS_TRANSPOSE left_op_;
44 madness::cblas::CBLAS_TRANSPOSE right_op_;
46 unsigned int result_rank_;
54 unsigned int inner[2];
55 unsigned int outer[2];
65 const madness::cblas::CBLAS_TRANSPOSE
right_op,
79 if(
left_op == madness::cblas::NoTrans) {
81 left_.outer[1] = left_.inner[0] =
left_rank - contract_size;
85 left_.inner[1] = left_.outer[0] = contract_size;
90 if(
right_op == madness::cblas::NoTrans) {
92 right_.inner[1] = right_.outer[0] = contract_size;
96 right_.outer[1] = right_.inner[0] =
right_rank - contract_size;
106 left_op_(other.left_op_), right_op_(other.right_op_),
107 result_rank_(other.result_rank_),
108 left_(other.left_), right_(other.right_)
115 left_op_ = other.left_op_;
116 right_op_ = other.right_op_;
117 result_rank_ = other.result_rank_;
119 right_ = other.right_;
128 return (left_.rank + right_.rank - result_rank_) >> 1;
165 template <
typename R,
typename Left,
typename Right>
168 const auto* MADNESS_RESTRICT
const left_lower = left.lobound_data();
169 const auto* MADNESS_RESTRICT
const left_upper = left.upbound_data();
170 const auto* MADNESS_RESTRICT
const right_lower = right.lobound_data();
171 const auto* MADNESS_RESTRICT
const right_upper = right.upbound_data();
174 std::vector<std::size_t> lower, upper;
175 lower.reserve(result_rank_);
176 upper.reserve(result_rank_);
179 for(
unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i) {
180 lower.push_back(left_lower[i]);
181 upper.push_back(left_upper[i]);
185 for(
unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i) {
186 lower.push_back(right_lower[i]);
187 upper.push_back(right_upper[i]);
191 return R(lower, upper);
204 template <
typename Left,
typename Result>
206 return std::equal(left + left_.outer[0], left + left_.outer[1], result);
219 template <
typename Right,
typename Result>
221 return std::equal(right + right_.outer[0], right + right_.outer[1],
222 result + (left_.outer[1] - left_.outer[0]));
235 template <
typename Left,
typename Right>
237 return std::equal(left + left_.inner[0], left + left_.inner[1],
238 right + right_.inner[0]);
251 template <
typename Left,
typename Right>
253 const Left& left,
const Right& right)
const 258 const auto* MADNESS_RESTRICT
const left_extent = left.extent_data();
259 const auto* MADNESS_RESTRICT
const right_extent = right.extent_data();
263 for(
unsigned int i = left_.outer[0]; i < left_.outer[1]; ++i)
266 for(
unsigned int i = left_.inner[0]; i < left_.inner[1]; ++i)
269 for(
unsigned int i = right_.outer[0]; i < right_.outer[1]; ++i)
270 n *= right_extent[i];
273 madness::cblas::CBLAS_TRANSPOSE
left_op()
const {
return left_op_; }
274 madness::cblas::CBLAS_TRANSPOSE
right_op()
const {
return right_op_; }
280 #endif // TILEDARRAY_MATH_GEMM_HELPER_H__INCLUDED void compute_matrix_sizes(integer &m, integer &n, integer &k, const Left &left, const Right &right) const
Compute the matrix dimension that can be used in a *GEMM call.
unsigned int right_inner_begin() const
bool left_result_congruent(const Left &left, const Result &result) const
Test that the outer dimensions of left are congruent (have equal extent) with that of the result tens...
bool right_result_congruent(const Right &right, const Result &result) const
Test that the outer dimensions of right are congruent (have equal extent) with that of the result ten...
madness::cblas::CBLAS_TRANSPOSE right_op() const
bool left_right_congruent(const Left &left, const Right &right) const
Test that the inner dimensions of left are congruent (have equal extent) with that of right...
unsigned int result_rank() const
Result rank accessor.
void outer(const std::size_t m, const std::size_t n, const X *const x, const Y *const y, A *a, const Op &op)
Compute the outer of x and y to modify a.
GemmHelper & operator=(const GemmHelper &other)
Functor assignment operator.
R make_result_range(const Left &left, const Right &right) const
Construct a result range based on left and right ranges.
unsigned int left_outer_end() const
unsigned int num_contract_ranks() const
Compute the number of contracted ranks.
unsigned int right_inner_end() const
unsigned int right_outer_begin() const
GemmHelper(const GemmHelper &other)
Functor copy constructor.
unsigned int right_rank() const
Right-hand argument rank accessor.
Contraction to *GEMM helper.
madness::cblas::CBLAS_TRANSPOSE left_op() const
unsigned int left_outer_begin() const
unsigned int right_outer_end() const
unsigned int left_rank() const
Left-hand argument rank accessor.
unsigned int left_inner_end() const
GemmHelper(const madness::cblas::CBLAS_TRANSPOSE left_op, const madness::cblas::CBLAS_TRANSPOSE right_op, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank)
unsigned int left_inner_begin() const