Skip to content

Commit

Permalink
CRS: Use Kokkos device function macros rather than duplicating code w…
Browse files Browse the repository at this point in the history
…hen compiling for GPU targets
  • Loading branch information
PhilMiller committed Jan 31, 2023
1 parent 52586ef commit 71e0eca
Showing 1 changed file with 2 additions and 38 deletions.
40 changes: 2 additions & 38 deletions core/src/Kokkos_Crs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ struct CountAndFillBase {
Functor m_functor;
counts_type m_counts;
struct Count {};
inline void operator()(Count, size_type i) const {
KOKKOS_FUNCTION void operator()(Count, size_type i) const {
m_counts(i) = m_functor(i, nullptr);
}
struct Fill {};
inline void operator()(Fill, size_type i) const {
KOKKOS_FUNCTION void operator()(Fill, size_type i) const {
auto j = m_crs.row_map(i);
/* we don't want to access entries(entries.size()), even if its just to get
its address and never use it. this can happen when row (i) is empty and
Expand All @@ -323,42 +323,6 @@ struct CountAndFillBase {
CountAndFillBase(CrsType& crs, Functor const& f) : m_crs(crs), m_functor(f) {}
};

#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
#if defined(KOKKOS_ENABLE_CUDA)
#define EXEC_SPACE Kokkos::Cuda
#elif defined(KOKKOS_ENABLE_HIP)
#define EXEC_SPACE Kokkos::HIP
#endif
template <class CrsType, class Functor>
struct CountAndFillBase<CrsType, Functor, EXEC_SPACE> {
using data_type = typename CrsType::data_type;
using size_type = typename CrsType::size_type;
using row_map_type = typename CrsType::row_map_type;
using counts_type = row_map_type;
CrsType m_crs;
Functor m_functor;
counts_type m_counts;
struct Count {};
__device__ inline void operator()(Count, size_type i) const {
m_counts(i) = m_functor(i, nullptr);
}
struct Fill {};
__device__ inline void operator()(Fill, size_type i) const {
auto j = m_crs.row_map(i);
/* we don't want to access entries(entries.size()), even if its just to get
its address and never use it. this can happen when row (i) is empty and
all rows after it are also empty. we could compare to row_map(i + 1), but
that is a read from global memory, whereas dimension_0() should be part
of the View in registers (or constant memory) */
data_type* fill = (j == static_cast<decltype(j)>(m_crs.entries.extent(0)))
? nullptr
: (&(m_crs.entries(j)));
m_functor(i, fill);
}
CountAndFillBase(CrsType& crs, Functor const& f) : m_crs(crs), m_functor(f) {}
};
#endif

template <class CrsType, class Functor>
struct CountAndFill : public CountAndFillBase<CrsType, Functor> {
using base_type = CountAndFillBase<CrsType, Functor>;
Expand Down

0 comments on commit 71e0eca

Please sign in to comment.