diff --git a/lapack/impl/KokkosLapack_geqrf_spec.hpp b/lapack/impl/KokkosLapack_geqrf_spec.hpp index 6970c6dd2c..5410520c1c 100644 --- a/lapack/impl/KokkosLapack_geqrf_spec.hpp +++ b/lapack/impl/KokkosLapack_geqrf_spec.hpp @@ -53,7 +53,8 @@ struct geqrf_eti_spec_avail { Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ }; @@ -78,7 +79,6 @@ struct GEQRF { }; #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY -//! Full specialization of geqrf for multi vectors. // Unification layer template struct GEQRF, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>, \ false, true>; @@ -128,7 +129,8 @@ struct GEQRF, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>, \ false, true>; diff --git a/lapack/src/KokkosLapack_geqrf.hpp b/lapack/src/KokkosLapack_geqrf.hpp index 506d3c60b7..1d26747cd3 100644 --- a/lapack/src/KokkosLapack_geqrf.hpp +++ b/lapack/src/KokkosLapack_geqrf.hpp @@ -15,7 +15,7 @@ //@HEADER /// \file KokkosLapack_geqrf.hpp -/// \brief Local dense linear solve +/// \brief QR factorization /// /// This file provides KokkosLapack::geqrf. This function performs a /// local (no MPI) QR factorization of a M-by-N matrix A. @@ -118,31 +118,33 @@ int geqrf(const ExecutionSpace& space, const AMatrix& A, const TWArray& Tau, } } - typedef Kokkos::View< + using RetArray = Kokkos::View; + RetArray rc("rc", 1); + + using AMatrix_Internal = Kokkos::View< typename AMatrix::non_const_value_type**, typename AMatrix::array_layout, - typename AMatrix::device_type, Kokkos::MemoryTraits > - AMatrix_Internal; - typedef Kokkos::View< + typename AMatrix::device_type, Kokkos::MemoryTraits>; + using TWArray_Internal = Kokkos::View< typename TWArray::non_const_value_type*, typename TWArray::array_layout, - typename TWArray::device_type, Kokkos::MemoryTraits > - TWArray_Internal; - AMatrix_Internal A_i = A; - TWArray_Internal Tau_i = Tau; - TWArray_Internal Work_i = Work; - - // This is the return value type and should always reside on host - using RViewInternalType = - Kokkos::View >; + typename TWArray::device_type, Kokkos::MemoryTraits>; + using RetArray_Internal = Kokkos::View< + int*, typename TWArray::array_layout, + typename TWArray::device_type, Kokkos::MemoryTraits>; - int result; - RViewInternalType R = RViewInternalType(&result); + AMatrix_Internal A_i = A; + TWArray_Internal Tau_i = Tau; + TWArray_Internal Work_i = Work; + RetArray_Internal rc_i = rc; KokkosLapack::Impl::GEQRF::geqrf(space, A_i, Tau_i, Work_i, - R); + RetArray_Internal>::geqrf(space, A_i, Tau_i, Work_i, + rc_i); + + typename RetArray_Internal::HostMirror h_rc = Kokkos::create_mirror_view(rc_i); + + Kokkos::deep_copy(h_rc, rc_i); - return result; + return h_rc[0]; } /// \brief Computes a QR factorization of a matrix A diff --git a/lapack/tpls/KokkosLapack_geqrf_tpl_spec_avail.hpp b/lapack/tpls/KokkosLapack_geqrf_tpl_spec_avail.hpp index f291bbe2a8..cc6f1e78a4 100644 --- a/lapack/tpls/KokkosLapack_geqrf_tpl_spec_avail.hpp +++ b/lapack/tpls/KokkosLapack_geqrf_tpl_spec_avail.hpp @@ -36,7 +36,7 @@ struct geqrf_tpl_spec_avail { Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ }; @@ -95,7 +95,7 @@ namespace Impl { Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ }; @@ -142,7 +142,7 @@ namespace Impl { Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ }; diff --git a/lapack/tpls/KokkosLapack_geqrf_tpl_spec_decl.hpp b/lapack/tpls/KokkosLapack_geqrf_tpl_spec_decl.hpp index 056eef24da..fe25ce19a0 100644 --- a/lapack/tpls/KokkosLapack_geqrf_tpl_spec_decl.hpp +++ b/lapack/tpls/KokkosLapack_geqrf_tpl_spec_decl.hpp @@ -54,21 +54,21 @@ void lapackGeqrfWrapper(const AViewType& A, const TWViewType& Tau, using ALayout_t = typename AViewType::array_layout; static_assert(std::is_same_v, "KokkosLapack - geqrf: A needs to have a Kokkos::LayoutLeft"); - const int M = A.extent_int(0); - const int N = A.extent_int(1); - const int LDA = A.stride(1); - const int LWORK = static_cast(Work.extent(0)); + const int m = A.extent_int(0); + const int n = A.extent_int(1); + const int lda = A.stride(1); + const int lwork = static_cast(Work.extent(0)); if constexpr (Kokkos::ArithTraits::is_complex) { using MagType = typename Kokkos::ArithTraits::mag_type; - R() = HostLapack>::geqrf( - M, N, reinterpret_cast*>(A.data()), LDA, + R[0] = HostLapack>::geqrf( + m, n, reinterpret_cast*>(A.data()), lda, reinterpret_cast*>(Tau.data()), - reinterpret_cast*>(Work.data()), LWORK); + reinterpret_cast*>(Work.data()), lwork); } else { - R() = HostLapack::geqrf(M, N, A.data(), LDA, Tau.data(), - Work.data(), LWORK); + R[0] = HostLapack::geqrf(m, n, A.data(), lda, Tau.data(), + Work.data(), lwork); } } @@ -80,7 +80,7 @@ void lapackGeqrfWrapper(const AViewType& A, const TWViewType& Tau, Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>, \ true, \ geqrf_eti_spec_avail< \ @@ -89,7 +89,7 @@ void lapackGeqrfWrapper(const AViewType& A, const TWViewType& Tau, Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>>::value> { \ using AViewType = \ Kokkos::View, \ @@ -97,7 +97,7 @@ void lapackGeqrfWrapper(const AViewType& A, const TWViewType& Tau, using TWViewType = \ Kokkos::View, \ Kokkos::MemoryTraits>; \ - using RType = Kokkos::View, \ Kokkos::MemoryTraits>; \ \ static void geqrf(const EXECSPACE& /* space */, const AViewType& A, \ @@ -255,87 +255,68 @@ KOKKOSLAPACK_GEQRF_MAGMA(Kokkos::complex, Kokkos::LayoutLeft, namespace KokkosLapack { namespace Impl { -template -void cusolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, - const AViewType& A, const TWViewType& Tau) { +template +void cusolverGeqrfWrapper(const ExecutionSpace& space, const AViewType& A, + const TWViewType& /* Work */, const TWViewType& Tau, + class RType& R) { + using memory_space = typename AViewType::memory_space; - using Scalar = typename TWViewType::non_const_value_type; - using ALayout_t = typename AViewType::array_layout; - using BLayout_t = typename TWViewType::array_layout; + using Scalar = typename AViewType::non_const_value_type; + using ALayout_t = typename AViewType::array_layout; + static_assert(std::is_same_v, + "KokkosLapack - cusolver geqrf: A needs to have a Kokkos::LayoutLeft"); const int m = A.extent_int(0); const int n = A.extent_int(1); - const int lda = std::is_same_v ? A.stride(0) - : A.stride(1); - - (void)B; - - const int nrhs = B.extent_int(1); - const int ldb = std::is_same_v ? B.stride(0) - : B.stride(1); + const int lda = A.stride(1); int lwork = 0; - Kokkos::View info("getrf info"); + + //Kokkos::View info("cusolver geqrf info"); CudaLapackSingleton& s = CudaLapackSingleton::singleton(); KOKKOS_CUSOLVER_SAFE_CALL_IMPL( cusolverDnSetStream(s.handle, space.cuda_stream())); if constexpr (std::is_same_v) { KOKKOS_CUSOLVER_SAFE_CALL_IMPL( - cusolverDnSgetrf_bufferSize(s.handle, m, n, A.data(), lda, &lwork)); - Kokkos::View Workspace("getrf workspace", lwork); - - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnSgetrf(s.handle, m, n, A.data(), - lda, Workspace.data(), - IPIV.data(), info.data())); + cusolverDnSgeqrf_bufferSize(s.handle, m, n, A.data(), lda, &lwork)); + Kokkos::View Workspace("cusolver sgeqrf workspace", lwork); - KOKKOS_CUSOLVER_SAFE_CALL_IMPL( - cusolverDnSgetrs(s.handle, CUBLAS_OP_N, m, nrhs, A.data(), lda, - IPIV.data(), B.data(), ldb, info.data())); + KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnSgeqrf(s.handle, m, n, A.data(), + lda, Tau.data(), + Workspace.data(), lwork, /*info*/R.data())); } if constexpr (std::is_same_v) { KOKKOS_CUSOLVER_SAFE_CALL_IMPL( - cusolverDnDgetrf_bufferSize(s.handle, m, n, A.data(), lda, &lwork)); - Kokkos::View Workspace("getrf workspace", lwork); - - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnDgetrf(s.handle, m, n, A.data(), - lda, Workspace.data(), - IPIV.data(), info.data())); + cusolverDnDgeqrf_bufferSize(s.handle, m, n, A.data(), lda, &lwork)); + Kokkos::View Workspace("cusolver dgeqrf workspace", lwork); - KOKKOS_CUSOLVER_SAFE_CALL_IMPL( - cusolverDnDgetrs(s.handle, CUBLAS_OP_N, m, nrhs, A.data(), lda, - IPIV.data(), B.data(), ldb, info.data())); + KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnDgeqrf(s.handle, m, n, A.data(), + lda, Tau.data(), + Workspace.data(), lwork, /*info*/R.data())); } if constexpr (std::is_same_v>) { - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnCgetrf_bufferSize( + KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnCgeqrf_bufferSize( s.handle, m, n, reinterpret_cast(A.data()), lda, &lwork)); - Kokkos::View Workspace("getrf workspace", lwork); + Kokkos::View Workspace("cusolver cgeqrf workspace", lwork); KOKKOS_CUSOLVER_SAFE_CALL_IMPL( - cusolverDnCgetrf(s.handle, m, n, reinterpret_cast(A.data()), - lda, reinterpret_cast(Workspace.data()), - IPIV.data(), info.data())); - - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnCgetrs( - s.handle, CUBLAS_OP_N, m, nrhs, reinterpret_cast(A.data()), - lda, IPIV.data(), reinterpret_cast(B.data()), ldb, - info.data())); + cusolverDnCgeqrf(s.handle, m, n, reinterpret_cast(A.data()), lda, + reinterpret_cast(Tau.data()), + reinterpret_cast(Workspace.data()), + lwork, /*info*/R.data())); } if constexpr (std::is_same_v>) { - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnZgetrf_bufferSize( + KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnZgeqrf_bufferSize( s.handle, m, n, reinterpret_cast(A.data()), lda, &lwork)); - Kokkos::View Workspace("getrf workspace", + Kokkos::View Workspace("cusolver zgeqrf workspace", lwork); - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnZgetrf( + KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnZgeqrf( s.handle, m, n, reinterpret_cast(A.data()), lda, - reinterpret_cast(Workspace.data()), IPIV.data(), - info.data())); - - KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnZgetrs( - s.handle, CUBLAS_OP_N, m, nrhs, - reinterpret_cast(A.data()), lda, IPIV.data(), - reinterpret_cast(B.data()), ldb, info.data())); + reinterpret_cast(Tau.data()), + reinterpret_cast(Workspace.data()), + lwork, /*info*/R.data())); } KOKKOS_CUSOLVER_SAFE_CALL_IMPL(cusolverDnSetStream(s.handle, NULL)); } @@ -348,6 +329,8 @@ void cusolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ + Kokkos::View, \ + Kokkos::MemoryTraits>, \ true, \ geqrf_eti_spec_avail< \ Kokkos::Cuda, \ @@ -355,6 +338,9 @@ void cusolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, Kokkos::Device, \ Kokkos::MemoryTraits>, \ Kokkos::View, \ + Kokkos::MemoryTraits>, \ + Kokkos::View, \ Kokkos::MemoryTraits>>::value> { \ using AViewType = Kokkos::View>; \ using TWViewType = \ Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using RType = \ + Kokkos::View, \ Kokkos::MemoryTraits>; \ \ static void geqrf(const Kokkos::Cuda& space, const AViewType& A, \ - const TWViewType& Tau, const TWViewType& Work) { \ + const TWViewType& Tau, const TWViewType& Work, \ + const RType& R) { \ Kokkos::Profiling::pushRegion( \ "KokkosLapack::geqrf[TPL_CUSOLVER," #SCALAR "]"); \ geqrf_print_specialization(); \ \ - cusolverGeqrfWrapper(space, IPIV, A, B); \ + cusolverGeqrfWrapper(space, A, Tau, Work, R); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -420,7 +410,7 @@ void rocsolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, const rocblas_int ldb = std::is_same_v ? B.stride(0) : B.stride(1); - Kokkos::View info("rocsolver info"); + Kokkos::View info("rocsolver geqrf info"); KokkosBlas::Impl::RocBlasSingleton& s = KokkosBlas::Impl::RocBlasSingleton::singleton(); @@ -459,6 +449,8 @@ void rocsolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ + Kokkos::View>, \ true, \ geqrf_eti_spec_avail< \ Kokkos::HIP, \ @@ -467,21 +459,28 @@ void rocsolverGeqrfWrapper(const ExecutionSpace& space, const TWViewType& Work, Kokkos::MemoryTraits>, \ Kokkos::View, \ + Kokkos::MemoryTraits>, \ + Kokkos::View, \ Kokkos::MemoryTraits>>::value> { \ using AViewType = \ Kokkos::View, \ Kokkos::MemoryTraits>; \ using TWViewType = \ Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using RType = \ + Kokkos::View, \ Kokkos::MemoryTraits>; \ \ static void geqrf(const Kokkos::HIP& space, const AViewType& A, \ - const TWViewType& Tau, const TWViewType& Work) { \ + const TWViewType& Tau, const TWViewType& Work, \ + const RType& R) { \ Kokkos::Profiling::pushRegion( \ "KokkosLapack::geqrf[TPL_ROCSOLVER," #SCALAR "]"); \ geqrf_print_specialization(); \ \ - rocsolverGeqrfWrapper(space, IPIV, A, B); \ + rocsolverGeqrfWrapper(space, A, Tau, Work, R); \ Kokkos::Profiling::popRegion(); \ } \ };