diff --git a/packages/kokkos-kernels/src/sparse/KokkosSparse_sptrsv_handle.hpp b/packages/kokkos-kernels/src/sparse/KokkosSparse_sptrsv_handle.hpp index 157c0e9e9245..33d064a4215e 100644 --- a/packages/kokkos-kernels/src/sparse/KokkosSparse_sptrsv_handle.hpp +++ b/packages/kokkos-kernels/src/sparse/KokkosSparse_sptrsv_handle.hpp @@ -98,6 +98,7 @@ class SPTRSVHandle { typedef typename Kokkos::View nnz_row_view_temp_t; typedef typename Kokkos::View nnz_row_view_t; typedef typename nnz_row_view_t::HostMirror host_nnz_row_view_t; + typedef typename Kokkos::View int_row_view_t; // typedef typename row_lno_persistent_work_view_t::HostMirror row_lno_persistent_work_host_view_t; //Host view type typedef typename Kokkos::View> nnz_row_unmanaged_view_t; // for rank1 subviews @@ -310,6 +311,7 @@ class SPTRSVHandle { #ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE SPTRSVcuSparseHandleType *cuSPARSEHandle; + int_row_view_t tmp_int_rowmap; #endif #ifdef KOKKOSKERNELS_ENABLE_SUPERNODAL_SPTRSV @@ -409,6 +411,7 @@ class SPTRSVHandle { require_symbolic_chain_phase(false) #ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE , cuSPARSEHandle(nullptr) + , tmp_int_rowmap() #endif #ifdef KOKKOSKERNELS_ENABLE_SUPERNODAL_SPTRSV , merge_supernodes (false) @@ -832,6 +835,28 @@ class SPTRSVHandle { SPTRSVcuSparseHandleType *get_cuSparseHandle(){ return this->cuSPARSEHandle; } + + void allocate_tmp_int_rowmap (size_type N) { + tmp_int_rowmap = int_row_view_t(Kokkos::ViewAllocateWithoutInitializing("tmp_int_rowmap"), N); + } + template + int_row_view_t get_int_rowmap_view_copy (const RowViewType & rowmap) { + Kokkos::deep_copy(tmp_int_rowmap, rowmap); + return tmp_int_rowmap; + } + template + int* get_int_rowmap_ptr_copy (const RowViewType & rowmap) { + Kokkos::deep_copy(tmp_int_rowmap, rowmap); + Kokkos::fence(); + return tmp_int_rowmap.data(); + } + int_row_view_t get_int_rowmap_view () { + return tmp_int_rowmap; + } + int* get_int_rowmap_ptr () { + return tmp_int_rowmap.data(); + } + #endif diff --git a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_sptrsv_cuSPARSE_impl.hpp b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_sptrsv_cuSPARSE_impl.hpp index 7c903b738cf7..592413534f10 100644 --- a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_sptrsv_cuSPARSE_impl.hpp +++ b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_sptrsv_cuSPARSE_impl.hpp @@ -92,9 +92,11 @@ namespace Impl{ int nnz = entries.extent_int(0); int pBufferSize; + if (!std::is_same::value) + sptrsv_handle->allocate_tmp_int_rowmap(row_map.extent(0)); + const int* rm = !std::is_same::value ? sptrsv_handle->get_int_rowmap_ptr_copy(row_map) : (const int*)row_map.data(); + const int* ent = entries.data(); const scalar_type* vals = values.data(); - const size_type* rm = row_map.data(); - const idx_type* ent = entries.data(); if (std::is_same::value) { cusparseDcsrsv2_bufferSize( @@ -221,9 +223,10 @@ namespace Impl{ int nnz = entries.extent_int(0); + //const size_type* rm = row_map.data(); + const int* rm = !std::is_same::value ? sptrsv_handle->get_int_rowmap_ptr() : (const int*)row_map.data(); + const int* ent = entries.data(); const scalar_type* vals = values.data(); - const size_type* rm = row_map.data(); - const idx_type* ent = entries.data(); const scalar_type* bv = rhs.data(); scalar_type* xv = lhs.data();