From f0be0c97d193b5c4df3653f4dfe4179695bb57e6 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 15 Dec 2023 10:45:59 -0500 Subject: [PATCH] Tensor::gemm involving custom elem_op supports batching --- src/TiledArray/tensor/tensor.h | 75 ++++++++++++++++++++++++---------- tests/einsum.cpp | 4 +- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index f3076c4514..c901dc0f4b 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -292,10 +292,12 @@ class Tensor { /// Construct a tensor with a range equal to \c range. The data is /// uninitialized. /// \param range The range of the tensor - explicit Tensor(const range_type& range) - : Tensor(range, 1, default_construct{true}) {} + /// \param batch_size The batch size (default is 1) + explicit Tensor(const range_type& range, size_type batch_size = 1) + : Tensor(range, batch_size, default_construct{true}) {} - /// Construct a tensor with a fill value + /// Construct a tensor of tensor values, setting all elements to the same + /// value /// \param range An array with the size of of each dimension /// \param value The value of the tensor elements @@ -312,12 +314,14 @@ class Tensor { new (data + i) value_type(cloner(value)); } - /// Construct a tensor with a fill value + /// Construct a tensor of scalars, setting all elements to the same value /// \param range An array with the size of of each dimension /// \param value The value of the tensor elements - template >::type* = nullptr> + template && + !detail::is_tensor::value>::type* = + nullptr> Tensor(const range_type& range, const Value& value) : Tensor(range, 1, default_construct{false}) { detail::tensor_init([value]() -> Value { return value; }, *this); @@ -358,7 +362,7 @@ class Tensor { math::uninitialized_copy_vector(range.volume(), u, this->data()); } - Tensor(const Range& range, std::initializer_list il) + explicit Tensor(const Range& range, std::initializer_list il) : Tensor(range, il.begin()) {} /// Construct a copy of a tensor interface object @@ -1004,6 +1008,22 @@ class Tensor { /// \return A mutable pointer to the tensor data pointer data() { return this->data_.get(); } + /// @param[in] batch_idx the batch index + /// @pre `batch_idx < this->batch_size()` + /// @return A const pointer to the tensor data of the batch \p batch_idx + const_pointer batch_data(size_t batch_idx) const { + TA_ASSERT(batch_idx < this->batch_size()); + return data() + batch_idx * size(); + } + + /// @param[in] batch_idx the batch index + /// @pre `batch_idx < this->batch_size()` + /// @return A const pointer to the tensor data of the batch \p batch_idx + pointer batch_data(size_t batch_idx) { + TA_ASSERT(batch_idx < this->batch_size()); + return data() + batch_idx * size(); + } + /// Read-only shared_ptr to the data /// \return A const shared_ptr to the tensor data @@ -2194,6 +2214,8 @@ class Tensor { TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); TA_ASSERT(!right.empty()); TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(left.batch_size() == right.batch_size()); + const auto batch_sz = left.batch_size(); // Check that the inner dimensions of left and right match TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), @@ -2207,7 +2229,8 @@ class Tensor { if (this->empty()) { // initialize, if empty *this = Tensor(gemm_helper.make_result_range(left.range(), - right.range())); + right.range()), + batch_sz); } else { // Check that the outer dimensions of left match the corresponding // dimensions in result @@ -2230,6 +2253,9 @@ class Tensor { TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( right.range().upbound_data(), this->range_.upbound_data())); + + // check that batch size of this matches that of left and right + TA_ASSERT(this->batch_size() == batch_sz); } // Compute gemm dimensions @@ -2243,20 +2269,25 @@ class Tensor { const integer ldb = (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? N : K); - for (integer m = 0; m != M; ++m) { - for (integer n = 0; n != N; ++n) { - auto c_offset = m * N + n; - for (integer k = 0; k != K; ++k) { - auto a_offset = - gemm_helper.left_op() == TiledArray::math::blas::NoTranspose - ? m * lda + k - : k * lda + m; - auto b_offset = - gemm_helper.right_op() == TiledArray::math::blas::NoTranspose - ? k * ldb + n - : n * ldb + k; - elem_muladd_op(*(this->data() + c_offset), *(left.data() + a_offset), - *(right.data() + b_offset)); + for (integer b = 0; b != batch_size(); ++b) { + auto this_data = this->batch_data(b); + auto left_data = left.batch_data(b); + auto right_data = right.batch_data(b); + for (integer m = 0; m != M; ++m) { + for (integer n = 0; n != N; ++n) { + auto c_offset = m * N + n; + for (integer k = 0; k != K; ++k) { + auto a_offset = + gemm_helper.left_op() == TiledArray::math::blas::NoTranspose + ? m * lda + k + : k * lda + m; + auto b_offset = + gemm_helper.right_op() == TiledArray::math::blas::NoTranspose + ? k * ldb + n + : n * ldb + k; + elem_muladd_op(*(this_data + c_offset), *(left_data + a_offset), + *(right_data + b_offset)); + } } } } diff --git a/tests/einsum.cpp b/tests/einsum.cpp index eb2ffe1869..eb976b31f5 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -604,10 +604,10 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) { return result; }; - TiledRange lhs_trange{{0, 2, 4}, {0, 5}}; + TiledRange lhs_trange{{0, 2, 4}, {0, 2, 5}}; auto lhs = random_tot_darr(world, lhs_trange); - TiledRange rhs_trange{{0, 2, 4, 6}, {0, 5}}; + TiledRange rhs_trange{{0, 2, 4, 6}, {0, 2, 5}}; auto rhs = random_tot_darr(world, rhs_trange); dist_array_t result; BOOST_REQUIRE_NO_THROW(