Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #609 from senior-zero/enh-main/github/use_libcudac…
Browse files Browse the repository at this point in the history
…xx_operators

Use libcu++ function objects
  • Loading branch information
gevtushenko authored Jan 7, 2023
2 parents 3abfcc1 + ad3c3c7 commit 12332f0
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
56 changes: 33 additions & 23 deletions cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
#pragma once

#include <cub/config.cuh>
#include <cub/util_cpp_dialect.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/functional>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

Expand All @@ -51,6 +53,33 @@ CUB_NAMESPACE_BEGIN
* @{
*/

/// @brief Inequality functor (wraps equality functor)
template <typename EqualityOp>
struct InequalityWrapper
{
/// Wrapped equality operator
EqualityOp op;

/// Constructor
__host__ __device__ __forceinline__ InequalityWrapper(EqualityOp op)
: op(op)
{}

/// Boolean inequality operator, returns `t != u`
template <typename T, typename U>
__host__ __device__ __forceinline__ bool operator()(T &&t, U &&u)
{
return !op(::cuda::std::forward<T>(t), ::cuda::std::forward<U>(u));
}
};

#if CUB_CPP_DIALECT > 2011
using Equality = ::cuda::std::equal_to<>;
using Inequality = ::cuda::std::not_equal_to<>;
using Sum = ::cuda::std::plus<>;
using Difference = ::cuda::std::minus<>;
using Division = ::cuda::std::divides<>;
#else
/// @brief Default equality functor
struct Equality
{
Expand All @@ -73,26 +102,6 @@ struct Inequality
}
};

/// @brief Inequality functor (wraps equality functor)
template <typename EqualityOp>
struct InequalityWrapper
{
/// Wrapped equality operator
EqualityOp op;

/// Constructor
__host__ __device__ __forceinline__ InequalityWrapper(EqualityOp op)
: op(op)
{}

/// Boolean inequality operator, returns `t != u`
template <typename T, typename U>
__host__ __device__ __forceinline__ bool operator()(T &&t, U &&u)
{
return !op(std::forward<T>(t), std::forward<U>(u));
}
};

/// @brief Default sum functor
struct Sum
{
Expand Down Expand Up @@ -128,6 +137,7 @@ struct Division
return ::cuda::std::forward<T>(t) / ::cuda::std::forward<U>(u);
}
};
#endif

/// @brief Default max functor
struct Max
Expand Down Expand Up @@ -367,10 +377,10 @@ struct BinaryFlip

template <typename T, typename U>
__device__ auto
operator()(T &&t, U &&u) -> decltype(binary_op(std::forward<U>(u),
std::forward<T>(t)))
operator()(T &&t, U &&u) -> decltype(binary_op(::cuda::std::forward<U>(u),
::cuda::std::forward<T>(t)))
{
return binary_op(std::forward<U>(u), std::forward<T>(t));
return binary_op(::cuda::std::forward<U>(u), ::cuda::std::forward<T>(t));
}
};

Expand Down
4 changes: 3 additions & 1 deletion cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#pragma once

#include <cuda/std/utility>

#include <cub/detail/device_synchronize.cuh>
#include <cub/util_arch.cuh>
#include <cub/util_cpp_dialect.cuh>
Expand Down Expand Up @@ -282,7 +284,7 @@ public:

// We don't use `CubDebug` here because we let the user code
// decide whether or not errors are hard errors.
payload.error = std::forward<Invocable>(f)(payload.attribute);
payload.error = ::cuda::std::forward<Invocable>(f)(payload.attribute);
if (payload.error)
// Clear the global CUDA error state which may have been
// set by the last call. Otherwise, errors may "leak" to
Expand Down
12 changes: 6 additions & 6 deletions cub/util_macro.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

#pragma once

#include "util_namespace.cuh"
#include <cuda/std/utility>

#include <utility>
#include "util_namespace.cuh"

CUB_NAMESPACE_BEGIN

Expand All @@ -59,17 +59,17 @@ CUB_NAMESPACE_BEGIN
template <typename T, typename U>
constexpr __host__ __device__ auto min CUB_PREVENT_MACRO_SUBSTITUTION(T &&t,
U &&u)
-> decltype(t < u ? std::forward<T>(t) : std::forward<U>(u))
-> decltype(t < u ? ::cuda::std::forward<T>(t) : ::cuda::std::forward<U>(u))
{
return t < u ? std::forward<T>(t) : std::forward<U>(u);
return t < u ? ::cuda::std::forward<T>(t) : ::cuda::std::forward<U>(u);
}

template <typename T, typename U>
constexpr __host__ __device__ auto max CUB_PREVENT_MACRO_SUBSTITUTION(T &&t,
U &&u)
-> decltype(t < u ? std::forward<U>(u) : std::forward<T>(t))
-> decltype(t < u ? ::cuda::std::forward<U>(u) : ::cuda::std::forward<T>(t))
{
return t < u ? std::forward<U>(u) : std::forward<T>(t);
return t < u ? ::cuda::std::forward<U>(u) : ::cuda::std::forward<T>(t);
}

#ifndef CUB_MAX
Expand Down

0 comments on commit 12332f0

Please sign in to comment.