Skip to content

Commit

Permalink
Merge Pull Request #6583 from mhoemmen/Trilinos/TSQR-Jan2020
Browse files Browse the repository at this point in the history
Automatically Merged using Trilinos Pull Request AutoTester
PR Title: TSQR: Automatically detect CUBLAS & CUSOLVER TPLs; improve TPL handle wrappers
PR Author: mhoemmen
  • Loading branch information
trilinos-autotester authored Jan 16, 2020
2 parents a679309 + f01fcb3 commit 12e7efa
Show file tree
Hide file tree
Showing 19 changed files with 241 additions and 131 deletions.
5 changes: 5 additions & 0 deletions packages/tpetra/tsqr/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ SET(LIB_REQUIRED_DEP_TPLS)
SET(LIB_OPTIONAL_DEP_TPLS CUBLAS CUSOLVER)
SET(TEST_REQUIRED_DEP_TPLS)
SET(TEST_OPTIONAL_DEP_TPLS)

IF(TPL_ENABLE_CUDA)
TRIBITS_TPL_TENTATIVELY_ENABLE(CUBLAS)
TRIBITS_TPL_TENTATIVELY_ENABLE(CUSOLVER)
ENDIF()
2 changes: 1 addition & 1 deletion packages/tpetra/tsqr/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ TRIBITS_ADD_LIBRARY(
# / from this directory, or to / from the 'impl' subdirectory. That ensures
# that running "make" will also rerun CMake in order to regenerate Makefiles.
#
# Behold: another such change, and another.
# Behold: Another such change that I hath wrought, and another.
#
2 changes: 0 additions & 2 deletions packages/tpetra/tsqr/src/Tsqr_CombineDefault.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ namespace TSQR {
const ordinal_type num_cols_Q,
const ordinal_type num_cols_C) const override
{
using STS = Teuchos::ScalarTraits<Scalar>;

const int ncols = num_cols_Q < num_cols_C ?
num_cols_C : num_cols_Q;
const int nrows = num_rows_Q + ncols;
Expand Down
74 changes: 40 additions & 34 deletions packages/tpetra/tsqr/src/Tsqr_CuSolverNodeTsqr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,8 @@ namespace TSQR {
Scalar A[],
const LocalOrdinal lda) const
{
using TSQR::Impl::CuSolver;
using TSQR::Impl::CuSolverHandle;

auto info = get_info ();
CuSolver<Scalar> solver
{CuSolverHandle::getSingleton (), info.data ()};
TSQR::Impl::CuSolver<Scalar> solver (info.data ());
const int lwork =
solver.compute_QR_lwork (numRows, numCols, A, lda);
// Avoid constant reallocation by setting a minimum lwork.
Expand All @@ -469,12 +465,8 @@ namespace TSQR {
Scalar C[],
const LocalOrdinal ldc) const
{
using TSQR::Impl::CuSolver;
using TSQR::Impl::CuSolverHandle;

auto info = get_info ();
CuSolver<Scalar> solver
{CuSolverHandle::getSingleton (), info.data ()};
TSQR::Impl::CuSolver<Scalar> solver (info.data ());
const char side = 'L';
const char trans = apply_type.toString ()[0];
const int lwork =
Expand Down Expand Up @@ -622,26 +614,48 @@ namespace TSQR {
}
else { // A_view_top is NOT contiguous
// Packed device version of R.
Impl::device_mat_view_type<kokkos_value_type> R_copy;
Impl::device_mat_view_type<kokkos_value_type> R_contig_d;
try {
using Impl::get_contiguous_device_mat_view;
R_copy = get_contiguous_device_mat_view (matrixStorage_,
ncols, ncols);
R_contig_d = get_contiguous_device_mat_view (matrixStorage_,
ncols, ncols);
}
TSQR_IMPL_CATCH( "R_copy = get_contiguous_device_mat_view threw: " );
TSQR_IMPL_CATCH( "R_contig_d = get_contiguous_device_mat_view threw: " );

TEUCHOS_ASSERT( size_t (R_copy.extent (0)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_copy.extent (1)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_copy.stride (1)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_contig_d.extent (0)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_contig_d.extent (1)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_contig_d.stride (1)) == size_t (ncols) );

try {
Kokkos::deep_copy (R_copy, A_view_top);
Kokkos::deep_copy (R_contig_d, A_view_top);
}
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_copy, A_view_top) threw: ");
try {
Kokkos::deep_copy (R_view, R_copy);
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_contig_d, A_view_top) threw: ");

if (R_view.extent (0) < R_view.stride (1)) {
// R_view is not contiguous, so we can't deep_copy directly
// from R_contig_d (device View) to R_view (host View). We
// need an intermediate contiguous host View, R_contig_h.
auto R_contig_h =
Impl::get_contiguous_host_mat_view (hostMatrixStorage_,
ncols, ncols);
TEUCHOS_ASSERT( size_t (R_contig_h.extent (0)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_contig_h.extent (1)) == size_t (ncols) );
TEUCHOS_ASSERT( size_t (R_contig_h.stride (1)) == size_t (ncols) );
try {
Kokkos::deep_copy (R_contig_h, R_contig_d);
}
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_contig_h, R_contig_d) threw: ");
try {
Kokkos::deep_copy (R_view, R_contig_h);
}
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_view, R_contig_h) threw: ");
}
else { // R_view is contiguous, so we can deep_copy directly
try {
Kokkos::deep_copy (R_view, R_contig_d);
}
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_view, R_contig_d) threw: ");
}
TSQR_IMPL_CATCH( "Kokkos::deep_copy(R_view, R_copy) threw: ");
}

try {
Expand Down Expand Up @@ -682,11 +696,7 @@ namespace TSQR {
const int lwork (work.extent (0));
auto info = get_info ();

using TSQR::Impl::CuSolver;
using TSQR::Impl::CuSolverHandle;
CuSolver<Scalar> solver
{CuSolverHandle::getSingleton (), info.data ()};

TSQR::Impl::CuSolver<Scalar> solver (info.data ());
TSQR_IMPL_CHECK_LAST_CUDA_ERROR( "TSQR::CuSolverNodeTsqr::factor, "
"before solver.compute_QR" );
try {
Expand Down Expand Up @@ -793,10 +803,7 @@ namespace TSQR {
const int lwork (work.extent (0));
auto info = get_info ();

using TSQR::Impl::CuSolver;
using TSQR::Impl::CuSolverHandle;
CuSolver<Scalar> solver
{CuSolverHandle::getSingleton (), info.data ()};
TSQR::Impl::CuSolver<Scalar> solver (info.data ());
solver.apply_Q_factor (side, trans,
nrows, ncols_C, ncols_Q,
Q, ldq, tau_raw,
Expand Down Expand Up @@ -994,9 +1001,8 @@ namespace TSQR {
constexpr Scalar ZERO {};
constexpr Scalar ONE (1.0);

using TSQR::Impl::CuBlas;
using TSQR::Impl::CuBlasHandle;
CuBlas<Scalar> blas {CuBlasHandle::getSingleton ()};
using Impl::CuBlas;
CuBlas<Scalar> blas;

const char transa = 'N';
const char transb = 'N';
Expand Down
9 changes: 6 additions & 3 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ class RawCuBlas<CudaValue<std::complex<float>>::type> {
#endif // defined(HAVE_TPETRATSQR_COMPLEX)

template<class Scalar>
CuBlas<Scalar>::CuBlas (CuBlasHandle handle) :
CuBlas<Scalar>::CuBlas () :
handle_ (getCuBlasHandleSingleton ()) {}

template<class Scalar>
CuBlas<Scalar>::CuBlas (const std::shared_ptr<CuBlasHandle>& handle) :
handle_ (handle) {}

template<class Scalar>
Expand All @@ -112,8 +116,7 @@ gemm (const char transa,
const Scalar beta,
Scalar* C, const int ldc)
{
auto rawHandle =
reinterpret_cast<cublasHandle_t> (handle_.getHandle ());
cublasHandle_t rawHandle = handle_->getHandle ();
const cublasOperation_t cuTransa = cuBlasTrans (transa);
const cublasOperation_t cuTransb = cuBlasTrans (transb);

Expand Down
8 changes: 5 additions & 3 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuBlas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "TpetraTSQR_config.h"
#if defined(HAVE_TPETRATSQR_CUBLAS)
# include "Tsqr_Impl_CuBlasHandle.hpp"
# include "Tsqr_Impl_CuBlasHandle_fwd.hpp"
# if defined(HAVE_TPETRATSQR_COMPLEX)
# include <complex>
# endif // HAVE_TPETRATSQR_COMPLEX
Expand All @@ -14,7 +14,9 @@ namespace Impl {
template<class Scalar>
class CuBlas {
public:
CuBlas (CuBlasHandle handle);
// Use the default handle.
CuBlas ();
CuBlas (const std::shared_ptr<CuBlasHandle>& handle);

void
gemm (const char transa,
Expand All @@ -27,7 +29,7 @@ class CuBlas {
Scalar* C, const int ldc);

private:
CuBlasHandle handle_;
std::shared_ptr<CuBlasHandle> handle_;
};

extern template class CuBlas<double>;
Expand Down
21 changes: 14 additions & 7 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuBlasHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@
#ifdef HAVE_TPETRATSQR_CUBLAS
#include "Kokkos_Core.hpp"
#include "Teuchos_Assert.hpp"
#include <cublas_v2.h>

namespace TSQR {
namespace Impl {

cublasHandle_t cuBlasRawHandle_ = nullptr;

CuBlasHandle::CuBlasHandle (void* handle) :
CuBlasHandle::CuBlasHandle (cublasHandle_t handle) :
handle_ (handle)
{}

CuBlasHandle CuBlasHandle::getSingleton ()
cublasHandle_t
CuBlasHandle::getHandle () const {
return handle_;
}

std::shared_ptr<CuBlasHandle> getCuBlasHandleSingleton ()
{
static int called_before = 0;
if (called_before == 0) {
static std::shared_ptr<CuBlasHandle> singleton_;
if (singleton_.get () == nullptr) {
auto finalizer = [] () {
if (cuBlasRawHandle_ != nullptr) {
(void) cublasDestroy (cuBlasRawHandle_);
Expand All @@ -27,10 +31,13 @@ CuBlasHandle CuBlasHandle::getSingleton ()
Kokkos::push_finalize_hook (finalizer);
auto status = cublasCreate (&cuBlasRawHandle_);
TEUCHOS_ASSERT( status == CUBLAS_STATUS_SUCCESS );
called_before = 1;

singleton_ = std::shared_ptr<CuBlasHandle>
(new CuBlasHandle (cuBlasRawHandle_));
}
TEUCHOS_ASSERT( cuBlasRawHandle_ != nullptr );
return CuBlasHandle (cuBlasRawHandle_);
TEUCHOS_ASSERT( singleton_.get () != nullptr );
return singleton_;
}

} // namespace Impl
Expand Down
27 changes: 13 additions & 14 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuBlasHandle.hpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
#ifndef TSQR_IMPL_CUBLASHANDLE_HPP
#define TSQR_IMPL_CUBLASHANDLE_HPP

#include "TpetraTSQR_config.h"
#include "Tsqr_Impl_CuBlasHandle_fwd.hpp"
#ifdef HAVE_TPETRATSQR_CUBLAS
#include <cublas_v2.h>

namespace TSQR {
namespace Impl {

class CuBlasHandle {
private:
// This is actually a cublasHandle_t, which is a pointer type.
void* handle_ {nullptr};
public:
CuBlasHandle () = delete;
CuBlasHandle (const CuBlasHandle&) = delete;
CuBlasHandle& operator= (const CuBlasHandle&) = delete;
CuBlasHandle (CuBlasHandle&&) = delete;
CuBlasHandle& operator= (CuBlasHandle&&) = delete;

CuBlasHandle (void* handle);
CuBlasHandle (cublasHandle_t handle);
cublasHandle_t getHandle () const;

public:
static CuBlasHandle getSingleton ();

// This is not really encapsulation, because the "handle" type is
// just a pointer. However, it lets us define cuBlas wrapper
// functions without needing to make them friends of CuBlasHandle.
void* getHandle () const {
return handle_;
}
private:
// cublasHandle_t is actually a pointer type.
cublasHandle_t handle_ {nullptr};
};

} // namespace Impl
Expand Down
30 changes: 30 additions & 0 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuBlasHandle_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef TSQR_IMPL_CUBLASHANDLE_FWD_HPP
#define TSQR_IMPL_CUBLASHANDLE_FWD_HPP

#include "TpetraTSQR_config.h"
#ifdef HAVE_TPETRATSQR_CUBLAS

#include <memory>

namespace TSQR {
namespace Impl {

/// \class CuBlasHandle
/// \brief Opaque wrapper for cublasHandle_t (cuBLAS handle instance)
///
/// \note To developers: Do not expose the declaration of this class
/// to downstream code. Users should only deal with this class by
/// the forward declaration and functions available in this header
/// file. Do not expose cuBLAS headers or extern declarations to
/// downstream code.
class CuBlasHandle;

//! Get TSQR's global cuBLAS handle wrapper.
std::shared_ptr<CuBlasHandle> getCuBlasHandleSingleton();

} // namespace Impl
} // namespace TSQR

#endif // HAVE_TPETRATSQR_CUBLAS

#endif // TSQR_IMPL_CUBLASHANDLE_FWD_HPP
27 changes: 14 additions & 13 deletions packages/tpetra/tsqr/src/Tsqr_Impl_CuSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,14 @@ class RawCuSolver<CudaValue<std::complex<float>>::type> {
#endif // defined(HAVE_TPETRATSQR_COMPLEX)

template<class Scalar>
CuSolver<Scalar>::CuSolver (CuSolverHandle handle, int* const info) :
CuSolver<Scalar>::CuSolver (int* const info) :
handle_ (getCuSolverHandleSingleton ()), info_ (info)
{}

template<class Scalar>
CuSolver<Scalar>::
CuSolver (const std::shared_ptr<CuSolverHandle>& handle,
int* const info) :
handle_ (handle), info_ (info)
{}

Expand All @@ -433,8 +440,7 @@ compute_QR_lwork (const int nrows,
Scalar A[],
const int lda) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();
int lwork = 0;

using IST = typename CudaValue<Scalar>::type;
Expand All @@ -459,8 +465,7 @@ compute_QR (const int nrows,
Scalar work[],
const int lwork) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();

using IST = typename CudaValue<Scalar>::type;
IST* A_raw = reinterpret_cast<IST*> (A);
Expand Down Expand Up @@ -488,8 +493,7 @@ apply_Q_factor_lwork (const char side,
Scalar C[],
const int ldc) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();
const cublasSideMode_t cuSide = cuBlasSide (side);
const cublasOperation_t cuTrans = cuBlasTrans (trans);
int lwork = 0;
Expand Down Expand Up @@ -525,8 +529,7 @@ apply_Q_factor (const char side,
Scalar work[],
const int lwork) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();
const cublasSideMode_t cuSide = cuBlasSide (side);
const cublasOperation_t cuTrans = cuBlasTrans (trans);

Expand All @@ -552,8 +555,7 @@ compute_explicit_Q_lwork(const int m, const int n, const int k,
Scalar A[], const int lda,
const Scalar tau[]) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();
int lwork = 0;

using IST = typename CudaValue<Scalar>::type;
Expand All @@ -576,8 +578,7 @@ compute_explicit_Q(const int m, const int n, const int k,
const Scalar tau[],
Scalar work[], const int lwork) const
{
auto rawHandle =
reinterpret_cast<cusolverDnHandle_t> (handle_.getHandle ());
cusolverDnHandle_t rawHandle = handle_->getHandle ();
using IST = typename CudaValue<Scalar>::type;
IST* A_raw = reinterpret_cast<IST*> (A);
const IST* tau_raw = reinterpret_cast<const IST*> (tau);
Expand Down
Loading

0 comments on commit 12e7efa

Please sign in to comment.