diff --git a/sycl/include/sycl/reduction.hpp b/sycl/include/sycl/reduction.hpp index 8a5084661103d..2901c2112fba2 100644 --- a/sycl/include/sycl/reduction.hpp +++ b/sycl/include/sycl/reduction.hpp @@ -159,6 +159,51 @@ struct ReducerTraits> { static constexpr size_t extent = Extent; }; +/// Helper class for accessing internal reducer member functions. +template class ReducerAccess { +public: + ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {} + + template auto &getElement(size_t E) { + return MReducerRef.getElement(E); + } + + template + enable_if_t< + IsKnownIdentityOp::value, + typename ReducerRelayT::value_type> constexpr getIdentity() { + return getIdentityStatic(); + } + + template + enable_if_t< + !IsKnownIdentityOp::value, + typename ReducerRelayT::value_type> + getIdentity() { + return MReducerRef.identity(); + } + + // MSVC does not like static overloads of non-static functions, even if they + // are made mutually exclusive through SFINAE. Instead we use a new static + // function to be used when a static function is needed. + template + enable_if_t< + IsKnownIdentityOp::value, + typename ReducerRelayT::value_type> static constexpr getIdentityStatic() { + return ReducerT::getIdentity(); + } + +private: + ReducerT &MReducerRef; +}; + +// Deduction guide to simplify the use of ReducerAccess. +template +ReducerAccess(ReducerT &) -> ReducerAccess; + /// Use CRTP to avoid redefining shorthand operators in terms of combine /// /// Also, for many types with known identity the operation 'atomic_combine()' @@ -238,7 +283,7 @@ template class combiner { auto AtomicRef = sycl::atomic_ref(), Space>( address_space_cast(ReduVarPtr)[E]); - Functor(std::move(AtomicRef), reducer->getElement(E)); + Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E)); } } @@ -355,13 +400,15 @@ class reducer< return *this; } - T getIdentity() const { return MIdentity; } + T identity() const { return MIdentity; } + +private: + template friend class detail::ReducerAccess; T &getElement(size_t) { return MValue; } const T &getElement(size_t) const { return MValue; } - T MValue; -private: + T MValue; const T MIdentity; BinaryOperation MBinaryOp; }; @@ -392,7 +439,12 @@ class reducer< return *this; } - static T getIdentity() { + T identity() const { return getIdentity(); } + +private: + template friend class detail::ReducerAccess; + + static constexpr T getIdentity() { return detail::known_identity_impl::value; } @@ -419,6 +471,8 @@ class reducer friend class detail::ReducerAccess; + T &MElement; BinaryOperation MBinaryOp; }; @@ -444,11 +498,14 @@ class reducer< return {MValue[Index], MBinaryOp}; } - T getIdentity() const { return MIdentity; } + T identity() const { return MIdentity; } + +private: + template friend class detail::ReducerAccess; + T &getElement(size_t E) { return MValue[E]; } const T &getElement(size_t E) const { return MValue[E]; } -private: marray MValue; const T MIdentity; BinaryOperation MBinaryOp; @@ -477,14 +534,18 @@ class reducer< return {MValue[Index], BinaryOperation()}; } - static T getIdentity() { + T identity() const { return getIdentity(); } + +private: + template friend class detail::ReducerAccess; + + static constexpr T getIdentity() { return detail::known_identity_impl::value; } T &getElement(size_t E) { return MValue[E]; } const T &getElement(size_t E) const { return MValue[E]; } -private: marray MValue; }; @@ -769,8 +830,7 @@ class reduction_impl // list of known operations does not break the existing programs. if constexpr (is_known_identity) { (void)Identity; - return reducer_type::getIdentity(); - + return ReducerAccess::getIdentityStatic(); } else { return Identity; } @@ -788,8 +848,8 @@ class reduction_impl template * = nullptr> reduction_impl(RedOutVar Var, bool InitializeToIdentity = false) - : algo(reducer_type::getIdentity(), BinaryOperation(), - InitializeToIdentity, Var) { + : algo(ReducerAccess::getIdentityStatic(), + BinaryOperation(), InitializeToIdentity, Var) { if constexpr (!is_usm) if (Var.size() != 1) throw sycl::runtime_error(errc::invalid, @@ -896,7 +956,7 @@ struct NDRangeReduction { // Work-group cooperates to initialize multiple reduction variables auto LID = NDId.get_local_id(0); for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) { - GroupSum[E] = Reducer.getIdentity(); + GroupSum[E] = ReducerAccess(Reducer).getIdentity(); } workGroupBarrier(); @@ -909,7 +969,7 @@ struct NDRangeReduction { workGroupBarrier(); if (LID == 0) { for (size_t E = 0; E < NElements; ++E) { - Reducer.getElement(E) = GroupSum[E]; + ReducerAccess{Reducer}.getElement(E) = GroupSum[E]; } Reducer.template atomic_combine(&Out[0]); } @@ -959,7 +1019,7 @@ struct NDRangeReduction< // reduce_over_group is only defined for each T, not for span size_t LID = NDId.get_local_id(0); for (int E = 0; E < NElements; ++E) { - auto &RedElem = Reducer.getElement(E); + auto &RedElem = ReducerAccess{Reducer}.getElement(E); RedElem = reduce_over_group(Group, RedElem, BOp); if (LID == 0) { if (NWorkGroups == 1) { @@ -970,7 +1030,7 @@ struct NDRangeReduction< Out[E] = RedElem; } else { PartialSums[NDId.get_group_linear_id() * NElements + E] = - Reducer.getElement(E); + ReducerAccess{Reducer}.getElement(E); } } } @@ -993,7 +1053,7 @@ struct NDRangeReduction< // Reduce each result separately // TODO: Opportunity to parallelize across elements. for (int E = 0; E < NElements; ++E) { - auto LocalSum = Reducer.getIdentity(); + auto LocalSum = ReducerAccess{Reducer}.getIdentity(); for (size_t I = LID; I < NWorkGroups; I += WGSize) LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]); auto Result = reduce_over_group(Group, LocalSum, BOp); @@ -1083,7 +1143,7 @@ template <> struct NDRangeReduction { for (int E = 0; E < NElements; ++E) { // Copy the element to local memory to prepare it for tree-reduction. - LocalReds[LID] = Reducer.getElement(E); + LocalReds[LID] = ReducerAccess{Reducer}.getElement(E); doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp, [&]() { workGroupBarrier(); }); @@ -1158,8 +1218,8 @@ struct NDRangeReduction { typename Reduction::binary_operation BOp; for (int E = 0; E < NElements; ++E) { - Reducer.getElement(E) = - reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp); + ReducerAccess{Reducer}.getElement(E) = reduce_over_group( + NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp); } if (NDIt.get_local_linear_id() == 0) Reducer.atomic_combine(&Out[0]); @@ -1207,14 +1267,15 @@ struct NDRangeReduction< for (int E = 0; E < NElements; ++E) { // Copy the element to local memory to prepare it for tree-reduction. - LocalReds[LID] = Reducer.getElement(E); + LocalReds[LID] = ReducerAccess{Reducer}.getElement(E); typename Reduction::binary_operation BOp; - doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(), - LocalReds, BOp, [&]() { NDIt.barrier(); }); + doTreeReduction(WGSize, LID, IsPow2WG, + ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp, + [&]() { NDIt.barrier(); }); if (LID == 0) { - Reducer.getElement(E) = + ReducerAccess{Reducer}.getElement(E) = IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]); } @@ -1282,7 +1343,7 @@ struct NDRangeReduction< typename Reduction::binary_operation BOp; for (int E = 0; E < NElements; ++E) { typename Reduction::result_type PSum; - PSum = Reducer.getElement(E); + PSum = ReducerAccess{Reducer}.getElement(E); PSum = reduce_over_group(NDIt.get_group(), PSum, BOp); if (NDIt.get_local_linear_id() == 0) { if (IsUpdateOfUserVar) @@ -1346,7 +1407,8 @@ struct NDRangeReduction< typename Reduction::result_type PSum = (HasUniformWG || (GID < NWorkItems)) ? In[GID * NElements + E] - : Reduction::reducer_type::getIdentity(); + : ReducerAccess:: + getIdentityStatic(); PSum = reduce_over_group(NDIt.get_group(), PSum, BOp); if (NDIt.get_local_linear_id() == 0) { if (IsUpdateOfUserVar) @@ -1420,7 +1482,7 @@ template <> struct NDRangeReduction { for (int E = 0; E < NElements; ++E) { // Copy the element to local memory to prepare it for tree-reduction. - LocalReds[LID] = Reducer.getElement(E); + LocalReds[LID] = ReducerAccess{Reducer}.getElement(E); doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp, [&]() { NDIt.barrier(); }); @@ -1693,7 +1755,8 @@ void reduCGFuncImplScalar( size_t WGSize = NDIt.get_local_range().size(); size_t LID = NDIt.get_local_linear_id(); - ((std::get(LocalAccsTuple)[LID] = std::get(ReducersTuple).MValue), + ((std::get(LocalAccsTuple)[LID] = + ReducerAccess{std::get(ReducersTuple)}.getElement(0)), ...); // For work-groups, which size is not power of two, local accessors have @@ -1744,7 +1807,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item NDIt, for (size_t E = 0; E < NElements; ++E) { // Copy the element to local memory to prepare it for tree-reduction. - LocalReds[LID] = Reducer.getElement(E); + LocalReds[LID] = ReducerAccess{Reducer}.getElement(E); doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp, [&]() { NDIt.barrier(); });