Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tpetra: use kokkos kernels BsrMatrix spmv #12103

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ struct BsrMatrixSpMVTensorCoreFunctorParams {
int leagueDim_y;
};

template <typename T>
struct is_scalar {
const static bool value =
std::is_same_v<std::remove_cv_t<T>, double> ||
std::is_same_v<std::remove_cv_t<T>, float> ||
std::is_same_v<std::remove_cv_t<T>, Kokkos::Experimental::half_t>;
};

/* True if types are all scalar. This excludes complex types and Sacado Vectors

*/
template <typename T1, typename T2, typename T3>
struct all_scalar {
const static bool value =
is_scalar<T1>::value && is_scalar<T2>::value && is_scalar<T3>::value;
};

/// \brief Functor for the BsrMatrix SpMV multivector implementation utilizing
/// tensor cores.
///
Expand Down Expand Up @@ -466,18 +483,20 @@ struct BsrMatrixSpMVTensorCoreDispatcher {
// to be used to avoid instantiating on unsupported types
static void tag_dispatch(std::false_type, YScalar, AMatrix, XMatrix, YScalar,
YMatrix) {
KokkosKernels::Impl::throw_runtime_exception(
"Tensor core SpMV is only supported for non-complex types in GPU "
"execution spaces");
}
const std::type_info &tia = typeid(AScalar);
const std::type_info &tix = typeid(XScalar);
const std::type_info &tiy = typeid(YScalar);

/*true if none of T1, T2, or T3 are complex*/
template <typename T1, typename T2, typename T3>
struct none_complex {
const static bool value = !Kokkos::ArithTraits<T1>::is_complex &&
!Kokkos::ArithTraits<T2>::is_complex &&
!Kokkos::ArithTraits<T3>::is_complex;
};
std::stringstream ss;

ss << "Tensor core SpMV is only supported for scalar types in GPU "
"execution spaces.";
ss << " AScalar was " << tia.name() << ".";
ss << " XScalar was " << tix.name() << ".";
ss << " YScalar was " << tiy.name() << ".";

KokkosKernels::Impl::throw_runtime_exception(ss.str());
}

/*true if T1::execution_space, T2, or T3 are all GPU exec space*/
template <typename T1, typename T2, typename T3>
Expand All @@ -491,7 +510,7 @@ struct BsrMatrixSpMVTensorCoreDispatcher {
YMatrix y) {
// tag will be false unless all conditions are met
using tag = std::integral_constant<
bool, none_complex<AScalar, XScalar, YScalar>::value &&
bool, all_scalar<AScalar, XScalar, YScalar>::value &&
all_gpu<typename AMatrix::execution_space,
typename XMatrix::execution_space,
typename YMatrix::execution_space>::value>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,12 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
if (controls.getParameter("algorithm") == ALG_TC)
method = Method::TensorCores;
// can't use tensor cores for complex
if (Kokkos::ArithTraits<YScalar>::is_complex) method = Method::Fallback;
if (Kokkos::ArithTraits<XScalar>::is_complex) method = Method::Fallback;
if (Kokkos::ArithTraits<AScalar>::is_complex) method = Method::Fallback;
if (!KokkosSparse::Experimental::Impl::is_scalar<YScalar>::value)
method = Method::Fallback;
if (!KokkosSparse::Experimental::Impl::is_scalar<XScalar>::value)
method = Method::Fallback;
if (!KokkosSparse::Experimental::Impl::is_scalar<AScalar>::value)
method = Method::Fallback;
// can't use tensor cores outside GPU
if (!KokkosKernels::Impl::kk_is_gpu_exec_space<
typename AMatrix::execution_space>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,11 @@ class ViewMapping< DstTraits , SrcTraits ,
typename DstTraits::array_layout(
dims[0] , dims[1] , dims[2] , dims[3] ,
dims[4] , dims[5] , dims[6] , dims[7] ) );
dst.m_impl_handle = src.m_impl_handle.scalar_ptr ;

// For CudaLDGFetch, which doesn't define operator=() for pointer RHS
// but does define a constructor
//dst.m_impl_handle = src.m_impl_handle.scalar_ptr ;
dst.m_impl_handle = typename DstType::handle_type(src.m_impl_handle.scalar_ptr);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@

namespace Stokhos {

namespace Impl {
// Remove MemoryRandomAccess memory trait from a given view
template <typename ViewType, typename Enabled = void>
class RemoveRandomAccess {
public:
typedef ViewType type;
};
template <typename ViewType>
class RemoveRandomAccess<
ViewType,
std::enable_if_t<ViewType::memory_traits::is_random_access> > {
public:
static constexpr unsigned M0 = ViewType::memory_traits::impl_value;
static constexpr unsigned M1 =
M0 & (Kokkos::Unmanaged | Kokkos::Atomic | Kokkos::Restrict | Kokkos::Aligned);
typedef Kokkos::View<typename ViewType::data_type,
typename ViewType::array_layout,
typename ViewType::device_type,
Kokkos::MemoryTraits<M1> > type;
};
}

//----------------------------------------------------------------------------
// Specialization of KokkosSparse::CrsMatrix for Sacado::UQ::PCE scalar type
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -107,8 +129,8 @@ class Multiply< KokkosSparse::CrsMatrix< const Sacado::UQ::PCE<MatrixStorage>,

typedef typename matrix_type::StaticCrsGraphType matrix_graph_type;
typedef typename matrix_values_type::array_type matrix_array_type;
typedef typename input_vector_type::array_type input_array_type;
typedef typename output_vector_type::array_type output_array_type;
typedef typename Impl::RemoveRandomAccess< typename input_vector_type::array_type >::type input_array_type;
typedef typename Impl::RemoveRandomAccess< typename output_vector_type::array_type >::type output_array_type;

typedef typename MatrixValue::value_type matrix_scalar;
typedef typename InputVectorValue::value_type input_scalar;
Expand Down Expand Up @@ -504,8 +526,8 @@ class Multiply< KokkosSparse::CrsMatrix< const Sacado::UQ::PCE<MatrixStorage>,

typedef typename matrix_type::StaticCrsGraphType matrix_graph_type;
typedef typename matrix_values_type::array_type matrix_array_type;
typedef typename input_vector_type::array_type input_array_type;
typedef typename output_vector_type::array_type output_array_type;
typedef typename Impl::RemoveRandomAccess< typename input_vector_type::array_type >::type input_array_type;
typedef typename Impl::RemoveRandomAccess< typename output_vector_type::array_type >::type output_array_type;

typedef typename MatrixValue::value_type matrix_scalar;
typedef typename InputVectorValue::value_type input_scalar;
Expand Down Expand Up @@ -1043,8 +1065,8 @@ class MeanMultiply< KokkosSparse::CrsMatrix< const Sacado::UQ::PCE<MatrixStorage
struct BlockKernel {
typedef typename MatrixDevice::execution_space execution_space;
typedef typename Kokkos::FlatArrayType<matrix_values_type>::type matrix_array_type;
typedef typename input_vector_type::array_type input_array_type;
typedef typename output_vector_type::array_type output_array_type;
typedef typename Impl::RemoveRandomAccess< typename input_vector_type::array_type >::type input_array_type;
typedef typename Impl::RemoveRandomAccess< typename output_vector_type::array_type >::type output_array_type;

const matrix_array_type m_A_values ;
const matrix_graph_type m_A_graph ;
Expand Down Expand Up @@ -1166,8 +1188,8 @@ class MeanMultiply< KokkosSparse::CrsMatrix< const Sacado::UQ::PCE<MatrixStorage
struct Kernel {
typedef typename MatrixDevice::execution_space execution_space;
typedef typename Kokkos::FlatArrayType<matrix_values_type>::type matrix_array_type;
typedef typename input_vector_type::array_type input_array_type;
typedef typename output_vector_type::array_type output_array_type;
typedef typename Impl::RemoveRandomAccess< typename input_vector_type::array_type >::type input_array_type;
typedef typename Impl::RemoveRandomAccess< typename output_vector_type::array_type >::type output_array_type;

const matrix_array_type m_A_values ;
const matrix_graph_type m_A_graph ;
Expand Down
Loading