Skip to content

Commit

Permalink
#9 Include scalar type in templates
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Jun 30, 2021
1 parent e2db00d commit ef87ab2
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions src/sparse/impl/KokkosSparse_spmv_impl_block_crs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ struct BSPMV_Functor {
//
constexpr size_t bmax = 12;

using Scalar = default_scalar;
using Ordinal = default_lno_t;
using Offset = default_size_type;
using Layout = default_layout;
Expand All @@ -254,16 +253,11 @@ using device_type = typename Kokkos::Device<
Kokkos::DefaultExecutionSpace,
typename Kokkos::DefaultExecutionSpace::memory_space>;

using crs_matrix_t_ =
typename KokkosSparse::CrsMatrix<Scalar, Ordinal, device_type, void,
Offset>;

using values_type = typename crs_matrix_t_::values_type;
/***********************************/

#ifdef KOKKOS_ENABLE_SERIAL

template <int M>
template <int M, typename Scalar>
inline void spmv_serial_gemv(const Scalar *Aval, const Ordinal lda,
const Scalar *x_ptr,
std::array<Scalar, Impl::bmax> &y) {
Expand All @@ -278,16 +272,17 @@ inline void spmv_serial_gemv(const Scalar *Aval, const Ordinal lda,
//
// Explicit blockSize=N case
//
template <class StaticGraph, Ordinal blockSize>
template <Ordinal blockSize>
struct SpMV_SerialNoTranspose {

template <typename Scalar, class StaticGraph>
static inline void spmv(const Scalar alpha, const Scalar *Avalues,
const StaticGraph &Agraph, const Scalar *x,
Scalar *y, const Ordinal bs) {

assert(blockSize == bs);
const Ordinal numBlockRows = Agraph.numRows();
std::array<double, Impl::bmax> tmp{0};
std::array<Scalar, Impl::bmax> tmp{0};
const Ordinal blockSize2 = blockSize * blockSize;
for (Ordinal iblock = 0; iblock < numBlockRows; ++iblock) {
const auto jbeg = Agraph.row_map[iblock];
Expand All @@ -313,8 +308,10 @@ struct SpMV_SerialNoTranspose {
//
// Special blockSize=1 case (optimized)
//
template <class StaticGraph>
struct SpMV_SerialNoTranspose<StaticGraph, 1> {
template<>
struct SpMV_SerialNoTranspose<1> {

template <typename Scalar, class StaticGraph>
static inline void spmv(const Scalar alpha, const Scalar *Avalues,
const StaticGraph &Agraph, const Scalar *x,
Scalar *y, const Ordinal blockSize) {
Expand All @@ -324,7 +321,7 @@ struct SpMV_SerialNoTranspose<StaticGraph, 1> {
for (Ordinal i = 0; i < numBlockRows; ++i) {
const auto jbeg = Agraph.row_map[i];
const auto jend = Agraph.row_map[i + 1];
double tmp = 0.0;
Scalar tmp = 0.0;
for (Ordinal j = jbeg; j < jend; ++j) {
const auto alpha_value1 = alpha * Avalues[j];
const auto col_idx1 = Agraph.entries[j];
Expand All @@ -340,8 +337,10 @@ struct SpMV_SerialNoTranspose<StaticGraph, 1> {
//
// --- Basic approach for large block sizes
//
template <class StaticGraph>
struct SpMV_SerialNoTranspose<StaticGraph, 0> {

template<>
struct SpMV_SerialNoTranspose<0> {
template <typename Scalar, class StaticGraph>
static inline void spmv(const Scalar alpha, const Scalar *Avalues,
const StaticGraph &A_graph, const Scalar *x,
Scalar *y, const Ordinal blockSize) {
Expand Down Expand Up @@ -588,7 +587,7 @@ void spMatVec_no_transpose(KokkosKernels::Experimental::Controls controls,
if (std::is_same< typename AMatrix_Internal::device_type::execution_space,
Kokkos::Serial>::value) {
Utils::eti_expand<Impl::bmax>(blockSize, [&]<int fixedBlockSize>(const int blockSize) {
SpMV_SerialNoTranspose<decltype(A_graph), fixedBlockSize>::spmv(alpha, &A_internal.values[0], A_graph, &x[0], &y[0], blockSize);
SpMV_SerialNoTranspose<fixedBlockSize>::spmv((AT)alpha, &A_internal.values[0], A_graph, &x[0], &y[0], blockSize);
});
return;
}
Expand Down Expand Up @@ -689,7 +688,7 @@ struct BSPMV_Transpose_Functor {
? ATV::conj(Aval_ptr[kr + ic * block_size])
: Aval_ptr[kr + ic * block_size];
Kokkos::atomic_add(&yvec[kr],
static_cast<Scalar>(alpha * val * xvalue));
static_cast<value_type>(alpha * val * xvalue));
}
}
}
Expand Down Expand Up @@ -796,7 +795,7 @@ void bspmv_raw_openmp_transpose(typename YVector::const_value_type& s_a,
const auto xvalue = xvec[ic];
for (Ordinal kr = 0; kr < blockSize; ++kr) {
Kokkos::atomic_add( &yvec[kr],
static_cast<Scalar>(s_a * Aval_ptr[ic + kr * blockSize] * xvalue) );
static_cast<value_type>(s_a * Aval_ptr[ic + kr * blockSize] * xvalue) );
}
}
}
Expand Down Expand Up @@ -908,7 +907,7 @@ void spMatVec_transpose(
/* ******************* */


template <int M>
template <int M, typename Scalar>
inline void spmv_transpose_gemv(const Scalar alpha, const Scalar *Aval,
const Ordinal lda, const Ordinal xrow,
const Scalar *x_ptr, Scalar *y) {
Expand All @@ -924,7 +923,7 @@ inline void spmv_transpose_gemv(const Scalar alpha, const Scalar *Aval,
//
// Explicit blockSize=N case
//
template <class StaticGraph, Ordinal blockSize>
template <typename Scalar, class StaticGraph, Ordinal blockSize>
struct SpMV_SerialTranspose {

static inline void spmv(const Scalar alpha, const Scalar *Avalues,
Expand Down Expand Up @@ -952,8 +951,8 @@ struct SpMV_SerialTranspose {
//
// Special blockSize=1 case (optimized)
//
template <class StaticGraph>
struct SpMV_SerialTranspose<StaticGraph, 1> {
template <typename Scalar, class StaticGraph>
struct SpMV_SerialTranspose<Scalar, StaticGraph, 1> {

static inline void spmv(const Scalar alpha, const Scalar *Avalues,
const StaticGraph &Agraph, const Scalar *x,
Expand All @@ -976,8 +975,8 @@ struct SpMV_SerialTranspose<StaticGraph, 1> {
//
// --- Basic approach for large block sizes
//
template <class StaticGraph>
struct SpMV_SerialTranspose<StaticGraph, 0> {
template <typename Scalar, class StaticGraph>
struct SpMV_SerialTranspose<Scalar, StaticGraph, 0> {

static inline void spmv(const Scalar alpha, const Scalar *A_values,
const StaticGraph &A_graph, const Scalar *x,
Expand Down Expand Up @@ -1065,7 +1064,7 @@ void spMatVec_transpose(KokkosKernels::Experimental::Controls controls,
if (std::is_same< typename AMatrix_Internal::device_type::execution_space,
Kokkos::Serial>::value) {
Utils::eti_expand<Impl::bmax>(blockSize, [&]<int N>(const int blockSize) {
SpMV_SerialTranspose<decltype(A_graph), N>::spmv(alpha, &A_internal.values[0], A_graph, &x[0], &y[0], blockSize);
SpMV_SerialTranspose<AT, decltype(A_graph), N>::spmv((AT)alpha, &A_internal.values[0], A_graph, &x[0], &y[0], blockSize);
});
return;
}
Expand Down

0 comments on commit ef87ab2

Please sign in to comment.