Skip to content

Commit

Permalink
[SYCL][Reduction] Hide reducer non-standard members and add identity (i…
Browse files Browse the repository at this point in the history
…ntel#8215)

This commit hides the members in reducer that are not mentioned in the
SYCL 2020 specification and introduces the identity member function.

---------

Signed-off-by: Larsen, Steffen <[email protected]>
  • Loading branch information
steffenlarsen authored Feb 10, 2023
1 parent 680c1b3 commit 505aa7d
Showing 1 changed file with 93 additions and 30 deletions.
123 changes: 93 additions & 30 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,51 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
static constexpr size_t extent = Extent;
};

/// Helper class for accessing internal reducer member functions.
template <typename ReducerT> class ReducerAccess {
public:
ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}

template <typename ReducerRelayT = ReducerT> auto &getElement(size_t E) {
return MReducerRef.getElement(E);
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type> constexpr getIdentity() {
return getIdentityStatic();
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
!IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type>
getIdentity() {
return MReducerRef.identity();
}

// MSVC does not like static overloads of non-static functions, even if they
// are made mutually exclusive through SFINAE. Instead we use a new static
// function to be used when a static function is needed.
template <typename ReducerRelayT = ReducerT>
enable_if_t<
IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type> static constexpr getIdentityStatic() {
return ReducerT::getIdentity();
}

private:
ReducerT &MReducerRef;
};

// Deduction guide to simplify the use of ReducerAccess.
template <typename ReducerT>
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;

/// Use CRTP to avoid redefining shorthand operators in terms of combine
///
/// Also, for many types with known identity the operation 'atomic_combine()'
Expand Down Expand Up @@ -238,7 +283,7 @@ template <class Reducer> class combiner {
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
getMemoryScope<Space>(), Space>(
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
Functor(std::move(AtomicRef), reducer->getElement(E));
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
}
}

Expand Down Expand Up @@ -355,13 +400,15 @@ class reducer<
return *this;
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t) { return MValue; }
const T &getElement(size_t) const { return MValue; }
T MValue;

private:
T MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
};
Expand Down Expand Up @@ -392,7 +439,12 @@ class reducer<
return *this;
}

static T getIdentity() {
T identity() const { return getIdentity(); }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

Expand All @@ -419,6 +471,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
}

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &MElement;
BinaryOperation MBinaryOp;
};
Expand All @@ -444,11 +498,14 @@ class reducer<
return {MValue[Index], MBinaryOp};
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
Expand Down Expand Up @@ -477,14 +534,18 @@ class reducer<
return {MValue[Index], BinaryOperation()};
}

static T getIdentity() {
T identity() const { return getIdentity(); }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
};

Expand Down Expand Up @@ -769,8 +830,7 @@ class reduction_impl
// list of known operations does not break the existing programs.
if constexpr (is_known_identity) {
(void)Identity;
return reducer_type::getIdentity();

return ReducerAccess<reducer_type>::getIdentityStatic();
} else {
return Identity;
}
Expand All @@ -788,8 +848,8 @@ class reduction_impl
template <typename _self = self,
enable_if_t<_self::is_known_identity> * = nullptr>
reduction_impl(RedOutVar Var, bool InitializeToIdentity = false)
: algo(reducer_type::getIdentity(), BinaryOperation(),
InitializeToIdentity, Var) {
: algo(ReducerAccess<reducer_type>::getIdentityStatic(),
BinaryOperation(), InitializeToIdentity, Var) {
if constexpr (!is_usm)
if (Var.size() != 1)
throw sycl::runtime_error(errc::invalid,
Expand Down Expand Up @@ -896,7 +956,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
// Work-group cooperates to initialize multiple reduction variables
auto LID = NDId.get_local_id(0);
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
GroupSum[E] = Reducer.getIdentity();
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
}
workGroupBarrier();

Expand All @@ -909,7 +969,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
workGroupBarrier();
if (LID == 0) {
for (size_t E = 0; E < NElements; ++E) {
Reducer.getElement(E) = GroupSum[E];
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
}
Reducer.template atomic_combine(&Out[0]);
}
Expand Down Expand Up @@ -959,7 +1019,7 @@ struct NDRangeReduction<
// reduce_over_group is only defined for each T, not for span<T, ...>
size_t LID = NDId.get_local_id(0);
for (int E = 0; E < NElements; ++E) {
auto &RedElem = Reducer.getElement(E);
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
RedElem = reduce_over_group(Group, RedElem, BOp);
if (LID == 0) {
if (NWorkGroups == 1) {
Expand All @@ -970,7 +1030,7 @@ struct NDRangeReduction<
Out[E] = RedElem;
} else {
PartialSums[NDId.get_group_linear_id() * NElements + E] =
Reducer.getElement(E);
ReducerAccess{Reducer}.getElement(E);
}
}
}
Expand All @@ -993,7 +1053,7 @@ struct NDRangeReduction<
// Reduce each result separately
// TODO: Opportunity to parallelize across elements.
for (int E = 0; E < NElements; ++E) {
auto LocalSum = Reducer.getIdentity();
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
auto Result = reduce_over_group(Group, LocalSum, BOp);
Expand Down Expand Up @@ -1083,7 +1143,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
[&]() { workGroupBarrier(); });
Expand Down Expand Up @@ -1158,8 +1218,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {

typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
Reducer.getElement(E) =
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
}
if (NDIt.get_local_linear_id() == 0)
Reducer.atomic_combine(&Out[0]);
Expand Down Expand Up @@ -1207,14 +1267,15 @@ struct NDRangeReduction<
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

typename Reduction::binary_operation BOp;
doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(),
LocalReds, BOp, [&]() { NDIt.barrier(); });
doTreeReduction(WGSize, LID, IsPow2WG,
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
[&]() { NDIt.barrier(); });

if (LID == 0) {
Reducer.getElement(E) =
ReducerAccess{Reducer}.getElement(E) =
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
}

Expand Down Expand Up @@ -1282,7 +1343,7 @@ struct NDRangeReduction<
typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
typename Reduction::result_type PSum;
PSum = Reducer.getElement(E);
PSum = ReducerAccess{Reducer}.getElement(E);
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1346,7 +1407,8 @@ struct NDRangeReduction<
typename Reduction::result_type PSum =
(HasUniformWG || (GID < NWorkItems))
? In[GID * NElements + E]
: Reduction::reducer_type::getIdentity();
: ReducerAccess<typename Reduction::reducer_type>::
getIdentityStatic();
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1420,7 +1482,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down Expand Up @@ -1693,7 +1755,8 @@ void reduCGFuncImplScalar(
size_t WGSize = NDIt.get_local_range().size();
size_t LID = NDIt.get_local_linear_id();

((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue),
((std::get<Is>(LocalAccsTuple)[LID] =
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
...);

// For work-groups, which size is not power of two, local accessors have
Expand Down Expand Up @@ -1744,7 +1807,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
for (size_t E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down

0 comments on commit 505aa7d

Please sign in to comment.