Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Reduction] Hide reducer non-standard members and add identity #8215

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 93 additions & 30 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,51 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
static constexpr size_t extent = Extent;
};

/// Helper class for accessing internal reducer member functions.
template <typename ReducerT> class ReducerAccess {
public:
ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}

template <typename ReducerRelayT = ReducerT> auto &getElement(size_t E) {
return MReducerRef.getElement(E);
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type> constexpr getIdentity() {
return getIdentityStatic();
}

template <typename ReducerRelayT = ReducerT>
enable_if_t<
!IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type>
getIdentity() {
return MReducerRef.identity();
}

// MSVC does not like static overloads of non-static functions, even if they
// are made mutually exclusive through SFINAE. Instead we use a new static
// function to be used when a static function is needed.
template <typename ReducerRelayT = ReducerT>
enable_if_t<
IsKnownIdentityOp<typename ReducerRelayT::value_type,
typename ReducerRelayT::binary_operation>::value,
typename ReducerRelayT::value_type> static constexpr getIdentityStatic() {
return ReducerT::getIdentity();
}

private:
ReducerT &MReducerRef;
};

// Deduction guide to simplify the use of ReducerAccess.
template <typename ReducerT>
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;

/// Use CRTP to avoid redefining shorthand operators in terms of combine
///
/// Also, for many types with known identity the operation 'atomic_combine()'
Expand Down Expand Up @@ -238,7 +283,7 @@ template <class Reducer> class combiner {
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
getMemoryScope<Space>(), Space>(
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
Functor(std::move(AtomicRef), reducer->getElement(E));
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
}
}

Expand Down Expand Up @@ -355,13 +400,15 @@ class reducer<
return *this;
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t) { return MValue; }
const T &getElement(size_t) const { return MValue; }
T MValue;

private:
T MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
};
Expand Down Expand Up @@ -392,7 +439,12 @@ class reducer<
return *this;
}

static T getIdentity() {
T identity() const { return getIdentity(); }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

Expand All @@ -419,6 +471,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
}

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &MElement;
BinaryOperation MBinaryOp;
};
Expand All @@ -444,11 +498,14 @@ class reducer<
return {MValue[Index], MBinaryOp};
}

T getIdentity() const { return MIdentity; }
T identity() const { return MIdentity; }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
const T MIdentity;
BinaryOperation MBinaryOp;
Expand Down Expand Up @@ -477,14 +534,18 @@ class reducer<
return {MValue[Index], BinaryOperation()};
}

static T getIdentity() {
T identity() const { return getIdentity(); }

private:
template <typename ReducerT> friend class detail::ReducerAccess;

static constexpr T getIdentity() {
return detail::known_identity_impl<BinaryOperation, T>::value;
}

T &getElement(size_t E) { return MValue[E]; }
const T &getElement(size_t E) const { return MValue[E]; }

private:
marray<T, Extent> MValue;
};

Expand Down Expand Up @@ -769,8 +830,7 @@ class reduction_impl
// list of known operations does not break the existing programs.
if constexpr (is_known_identity) {
(void)Identity;
return reducer_type::getIdentity();

return ReducerAccess<reducer_type>::getIdentityStatic();
} else {
return Identity;
}
Expand All @@ -788,8 +848,8 @@ class reduction_impl
template <typename _self = self,
enable_if_t<_self::is_known_identity> * = nullptr>
reduction_impl(RedOutVar Var, bool InitializeToIdentity = false)
: algo(reducer_type::getIdentity(), BinaryOperation(),
InitializeToIdentity, Var) {
: algo(ReducerAccess<reducer_type>::getIdentityStatic(),
BinaryOperation(), InitializeToIdentity, Var) {
if constexpr (!is_usm)
if (Var.size() != 1)
throw sycl::runtime_error(errc::invalid,
Expand Down Expand Up @@ -896,7 +956,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
// Work-group cooperates to initialize multiple reduction variables
auto LID = NDId.get_local_id(0);
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
GroupSum[E] = Reducer.getIdentity();
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
}
workGroupBarrier();

Expand All @@ -909,7 +969,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
workGroupBarrier();
if (LID == 0) {
for (size_t E = 0; E < NElements; ++E) {
Reducer.getElement(E) = GroupSum[E];
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
}
Reducer.template atomic_combine(&Out[0]);
}
Expand Down Expand Up @@ -959,7 +1019,7 @@ struct NDRangeReduction<
// reduce_over_group is only defined for each T, not for span<T, ...>
size_t LID = NDId.get_local_id(0);
for (int E = 0; E < NElements; ++E) {
auto &RedElem = Reducer.getElement(E);
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
RedElem = reduce_over_group(Group, RedElem, BOp);
if (LID == 0) {
if (NWorkGroups == 1) {
Expand All @@ -970,7 +1030,7 @@ struct NDRangeReduction<
Out[E] = RedElem;
} else {
PartialSums[NDId.get_group_linear_id() * NElements + E] =
Reducer.getElement(E);
ReducerAccess{Reducer}.getElement(E);
}
}
}
Expand All @@ -993,7 +1053,7 @@ struct NDRangeReduction<
// Reduce each result separately
// TODO: Opportunity to parallelize across elements.
for (int E = 0; E < NElements; ++E) {
auto LocalSum = Reducer.getIdentity();
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
auto Result = reduce_over_group(Group, LocalSum, BOp);
Expand Down Expand Up @@ -1083,7 +1143,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
[&]() { workGroupBarrier(); });
Expand Down Expand Up @@ -1158,8 +1218,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {

typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
Reducer.getElement(E) =
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
}
if (NDIt.get_local_linear_id() == 0)
Reducer.atomic_combine(&Out[0]);
Expand Down Expand Up @@ -1207,14 +1267,15 @@ struct NDRangeReduction<
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

typename Reduction::binary_operation BOp;
doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(),
LocalReds, BOp, [&]() { NDIt.barrier(); });
doTreeReduction(WGSize, LID, IsPow2WG,
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
[&]() { NDIt.barrier(); });

if (LID == 0) {
Reducer.getElement(E) =
ReducerAccess{Reducer}.getElement(E) =
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
}

Expand Down Expand Up @@ -1282,7 +1343,7 @@ struct NDRangeReduction<
typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
typename Reduction::result_type PSum;
PSum = Reducer.getElement(E);
PSum = ReducerAccess{Reducer}.getElement(E);
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1346,7 +1407,8 @@ struct NDRangeReduction<
typename Reduction::result_type PSum =
(HasUniformWG || (GID < NWorkItems))
? In[GID * NElements + E]
: Reduction::reducer_type::getIdentity();
: ReducerAccess<typename Reduction::reducer_type>::
getIdentityStatic();
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1420,7 +1482,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down Expand Up @@ -1693,7 +1755,8 @@ void reduCGFuncImplScalar(
size_t WGSize = NDIt.get_local_range().size();
size_t LID = NDIt.get_local_linear_id();

((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue),
((std::get<Is>(LocalAccsTuple)[LID] =
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
...);

// For work-groups, which size is not power of two, local accessors have
Expand Down Expand Up @@ -1744,7 +1807,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
for (size_t E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = Reducer.getElement(E);
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);

doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down