Skip to content

Commit

Permalink
test_common: Update VanillaGemm
Browse files Browse the repository at this point in the history
  - Fix VanillaGemm to work with batch_size_last_dim=true
  when Cuda is enabled.
  • Loading branch information
e10harvey committed Mar 10, 2021
1 parent a7558b5 commit 4ea0e4c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions test_common/KokkosKernels_TestUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,29 @@ namespace Test {
using ScalarA = typename ViewTypeA::value_type;
using ScalarB = typename ViewTypeB::value_type;
using ScalarC = typename ViewTypeC::value_type;
using SubviewTypeA = typename Kokkos::View<ScalarA**, Kokkos::LayoutStride, typename ViewTypeA::device_type>;
using SubviewTypeB = typename Kokkos::View<ScalarB**, Kokkos::LayoutStride, typename ViewTypeA::device_type>;
using SubviewTypeC = typename Kokkos::View<ScalarC**, Kokkos::LayoutStride, typename ViewTypeA::device_type>;

ScalarA alpha;
ScalarC beta;

KOKKOS_INLINE_FUNCTION
void operator()(const typename Kokkos::TeamPolicy<ExecutionSpace>::member_type& team) const {
int i = team.league_rank();
SubviewTypeA _A;
SubviewTypeB _B;
SubviewTypeC _C;

auto _A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL());
auto _B = Kokkos::subview(B, i, Kokkos::ALL(), Kokkos::ALL());
auto _C = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL());
if (batch_size_last_dim) {
_A = Kokkos::subview(A, Kokkos::ALL(), Kokkos::ALL(), i);
_B = Kokkos::subview(B, Kokkos::ALL(), Kokkos::ALL(), i);
_C = Kokkos::subview(C, Kokkos::ALL(), Kokkos::ALL(), i);
} else {
_A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL());
_B = Kokkos::subview(B, i, Kokkos::ALL(), Kokkos::ALL());
_C = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL());
}
using SubviewTypeA = decltype(_A);
using SubviewTypeB = decltype(_B);
using SubviewTypeC = decltype(_C);
struct SharedVanillaGEMM<SubviewTypeA,SubviewTypeB,SubviewTypeC,ExecutionSpace> vgemm;
vgemm.A_t = A_t; vgemm.B_t = B_t;
vgemm.A_c = A_c; vgemm.B_c = B_c;
Expand Down

0 comments on commit 4ea0e4c

Please sign in to comment.