diff --git a/sycl/include/sycl/ext/oneapi/reduction.hpp b/sycl/include/sycl/ext/oneapi/reduction.hpp index a8898ea82ff3b..425e69d54ef7d 100644 --- a/sycl/include/sycl/ext/oneapi/reduction.hpp +++ b/sycl/include/sycl/ext/oneapi/reduction.hpp @@ -192,57 +192,52 @@ struct ReducerTraits> { /// Also, for int32/64 types the atomic_combine() is lowered to /// sycl::atomic::fetch_add(). template class combiner { - using T = typename ReducerTraits::type; - using BinaryOperation = typename ReducerTraits::op; + using Ty = typename ReducerTraits::type; + using BinaryOp = typename ReducerTraits::op; static constexpr int Dims = ReducerTraits::dims; static constexpr size_t Extent = ReducerTraits::extent; public: - template - enable_if_t<(_Dims == 0) && - sycl::detail::IsPlus<_T, BinaryOperation>::value && + template + enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value && sycl::detail::is_geninteger<_T>::value> operator++() { - static_cast(this)->combine(static_cast(1)); + static_cast(this)->combine(static_cast<_T>(1)); } - template - enable_if_t<(_Dims == 0) && - sycl::detail::IsPlus<_T, BinaryOperation>::value && + template + enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value && sycl::detail::is_geninteger<_T>::value> operator++(int) { - static_cast(this)->combine(static_cast(1)); + static_cast(this)->combine(static_cast<_T>(1)); } - template - enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOperation>::value> + template + enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value> operator+=(const _T &Partial) { static_cast(this)->combine(Partial); } - template - enable_if_t<(_Dims == 0) && - sycl::detail::IsMultiplies<_T, BinaryOperation>::value> + template + enable_if_t<(_Dims == 0) && sycl::detail::IsMultiplies<_T, BinaryOp>::value> operator*=(const _T &Partial) { static_cast(this)->combine(Partial); } - template - enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOperation>::value> + template + enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOp>::value> operator|=(const _T &Partial) { static_cast(this)->combine(Partial); } - template - enable_if_t<(_Dims == 0) && - sycl::detail::IsBitXOR<_T, BinaryOperation>::value> + template + enable_if_t<(_Dims == 0) && sycl::detail::IsBitXOR<_T, BinaryOp>::value> operator^=(const _T &Partial) { static_cast(this)->combine(Partial); } - template - enable_if_t<(_Dims == 0) && - sycl::detail::IsBitAND<_T, BinaryOperation>::value> + template + enable_if_t<(_Dims == 0) && sycl::detail::IsBitAND<_T, BinaryOp>::value> operator&=(const _T &Partial) { static_cast(this)->combine(Partial); } @@ -266,20 +261,20 @@ template class combiner { } } - template + template static constexpr bool BasicCheck = - std::is_same::type, T>::value && + std::is_same::type, Ty>::value && (Space == access::address_space::global_space || Space == access::address_space::local_space); public: /// Atomic ADD operation: *ReduVarPtr += MValue; template + typename _T = Ty, class _BinaryOperation = BinaryOp> enable_if_t && - (IsReduOptForFastAtomicFetch::value || - IsReduOptForAtomic64Op::value) && - sycl::detail::IsPlus::value> + (IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value || + IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) && + sycl::detail::IsPlus<_T, _BinaryOperation>::value> atomic_combine(_T *ReduVarPtr) const { atomic_combine_impl( ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_add(Val); }); @@ -287,10 +282,10 @@ template class combiner { /// Atomic BITWISE OR operation: *ReduVarPtr |= MValue; template + typename _T = Ty, class _BinaryOperation = BinaryOp> enable_if_t && - IsReduOptForFastAtomicFetch::value && - sycl::detail::IsBitOR::value> + IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value && + sycl::detail::IsBitOR<_T, _BinaryOperation>::value> atomic_combine(_T *ReduVarPtr) const { atomic_combine_impl( ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_or(Val); }); @@ -298,10 +293,10 @@ template class combiner { /// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue; template + typename _T = Ty, class _BinaryOperation = BinaryOp> enable_if_t && - IsReduOptForFastAtomicFetch::value && - sycl::detail::IsBitXOR::value> + IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value && + sycl::detail::IsBitXOR<_T, _BinaryOperation>::value> atomic_combine(_T *ReduVarPtr) const { atomic_combine_impl( ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_xor(Val); }); @@ -309,10 +304,10 @@ template class combiner { /// Atomic BITWISE AND operation: *ReduVarPtr &= MValue; template - enable_if_t::type, T>::value && - IsReduOptForFastAtomicFetch::value && - sycl::detail::IsBitAND::value && + typename _T = Ty, class _BinaryOperation = BinaryOp> + enable_if_t::type, _T>::value && + IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value && + sycl::detail::IsBitAND<_T, _BinaryOperation>::value && (Space == access::address_space::global_space || Space == access::address_space::local_space)> atomic_combine(_T *ReduVarPtr) const { @@ -322,11 +317,11 @@ template class combiner { /// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue); template + typename _T = Ty, class _BinaryOperation = BinaryOp> enable_if_t && - (IsReduOptForFastAtomicFetch::value || - IsReduOptForAtomic64Op::value) && - sycl::detail::IsMinimum::value> + (IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value || + IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) && + sycl::detail::IsMinimum<_T, _BinaryOperation>::value> atomic_combine(_T *ReduVarPtr) const { atomic_combine_impl( ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_min(Val); }); @@ -334,11 +329,11 @@ template class combiner { /// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue); template + typename _T = Ty, class _BinaryOperation = BinaryOp> enable_if_t && - (IsReduOptForFastAtomicFetch::value || - IsReduOptForAtomic64Op::value) && - sycl::detail::IsMaximum::value> + (IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value || + IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) && + sycl::detail::IsMaximum<_T, _BinaryOperation>::value> atomic_combine(_T *ReduVarPtr) const { atomic_combine_impl( ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_max(Val); }); @@ -928,7 +923,7 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc, const range &Range, const nd_range<1> &NDRange, Reduction &Redu) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH); auto GroupSum = Reduction::getReadWriteLocalAcc(NElements, CGH); using Name = __sycl_reduction_kernel bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc, const range &Range, const nd_range<1> &NDRange, Reduction &Redu) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; size_t WGSize = NDRange.get_local_range().size(); size_t NWorkGroups = NDRange.get_group_range().size(); @@ -1078,7 +1073,7 @@ template bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc, const range &Range, const nd_range<1> &NDRange, Reduction &Redu) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; size_t WGSize = NDRange.get_local_range().size(); size_t NWorkGroups = NDRange.get_group_range().size(); @@ -1230,7 +1225,7 @@ template void reduCGFuncForNDRangeBothFastReduceAndAtomics( handler &CGH, KernelType KernelFunc, const nd_range &Range, Reduction &, typename Reduction::rw_accessor_type Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; using Name = __sycl_reduction_kernel< reduction::main_krn::NDRangeBothFastReduceAndAtomics, KernelName>; CGH.parallel_for(Range, [=](nd_item NDIt) { @@ -1266,7 +1261,7 @@ void reduCGFuncForNDRangeFastAtomicsOnly( handler &CGH, bool IsPow2WG, KernelType KernelFunc, const nd_range &Range, Reduction &, typename Reduction::rw_accessor_type Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; size_t WGSize = Range.get_local_range().size(); // Use local memory to reduce elements in work-groups into zero-th element. @@ -1345,7 +1340,7 @@ template void reduCGFuncForNDRangeFastReduceOnly( handler &CGH, KernelType KernelFunc, const nd_range &Range, Reduction &Redu, typename Reduction::rw_accessor_type Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; size_t NWorkGroups = Range.get_group_range().size(); bool IsUpdateOfUserVar = !Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1; @@ -1392,7 +1387,7 @@ void reduCGFuncForNDRangeBasic(handler &CGH, bool IsPow2WG, KernelType KernelFunc, const nd_range &Range, Reduction &Redu, typename Reduction::rw_accessor_type Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; size_t WGSize = Range.get_local_range().size(); size_t NWorkGroups = Range.get_group_range().size(); @@ -1477,7 +1472,7 @@ void reduAuxCGFuncFastReduceImpl(handler &CGH, bool UniformWG, size_t NWorkItems, size_t NWorkGroups, size_t WGSize, Reduction &Redu, InputT In, OutputT Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; using Name = __sycl_reduction_kernel; bool IsUpdateOfUserVar = @@ -1523,7 +1518,7 @@ void reduAuxCGFuncNoFastReduceNorAtomicImpl(handler &CGH, bool UniformPow2WG, size_t NWorkGroups, size_t WGSize, Reduction &Redu, InputT In, OutputT Out) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; bool IsUpdateOfUserVar = !Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1; @@ -1642,7 +1637,7 @@ reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) { template std::enable_if_t reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) { - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH); auto UserVarPtr = Redu.getUserRedVar(); bool IsUpdateOfUserVar = !Redu.initializeToIdentity(); @@ -2120,7 +2115,7 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc, static_assert( Reduction::has_float64_atomics, "Only suitable for reductions that have FP64 atomic operations."); - constexpr size_t NElements = Reduction::num_elements; + size_t NElements = Reduction::num_elements; using Name = __sycl_reduction_kernel; CGH.parallel_for(Range, [=](nd_item NDIt) {