Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Reduction] Hide reducer non-standard members and add identity #8215

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 103 additions & 34 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,40 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
static constexpr size_t extent = Extent;
};

/// Helper class for accessing internal reducer member functions.
template <typename ReducerT> class ReducerAccess {
public:
ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}

template <typename ReducerRelayT = ReducerT> auto &getElement(size_t E) {
return MReducerRef.getElement(E);
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type> static constexpr getIdentity() {
return ReducerT::getIdentity();
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
!IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type>
getIdentity() {
return MReducerRef.identity();
}

private:
ReducerT &MReducerRef;
};

// Deduction guide to simplify the use of ReducerAccess.
template <typename ReducerT>
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;

/// Use CRTP to avoid redefining shorthand operators in terms of combine
///
/// Also, for many types with known identity the operation 'atomic_combine()'
Expand Down Expand Up @@ -238,7 +272,7 @@ template <class Reducer> class combiner {
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
getMemoryScope<Space>(), Space>(
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
Functor(std::move(AtomicRef), reducer->getElement(E));
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
}
}

Expand Down Expand Up @@ -320,6 +354,14 @@ template <class Reducer> class combiner {
ReduVarPtr, [](auto &&Ref, auto Val) { return Ref.fetch_max(Val); });
}
};

template <typename T, class BinaryOperation, int Dims> class reducer_common {
public:
using value_type = T;
using binary_operation = BinaryOperation;
static constexpr int dimensions = Dims;
};

} // namespace detail

/// Specialization of the generic class 'reducer'. It is used for reductions
Expand All @@ -336,7 +378,8 @@ class reducer<
reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<
Dims == 0 && Extent == 1 && View == false &&
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
public detail::reducer_common<T, BinaryOperation, Dims> {
public:
reducer(const T &Identity, BinaryOperation BOp)
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
Expand All @@ -346,13 +389,15 @@ class reducer<
return *this;
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t) { return MValue; }
const T &getElement(size_t) const { return MValue; }
T MValue;

private:
T MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
};
Expand All @@ -371,7 +416,8 @@ class reducer<
reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<
Dims == 0 && Extent == 1 && View == false &&
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
public detail::reducer_common<T, BinaryOperation, Dims> {
public:
reducer() : MValue(getIdentity()) {}
reducer(const T & /* Identity */, BinaryOperation) : MValue(getIdentity()) {}
Expand All @@ -382,7 +428,14 @@ class reducer<
return *this;
}

static T getIdentity() {
T identity() const {
return getIdentity();
}

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

Expand All @@ -398,7 +451,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<Dims == 0 && View == true>>
: public detail::combiner<
reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<Dims == 0 && View == true>>> {
std::enable_if_t<Dims == 0 && View == true>>>,
public detail::reducer_common<T, BinaryOperation, Dims> {
public:
reducer(T &Ref, BinaryOperation BOp) : MElement(Ref), MBinaryOp(BOp) {}

Expand All @@ -408,6 +462,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
}

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &MElement;
BinaryOperation MBinaryOp;
};
Expand All @@ -423,7 +479,8 @@ class reducer<
reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<
Dims == 1 && View == false &&
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
public detail::reducer_common<T, BinaryOperation, Dims> {
public:
reducer(const T &Identity, BinaryOperation BOp)
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
Expand All @@ -432,11 +489,14 @@ class reducer<
return {MValue[Index], MBinaryOp};
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
Expand All @@ -453,7 +513,8 @@ class reducer<
reducer<T, BinaryOperation, Dims, Extent, View,
std::enable_if_t<
Dims == 1 && View == false &&
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
public detail::reducer_common<T, BinaryOperation, Dims> {
public:
reducer() : MValue(getIdentity()) {}
reducer(const T & /* Identity */, BinaryOperation) : MValue(getIdentity()) {}
Expand All @@ -464,14 +525,20 @@ class reducer<
return {MValue[Index], BinaryOperation()};
}

static T getIdentity() {
T identity() const {
return getIdentity();
}

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
};

Expand Down Expand Up @@ -756,8 +823,7 @@ class reduction_impl
// list of known operations does not break the existing programs.
if constexpr (is_known_identity) {
(void)Identity;
return reducer_type::getIdentity();

return ReducerAccess<reducer_type>::getIdentity();
} else {
return Identity;
}
Expand All @@ -775,7 +841,7 @@ class reduction_impl
template <typename _self = self,
enable_if_t<_self::is_known_identity> * = nullptr>
reduction_impl(RedOutVar Var, bool InitializeToIdentity = false)
: algo(reducer_type::getIdentity(), BinaryOperation(),
: algo(ReducerAccess<reducer_type>::getIdentity(), BinaryOperation(),
InitializeToIdentity, Var) {
if constexpr (!is_usm)
if (Var.size() != 1)
Expand Down Expand Up @@ -883,7 +949,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
// Work-group cooperates to initialize multiple reduction variables
auto LID = NDId.get_local_id(0);
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
GroupSum[E] = Reducer.getIdentity();
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
}
workGroupBarrier();

Expand All @@ -896,7 +962,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
workGroupBarrier();
if (LID == 0) {
for (size_t E = 0; E < NElements; ++E) {
Reducer.getElement(E) = GroupSum[E];
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
}
Reducer.template atomic_combine(&Out[0]);
}
Expand Down Expand Up @@ -946,7 +1012,7 @@ struct NDRangeReduction<
// reduce_over_group is only defined for each T, not for span<T, ...>
size_t LID = NDId.get_local_id(0);
for (int E = 0; E < NElements; ++E) {
auto &RedElem = Reducer.getElement(E);
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
RedElem = reduce_over_group(Group, RedElem, BOp);
if (LID == 0) {
if (NWorkGroups == 1) {
Expand All @@ -957,7 +1023,7 @@ struct NDRangeReduction<
Out[E] = RedElem;
} else {
PartialSums[NDId.get_group_linear_id() * NElements + E] =
Reducer.getElement(E);
ReducerAccess{Reducer}.getElement(E);
}
}
}
Expand All @@ -980,7 +1046,7 @@ struct NDRangeReduction<
// Reduce each result separately
// TODO: Opportunity to parallelize across elements.
for (int E = 0; E < NElements; ++E) {
auto LocalSum = Reducer.getIdentity();
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
auto Result = reduce_over_group(Group, LocalSum, BOp);
Expand Down Expand Up @@ -1070,7 +1136,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
[&]() { workGroupBarrier(); });
Expand Down Expand Up @@ -1145,8 +1211,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {

typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
Reducer.getElement(E) =
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
}
if (NDIt.get_local_linear_id() == 0)
Reducer.atomic_combine(&Out[0]);
Expand Down Expand Up @@ -1194,14 +1260,15 @@ struct NDRangeReduction<
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

typename Reduction::binary_operation BOp;
doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(),
LocalReds, BOp, [&]() { NDIt.barrier(); });
doTreeReduction(WGSize, LID, IsPow2WG,
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
[&]() { NDIt.barrier(); });

if (LID == 0) {
Reducer.getElement(E) =
ReducerAccess{Reducer}.getElement(E) =
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
}

Expand Down Expand Up @@ -1269,7 +1336,7 @@ struct NDRangeReduction<
typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
typename Reduction::result_type PSum;
PSum = Reducer.getElement(E);
PSum = ReducerAccess{Reducer}.getElement(E);
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1333,7 +1400,8 @@ struct NDRangeReduction<
typename Reduction::result_type PSum =
(HasUniformWG || (GID < NWorkItems))
? In[GID * NElements + E]
: Reduction::reducer_type::getIdentity();
: ReducerAccess<
typename Reduction::reducer_type>::getIdentity();
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1407,7 +1475,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down Expand Up @@ -1680,7 +1748,8 @@ void reduCGFuncImplScalar(
size_t WGSize = NDIt.get_local_range().size();
size_t LID = NDIt.get_local_linear_id();

((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue),
((std::get<Is>(LocalAccsTuple)[LID] =
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
...);

// For work-groups, which size is not power of two, local accessors have
Expand Down Expand Up @@ -1731,7 +1800,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
for (size_t E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down
Loading