diff --git a/test_common/KokkosKernels_TestUtils.hpp b/test_common/KokkosKernels_TestUtils.hpp index ad546fe0b4..43f2d48460 100644 --- a/test_common/KokkosKernels_TestUtils.hpp +++ b/test_common/KokkosKernels_TestUtils.hpp @@ -167,24 +167,29 @@ namespace Test { using ScalarA = typename ViewTypeA::value_type; using ScalarB = typename ViewTypeB::value_type; using ScalarC = typename ViewTypeC::value_type; + using SubviewTypeA = typename Kokkos::View; + using SubviewTypeB = typename Kokkos::View; + using SubviewTypeC = typename Kokkos::View; + ScalarA alpha; ScalarC beta; KOKKOS_INLINE_FUNCTION void operator()(const typename Kokkos::TeamPolicy::member_type& team) const { int i = team.league_rank(); + SubviewTypeA _A; + SubviewTypeB _B; + SubviewTypeC _C; - auto _A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL()); - auto _B = Kokkos::subview(B, i, Kokkos::ALL(), Kokkos::ALL()); - auto _C = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL()); if (batch_size_last_dim) { _A = Kokkos::subview(A, Kokkos::ALL(), Kokkos::ALL(), i); _B = Kokkos::subview(B, Kokkos::ALL(), Kokkos::ALL(), i); _C = Kokkos::subview(C, Kokkos::ALL(), Kokkos::ALL(), i); + } else { + _A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL()); + _B = Kokkos::subview(B, i, Kokkos::ALL(), Kokkos::ALL()); + _C = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL()); } - using SubviewTypeA = decltype(_A); - using SubviewTypeB = decltype(_B); - using SubviewTypeC = decltype(_C); struct SharedVanillaGEMM vgemm; vgemm.A_t = A_t; vgemm.B_t = B_t; vgemm.A_c = A_c; vgemm.B_c = B_c;