Skip to content

Commit

Permalink
Tensor::gemm involving custom elem_op supports batching
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Dec 15, 2023
1 parent 02a7db7 commit f0be0c9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 24 deletions.
75 changes: 53 additions & 22 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,12 @@ class Tensor {
/// Construct a tensor with a range equal to \c range. The data is
/// uninitialized.
/// \param range The range of the tensor
explicit Tensor(const range_type& range)
: Tensor(range, 1, default_construct{true}) {}
/// \param batch_size The batch size (default is 1)
explicit Tensor(const range_type& range, size_type batch_size = 1)
: Tensor(range, batch_size, default_construct{true}) {}

/// Construct a tensor with a fill value
/// Construct a tensor of tensor values, setting all elements to the same
/// value

/// \param range An array with the size of of each dimension
/// \param value The value of the tensor elements
Expand All @@ -312,12 +314,14 @@ class Tensor {
new (data + i) value_type(cloner(value));
}

/// Construct a tensor with a fill value
/// Construct a tensor of scalars, setting all elements to the same value

/// \param range An array with the size of of each dimension
/// \param value The value of the tensor elements
template <typename Value, typename std::enable_if<
detail::is_numeric_v<Value>>::type* = nullptr>
template <typename Value,
typename std::enable_if<std::is_convertible_v<Value, value_type> &&
!detail::is_tensor<Value>::value>::type* =
nullptr>
Tensor(const range_type& range, const Value& value)
: Tensor(range, 1, default_construct{false}) {
detail::tensor_init([value]() -> Value { return value; }, *this);
Expand Down Expand Up @@ -358,7 +362,7 @@ class Tensor {
math::uninitialized_copy_vector(range.volume(), u, this->data());
}

Tensor(const Range& range, std::initializer_list<T> il)
explicit Tensor(const Range& range, std::initializer_list<T> il)
: Tensor(range, il.begin()) {}

/// Construct a copy of a tensor interface object
Expand Down Expand Up @@ -1004,6 +1008,22 @@ class Tensor {
/// \return A mutable pointer to the tensor data
pointer data() { return this->data_.get(); }

/// @param[in] batch_idx the batch index
/// @pre `batch_idx < this->batch_size()`
/// @return A const pointer to the tensor data of the batch \p batch_idx
const_pointer batch_data(size_t batch_idx) const {
TA_ASSERT(batch_idx < this->batch_size());
return data() + batch_idx * size();
}

/// @param[in] batch_idx the batch index
/// @pre `batch_idx < this->batch_size()`
/// @return A const pointer to the tensor data of the batch \p batch_idx
pointer batch_data(size_t batch_idx) {
TA_ASSERT(batch_idx < this->batch_size());
return data() + batch_idx * size();
}

/// Read-only shared_ptr to the data

/// \return A const shared_ptr to the tensor data
Expand Down Expand Up @@ -2194,6 +2214,8 @@ class Tensor {
TA_ASSERT(left.range().rank() == gemm_helper.left_rank());
TA_ASSERT(!right.empty());
TA_ASSERT(right.range().rank() == gemm_helper.right_rank());
TA_ASSERT(left.batch_size() == right.batch_size());
const auto batch_sz = left.batch_size();

// Check that the inner dimensions of left and right match
TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(),
Expand All @@ -2207,7 +2229,8 @@ class Tensor {

if (this->empty()) { // initialize, if empty
*this = Tensor(gemm_helper.make_result_range<range_type>(left.range(),
right.range()));
right.range()),
batch_sz);
} else {
// Check that the outer dimensions of left match the corresponding
// dimensions in result
Expand All @@ -2230,6 +2253,9 @@ class Tensor {
TA_ASSERT(ignore_tile_position() ||
gemm_helper.right_result_congruent(
right.range().upbound_data(), this->range_.upbound_data()));

// check that batch size of this matches that of left and right
TA_ASSERT(this->batch_size() == batch_sz);
}

// Compute gemm dimensions
Expand All @@ -2243,20 +2269,25 @@ class Tensor {
const integer ldb =
(gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? N : K);

for (integer m = 0; m != M; ++m) {
for (integer n = 0; n != N; ++n) {
auto c_offset = m * N + n;
for (integer k = 0; k != K; ++k) {
auto a_offset =
gemm_helper.left_op() == TiledArray::math::blas::NoTranspose
? m * lda + k
: k * lda + m;
auto b_offset =
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose
? k * ldb + n
: n * ldb + k;
elem_muladd_op(*(this->data() + c_offset), *(left.data() + a_offset),
*(right.data() + b_offset));
for (integer b = 0; b != batch_size(); ++b) {
auto this_data = this->batch_data(b);
auto left_data = left.batch_data(b);
auto right_data = right.batch_data(b);
for (integer m = 0; m != M; ++m) {
for (integer n = 0; n != N; ++n) {
auto c_offset = m * N + n;
for (integer k = 0; k != K; ++k) {
auto a_offset =
gemm_helper.left_op() == TiledArray::math::blas::NoTranspose
? m * lda + k
: k * lda + m;
auto b_offset =
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose
? k * ldb + n
: n * ldb + k;
elem_muladd_op(*(this_data + c_offset), *(left_data + a_offset),
*(right_data + b_offset));
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,10 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
return result;
};

TiledRange lhs_trange{{0, 2, 4}, {0, 5}};
TiledRange lhs_trange{{0, 2, 4}, {0, 2, 5}};
auto lhs = random_tot_darr(world, lhs_trange);

TiledRange rhs_trange{{0, 2, 4, 6}, {0, 5}};
TiledRange rhs_trange{{0, 2, 4, 6}, {0, 2, 5}};
auto rhs = random_tot_darr(world, rhs_trange);
dist_array_t result;
BOOST_REQUIRE_NO_THROW(
Expand Down

0 comments on commit f0be0c9

Please sign in to comment.