For
Contract 2 tensors over head/tail modes and accumulate into result
using a custom element-wise multiply-add op The contraction is done via a GEMM operation with fused indices as defined by gemm_config
.
- Template Parameters
-
Result The result tile type Left The left-hand tile type Right The right-hand tile type ElementMultiplyAddOp a callable type with signature that implements custom multiply-add operation:void (Result::value_type& result, Left::value_type const& left,Right::value_type const& right)
- Parameters
-
result The contracted result; this can be null, will be initialized as needed left The left-hand argument to be contracted right The right-hand argument to be contracted gemm_config A helper object used to simplify gemm operations element_multiplyadd_op a custom multiply op operation for tensor elements
- Returns
- A tile whose element
result[i,j]
obtained by executingforeach k: element_multiplyadd_op(result[i,j], left[i,k], right[k,j])
plain tensors GEMM can be implemented (very inefficiently) using this method as follows:gemm(result, left, right, gemm_config,[factor](auto& result, const auto& left, const auto& right) {result += scalar * (left * right)});
btas::Tensor< T, Range, Storage > mult(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
result[i] = arg1[i] * arg2[i]
Definition: btas.h:363
btas::Tensor< T, Range, Storage > add(const btas::Tensor< T, Range, Storage > &arg1, const btas::Tensor< T, Range, Storage > &arg2)
result[i] = arg1[i] + arg2[i]
Definition: btas.h:218
btas::Tensor< T, Range, Storage > gemm(const btas::Tensor< T, Range, Storage > &left, const btas::Tensor< T, Range, Storage > &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper)
Definition: btas.h:596