26 #ifndef TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED 27 #define TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED 35 namespace expressions {
44 template <
typename Derived>
101 permute_to_no_trans = 3,
120 find(
const VariableList&
vars, std::string var,
unsigned int i,
const unsigned int n) {
136 template <
typename L,
typename R>
139 left_op_(permute_to_no_trans), right_op_(permute_to_no_trans), op_(),
149 template <
typename L,
typename R,
typename S>
152 left_op_(permute_to_no_trans), right_op_(permute_to_no_trans), op_(),
169 if((left_op_ == permute_to_no_trans) || (right_op_ == permute_to_no_trans)) {
172 const unsigned int result_rank = target_vars.
dim();
173 const unsigned int inner_rank = (
left_.vars().dim() +
174 right_.vars().dim() - result_rank) >> 1;
175 const unsigned int left_outer_rank =
left_.vars().dim() - inner_rank;
179 bool target_partitioned =
true;
180 for(
unsigned int i = 0u; i < left_outer_rank; ++i)
181 target_partitioned = target_partitioned &&
182 (find(target_vars, left_vars_[i], 0u, left_outer_rank) < left_outer_rank);
186 if(target_partitioned) {
188 if(left_op_ == permute_to_no_trans) {
190 for(
unsigned int i = 0u; i < left_outer_rank; ++i) {
191 const std::string& var = target_vars[i];
192 const_cast<std::string&
>(left_vars_[i]) = var;
193 const_cast<std::string&
>(
vars_[i]) = var;
197 left_.perm_vars(left_vars_);
200 for(
unsigned int i = 0u; i < left_outer_rank; ++i)
201 const_cast<std::string&>(
vars_[i]) = left_vars_[i];
204 if(right_op_ == permute_to_no_trans) {
206 for(
unsigned int i = left_outer_rank, j = inner_rank; i < result_rank; ++i, ++j) {
207 const std::string& var = target_vars[i];
208 const_cast<std::string&
>(right_vars_[j]) = var;
209 const_cast<std::string&
>(
vars_[i]) = var;
213 right_.perm_vars(right_vars_);
216 for(
unsigned int i = left_outer_rank, j = inner_rank; i < result_rank; ++i, ++j)
217 const_cast<std::string&>(
vars_[i]) = right_vars_[j];
229 const unsigned int left_rank =
left_.vars().dim();
230 const unsigned int right_rank =
right_.vars().dim();
233 std::vector<std::string>& left_vars =
234 const_cast<std::vector<std::string>&
>(left_vars_.
data());
235 left_vars.reserve(left_rank);
236 std::vector<std::string>& right_vars =
237 const_cast<std::vector<std::string>&
>(right_vars_.
data());
238 right_vars.reserve(right_rank);
239 std::vector<std::string>& result_vars =
240 const_cast<std::vector<std::string>&
>(
vars_.
data());
241 result_vars.reserve(
std::max(left_rank, right_rank));
244 for(
unsigned int i = 0ul; i < left_rank; ++i) {
245 const std::string& var =
left_.vars()[i];
246 if(find(
right_.vars(), var, 0u, right_rank) == right_rank) {
248 left_vars.push_back(var);
249 result_vars.push_back(var);
252 right_vars.push_back(var);
257 const unsigned int inner_rank = right_vars.size();
258 const unsigned int left_outer_rank = left_vars.size();
259 const unsigned int right_outer_rank = right_rank - inner_rank;
260 const unsigned int result_rank = left_outer_rank + right_outer_rank;
263 result_vars.reserve(result_rank);
266 if(inner_rank == 0u) {
269 for(
unsigned int i = 0ul; i < right_rank; ++i) {
270 const std::string& var =
right_.vars()[i];
271 right_vars.push_back(var);
272 result_vars.push_back(var);
280 bool inner_vars_ordered =
true, left_is_no_trans =
true, left_is_trans =
true,
281 right_is_no_trans =
true, right_is_trans =
true;
289 const bool perm_left = (left_rank < right_rank) || ((left_rank == right_rank)
290 && (left_type::leaves <= right_type::leaves));
295 for(
unsigned int i = 0ul; i < right_rank; ++i) {
296 const std::string& var =
right_.vars()[i];
297 const unsigned int j = find(
left_.vars(), var, 0u, left_rank);
300 right_vars.push_back(var);
301 result_vars.push_back(var);
303 const unsigned int x = left_vars.size() - left_outer_rank;
306 inner_vars_ordered = inner_vars_ordered && (right_vars[x] == var);
307 left_is_no_trans = left_is_no_trans && (j >= left_outer_rank);
308 left_is_trans = left_is_trans && (j < inner_rank);
309 right_is_no_trans = right_is_no_trans && (i < inner_rank);
310 right_is_trans = right_is_trans && (i >= right_outer_rank);
313 if(inner_vars_ordered) {
315 left_vars.push_back(var);
316 }
else if(perm_left) {
319 left_vars.push_back(var);
321 left_is_no_trans = left_is_trans =
false;
325 left_vars.push_back(right_vars[x]);
326 right_is_no_trans = right_is_trans =
false;
334 if(left_is_no_trans) {
336 left_.permute_tiles(
false);
337 }
else if(left_is_trans) {
339 left_.permute_tiles(
false);
341 left_.perm_vars(left_vars_);
343 if(right_is_no_trans) {
344 right_op_ = no_trans;
345 right_.permute_tiles(
false);
346 }
else if(right_is_trans) {
348 right_.permute_tiles(
false);
350 right_.perm_vars(right_vars_);
362 left_.init_struct(left_vars_);
363 right_.init_struct(right_vars_);
368 const madness::cblas::CBLAS_TRANSPOSE left_op =
369 (left_op_ == trans ? madness::cblas::Trans : madness::cblas::NoTrans);
370 const madness::cblas::CBLAS_TRANSPOSE right_op =
371 (right_op_ == trans ? madness::cblas::Trans : madness::cblas::NoTrans);
374 if(target_vars !=
vars_) {
401 const unsigned int inner_rank = op_.gemm_helper().num_contract_ranks();
402 const unsigned int left_rank = op_.gemm_helper().left_rank();
403 const unsigned int right_rank = op_.gemm_helper().right_rank();
404 const unsigned int left_outer_rank = left_rank - inner_rank;
407 const size_type* MADNESS_RESTRICT
const left_tiles_size =
408 left_.trange().tiles_range().extent_data();
409 const size_type* MADNESS_RESTRICT
const left_element_size =
410 left_.trange().elements_range().extent_data();
411 const size_type* MADNESS_RESTRICT
const right_tiles_size =
412 right_.trange().tiles_range().extent_data();
413 const size_type* MADNESS_RESTRICT
const right_element_size =
414 right_.trange().elements_range().extent_data();
417 size_type M = 1ul, m = 1ul, N = 1ul, n = 1ul;
419 for(; i < left_outer_rank; ++i) {
420 M *= left_tiles_size[i];
421 m *= left_element_size[i];
423 for(; i < left_rank; ++i)
424 K_ *= left_tiles_size[i];
425 for(i = inner_rank; i < right_rank; ++i) {
426 N *= right_tiles_size[i];
427 n *= right_element_size[i];
449 const unsigned int left_rank = op_.gemm_helper().left_rank();
450 const unsigned int right_rank = op_.gemm_helper().right_rank();
451 const unsigned int inner_rank = op_.gemm_helper().num_contract_ranks();
452 const unsigned int left_outer_rank = left_rank - inner_rank;
455 typename trange_type::Ranges ranges(op_.gemm_helper().result_rank());
456 unsigned int i = 0ul;
457 for(
unsigned int x = 0ul; x < left_outer_rank; ++x, ++i) {
458 const unsigned int pi = (
perm ?
perm[i] : i);
459 ranges[pi] =
left_.trange().data()[x];
461 for(
unsigned int x = inner_rank; x < right_rank; ++x, ++i) {
462 const unsigned int pi = (
perm ?
perm[i] : i);
463 ranges[pi] =
right_.trange().data()[x];
469 const auto* MADNESS_RESTRICT
const left_extent =
470 left_.trange().tiles_range().extent_data();
471 const auto* MADNESS_RESTRICT
const right_extent =
472 right_.trange().tiles_range().extent_data();
475 for(
unsigned int l = left_outer_rank, r = 0ul; l < left_rank; ++l, ++r) {
476 if(
left_.trange().data()[l] !=
right_.trange().data()[r]) {
477 if(TiledArray::get_default_world().rank() == 0) {
479 if(left_extent[l] == right_extent[r]) {
481 "of the left- and right-hand arguments are not equal.");
485 "and right-hand arguments are not congruent:" \
486 <<
"\n left = " <<
left_.trange() \
487 <<
"\n right = " <<
right_.trange() );
489 TA_EXCEPTION(
"The contracted dimensions of the left- and " \
490 "right-hand expressions are not congruent.");
494 TA_EXCEPTION(
"The contracted dimensions of the left- and " 495 "right-hand expressions are not congruent.");
508 shape_gemm_helper(madness::cblas::NoTrans, madness::cblas::NoTrans,
509 op_.gemm_helper().result_rank(), op_.gemm_helper().left_rank(),
510 op_.gemm_helper().right_rank());
520 shape_gemm_helper(madness::cblas::NoTrans, madness::cblas::NoTrans,
521 op_.gemm_helper().result_rank(), op_.gemm_helper().left_rank(),
522 op_.gemm_helper().right_rank());
530 typename right_type::dist_eval_type,
op_type,
typename Derived::policy> impl_type;
532 typename left_type::dist_eval_type left =
left_.make_dist_eval();
533 typename right_type::dist_eval_type right =
right_.make_dist_eval();
535 std::shared_ptr<impl_type> pimpl =
546 std::stringstream ss;
561 left_.print(os, left_vars_);
562 right_.print(os, right_vars_);
571 #endif // TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED EngineTrait< Derived >::trange_type trange_type
Tiled range type.
std::string make_tag() const
Expression identification tag.
Multiplication expression.
trange_type make_trange(const Permutation &perm=Permutation()) const
Tiled range factory function.
std::shared_ptr< EngineParamOverride< Derived > > override_ptr_
The engine params overriding the default.
TiledArray::detail::ContractReduce< value_type, typename eval_trait< typename left_type::value_type >::type, typename eval_trait< typename right_type::value_type >::type, scalar_type > op_type
The tile operation type.
right_type right_
The right-hand argument.
EngineTrait< Derived >::right_type right_type
The right-hand expression type.
bool permute_tiles_
Result tile permutation flag (true == permute tile)
void print(ExprOStream os, const VariableList &target_vars) const
Expression print.
const Permutation & perm() const
Permutation accessor.
scalar_type factor_
Contraction scaling factor.
Permutation make_perm(const VariableList &target_vars) const
Permutation factory function.
void init_struct(const VariableList &target_vars)
Initialize result tensor structure.
BinaryEngine< Derived > BinaryEngine_
Binary base class type.
EngineTrait< MultEngine< Left, Right, Result > >::shape_type shape_type
Shape type.
World * world() const
World accessor.
KroneckerDeltaTile< _N >::numeric_type max(const KroneckerDeltaTile< _N > &arg)
EngineTrait< MultEngine< Left, Right, Result > >::op_type op_type
The tile operation type.
const shape_type & shape() const
Shape accessor.
derived_type & derived()
Cast this object to it's derived type.
const std::vector< std::string > & data() const
Multiplication expression.
const std::shared_ptr< pmap_interface > & pmap() const
Process map accessor.
VariableList vars_
The variable list of this expression.
EngineTrait< Derived >::value_type value_type
The result tile type.
EngineTrait< MultEngine< Left, Right, Result > >::size_type size_type
Size type.
shape_type make_shape() const
Non-permuting shape factory function.
std::shared_ptr< Pmap > make_pmap() const
Construct a cyclic process.
EngineTrait< MultEngine< Left, Right, Result > >::trange_type trange_type
Tiled range type.
EngineTrait< MultEngine< Left, Right, Result > >::dist_eval_type dist_eval_type
The distributed evaluator type.
unsigned int dim() const
Returns the number of strings in the variable list.
const VariableList & vars() const
Variable list accessor.
ContEngine< Derived > ContEngine_
This class type.
void perm_vars(const VariableList &target_vars)
Set the variable list for this expression.
EngineTrait< Derived >::left_type left_type
The left-hand expression type.
EngineTrait< Derived >::size_type size_type
Size type.
Variable list manages a list variable strings.
EngineTrait< Derived >::shape_type shape_type
Shape type.
ContEngine(const ScalMultExpr< L, R, S > &expr)
Constructor.
Multiplication expression engine.
shape_type make_shape(const Permutation &perm) const
Permuting shape factory function.
std::shared_ptr< pmap_interface > pmap_
The process map for the result tensor.
EngineTrait< Derived >::dist_eval_type dist_eval_type
The distributed evaluator type.
World * world_
The world where this expression will be evaluated.
ExprEngine< Derived > ExprEngine_
Expression engine base class type.
void inc()
Increment the number of tabs.
Contraction to *GEMM helper.
Permutation of a sequence of objects indexed by base-0 indices.
trange_type make_trange() const
Non-permuting tiled range factory function.
Permutation perm_
The permutation that will be applied to the result.
void init_vars()
Initialize the variable list of this expression.
void init_distribution(World *world, std::shared_ptr< pmap_interface > pmap)
Initialize result tensor distribution.
Contract and (sum) reduce operation.
void init_distribution(World *world, const std::shared_ptr< pmap_interface > &pmap)
Initialize result tensor distribution.
shape_type shape_
The shape of the result tensor.
std::shared_ptr< Pmap > make_row_phase_pmap(const size_type cols) const
Construct row phased a cyclic process.
std::shared_ptr< Pmap > make_col_phase_pmap(const size_type rows) const
Construct column phased a cyclic process.
Distributed contraction evaluator implementation.
ContEngine(const MultExpr< L, R > &expr)
Constructor.
Expression output stream.
EngineTrait< Derived >::scalar_type scalar_type
Tile scalar type.
dist_eval_type make_dist_eval() const
trange_type trange_
The tiled range of the result tensor.
#define TA_USER_ERROR_MESSAGE(m)
left_type left_
The left-hand argument.
EngineTrait< Derived >::policy policy
The result policy type.
void print(ExprOStream &os, const VariableList &target_vars) const
Expression print.
EngineTrait< Derived >::pmap_interface pmap_interface
Process map interface type.
void dec()
Decrement the number of tabs.