Skip to content

Commit 94e1be0

Browse files
committed
Address reviewer's comments. Fixes are mostly NFC and do not change code-gen at this moment.
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 7a291f9 commit 94e1be0

File tree

2 files changed

+49
-36
lines changed

2 files changed

+49
-36
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -392,18 +392,17 @@ class reducer<T, BinaryOperation,
392392
/// implementation classes. It is needed to detect the reduction classes.
393393
class reduction_impl_base {};
394394

395-
/// Predicate returning true if and only if 'FirstT' is a reduction class and
396-
/// all types except the last one from 'RestT' are reductions as well.
397-
template <typename FirstT, typename... RestT>
398-
struct are_all_but_last_reductions {
395+
/// Predicate returning true if all template type parameters except the last one
396+
/// are reductions.
397+
template <typename FirstT, typename... RestT> struct AreAllButLastReductions {
399398
static constexpr bool value =
400399
std::is_base_of<reduction_impl_base, FirstT>::value &&
401-
are_all_but_last_reductions<RestT...>::value;
400+
AreAllButLastReductions<RestT...>::value;
402401
};
403402

404-
/// Helper specialization of are_all_but_last_reductions for one element only.
405-
/// Returns true if the last and only typename is not a reduction.
406-
template <typename T> struct are_all_but_last_reductions<T> {
403+
/// Helper specialization of AreAllButLastReductions for one element only.
404+
/// Returns true if the template parameter is not a reduction.
405+
template <typename T> struct AreAllButLastReductions<T> {
407406
static constexpr bool value = !std::is_base_of<reduction_impl_base, T>::value;
408407
};
409408

@@ -1097,9 +1096,11 @@ reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
10971096
/// the reductions for which a local accessors are needed, this function creates
10981097
/// those local accessors and returns a tuple consisting of them.
10991098
template <typename... Reductions, size_t... Is>
1100-
std::tuple<typename Reductions::local_accessor_type...>
1101-
createReduLocalAccs(size_t Size, handler &CGH, std::index_sequence<Is...>) {
1102-
return {Reductions::getReadWriteLocalAcc(Size, CGH)...};
1099+
auto createReduLocalAccs(size_t Size, handler &CGH,
1100+
std::index_sequence<Is...>) {
1101+
return std::make_tuple(
1102+
std::tuple_element_t<Is, std::tuple<Reductions...>>::getReadWriteLocalAcc(
1103+
Size, CGH)...);
11031104
}
11041105

11051106
/// For the given 'Reductions' types pack and indices enumerating them this
@@ -1154,7 +1155,7 @@ void callReduUserKernelFunc(KernelType KernelFunc, nd_item<Dims> NDIt,
11541155
KernelFunc(NDIt, std::get<Is>(Reducers)...);
11551156
}
11561157

1157-
template <bool UniformPow2WG, typename... LocalAccT, typename... ReducerT,
1158+
template <bool Pow2WG, typename... LocalAccT, typename... ReducerT,
11581159
typename... ResultT, size_t... Is>
11591160
void initReduLocalAccs(size_t LID, size_t WGSize,
11601161
std::tuple<LocalAccT...> LocalAccs,
@@ -1163,7 +1164,11 @@ void initReduLocalAccs(size_t LID, size_t WGSize,
11631164
std::index_sequence<Is...>) {
11641165
std::tie(std::get<Is>(LocalAccs)[LID]...) =
11651166
std::make_tuple(std::get<Is>(Reducers).MValue...);
1166-
if (!UniformPow2WG)
1167+
1168+
// For work-groups, which size is not power of two, local accessors have
1169+
// an additional element with index WGSize that is used by the tree-reduction
1170+
// algorithm. Initialize those additional elements with identity values here.
1171+
if (!Pow2WG)
11671172
std::tie(std::get<Is>(LocalAccs)[WGSize]...) =
11681173
std::make_tuple(std::get<Is>(Identities)...);
11691174
}
@@ -1175,12 +1180,22 @@ void initReduLocalAccs(size_t LID, size_t GID, size_t NWorkItems, size_t WGSize,
11751180
std::tuple<LocalAccT...> InputAccs,
11761181
const std::tuple<ResultT...> Identities,
11771182
std::index_sequence<Is...>) {
1183+
// Normally, the local accessors are initialized with elements from the input
1184+
// accessors. The exception is the case when (GID >= NWorkItems), which
1185+
// possible only when UniformPow2WG is false. For that case the elements of
1186+
// local accessors are initialized with identity value, so they would not
1187+
// give any impact into the final partial sums during the tree-reduction
1188+
// algorithm work.
11781189
if (UniformPow2WG || GID < NWorkItems)
11791190
std::tie(std::get<Is>(LocalAccs)[LID]...) =
11801191
std::make_tuple(std::get<Is>(InputAccs)[GID]...);
11811192
else
11821193
std::tie(std::get<Is>(LocalAccs)[LID]...) =
11831194
std::make_tuple(std::get<Is>(Identities)...);
1195+
1196+
// For work-groups, which size is not power of two, local accessors have
1197+
// an additional element with index WGSize that is used by the tree-reduction
1198+
// algorithm. Initialize those additional elements with identity values here.
11841199
if (!UniformPow2WG)
11851200
std::tie(std::get<Is>(LocalAccs)[WGSize]...) =
11861201
std::make_tuple(std::get<Is>(Identities)...);
@@ -1196,7 +1211,7 @@ void reduceReduLocalAccs(size_t IndexA, size_t IndexB,
11961211
std::get<Is>(LocalAccs)[IndexB]))...);
11971212
}
11981213

1199-
template <bool UniformPow2WG, typename... Reductions, typename... OutAccT,
1214+
template <bool Pow2WG, typename... Reductions, typename... OutAccT,
12001215
typename... LocalAccT, typename... BOPsT, size_t... Is,
12011216
size_t... RWIs>
12021217
void writeReduSumsToOutAccs(size_t OutAccIndex, size_t WGSize,
@@ -1214,11 +1229,16 @@ void writeReduSumsToOutAccs(size_t OutAccIndex, size_t WGSize,
12141229
std::tuple_element_t<RWIs, std::tuple<Reductions...>>::getOutPointer(
12151230
std::get<RWIs>(OutAccs))[OutAccIndex])...);
12161231

1217-
if (UniformPow2WG) {
1232+
if (Pow2WG) {
1233+
// The partial sums for the work-group are stored in 0-th elements of local
1234+
// accessors. Simply write those sums to output accessors.
12181235
std::tie(std::tuple_element_t<Is, std::tuple<Reductions...>>::getOutPointer(
12191236
std::get<Is>(OutAccs))[OutAccIndex]...) =
12201237
std::make_tuple(std::get<Is>(LocalAccs)[0]...);
12211238
} else {
1239+
// Each of local accessors keeps two partial sums: in 0-th and WGsize-th
1240+
// elements. Combine them into final partial sums and write to output
1241+
// accessors.
12221242
std::tie(std::tuple_element_t<Is, std::tuple<Reductions...>>::getOutPointer(
12231243
std::get<Is>(OutAccs))[OutAccIndex]...) =
12241244
std::make_tuple(std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[0],
@@ -1300,15 +1320,15 @@ constexpr auto filterSequence(FunctorT F, std::index_sequence<Is...> Indices) {
13001320
return filterSequenceHelper<T...>(F, Indices);
13011321
}
13021322

1303-
template <typename KernelName, bool UniformPow2WG, bool IsOneWG,
1304-
typename KernelType, int Dims, typename... Reductions, size_t... Is>
1323+
template <typename KernelName, bool Pow2WG, bool IsOneWG, typename KernelType,
1324+
int Dims, typename... Reductions, size_t... Is>
13051325
void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13061326
const nd_range<Dims> &Range,
13071327
std::tuple<Reductions...> &ReduTuple,
13081328
std::index_sequence<Is...> ReduIndices) {
13091329

13101330
size_t WGSize = Range.get_local_range().size();
1311-
size_t LocalAccSize = WGSize + (UniformPow2WG ? 0 : 1);
1331+
size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1);
13121332
auto LocalAccsTuple =
13131333
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
13141334

@@ -1318,10 +1338,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13181338
auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices);
13191339
auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices);
13201340

1321-
using Name =
1322-
typename get_reduction_main_kernel_name_t<KernelName, KernelType,
1323-
UniformPow2WG, IsOneWG,
1324-
decltype(OutAccsTuple)>::name;
1341+
using Name = typename get_reduction_main_kernel_name_t<
1342+
KernelName, KernelType, Pow2WG, IsOneWG, decltype(OutAccsTuple)>::name;
13251343
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
13261344
auto ReduIndices = std::index_sequence_for<Reductions...>();
13271345
auto ReducersTuple =
@@ -1332,8 +1350,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13321350

13331351
size_t WGSize = NDIt.get_local_range().size();
13341352
size_t LID = NDIt.get_local_linear_id();
1335-
initReduLocalAccs<UniformPow2WG>(LID, WGSize, LocalAccsTuple, ReducersTuple,
1336-
IdentitiesTuple, ReduIndices);
1353+
initReduLocalAccs<Pow2WG>(LID, WGSize, LocalAccsTuple, ReducersTuple,
1354+
IdentitiesTuple, ReduIndices);
13371355
NDIt.barrier();
13381356

13391357
size_t PrevStep = WGSize;
@@ -1342,7 +1360,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13421360
// LocalReds[LID] = BOp(LocalReds[LID], LocalReds[LID + CurStep]);
13431361
reduceReduLocalAccs(LID, LID + CurStep, LocalAccsTuple, BOPsTuple,
13441362
ReduIndices);
1345-
} else if (!UniformPow2WG && LID == CurStep && (PrevStep & 0x1)) {
1363+
} else if (!Pow2WG && LID == CurStep && (PrevStep & 0x1)) {
13461364
// LocalReds[WGSize] = BOp(LocalReds[WGSize], LocalReds[PrevStep - 1]);
13471365
reduceReduLocalAccs(WGSize, PrevStep - 1, LocalAccsTuple, BOPsTuple,
13481366
ReduIndices);
@@ -1363,7 +1381,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13631381
Predicate;
13641382
auto RWReduIndices =
13651383
filterSequence<Reductions...>(Predicate, ReduIndices);
1366-
writeReduSumsToOutAccs<UniformPow2WG>(
1384+
writeReduSumsToOutAccs<Pow2WG>(
13671385
GrID, WGSize, (std::tuple<Reductions...> *)nullptr, OutAccsTuple,
13681386
LocalAccsTuple, BOPsTuple, ReduIndices, RWReduIndices);
13691387
}
@@ -1376,21 +1394,18 @@ void reduCGFunc(handler &CGH, KernelType KernelFunc,
13761394
const nd_range<Dims> &Range,
13771395
std::tuple<Reductions...> &ReduTuple,
13781396
std::index_sequence<Is...> ReduIndices) {
1379-
size_t NWorkItems = Range.get_global_range().size();
13801397
size_t WGSize = Range.get_local_range().size();
13811398
size_t NWorkGroups = Range.get_group_range().size();
1382-
13831399
bool Pow2WG = (WGSize & (WGSize - 1)) == 0;
1384-
bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
13851400
if (NWorkGroups == 1) {
1386-
if (HasUniformWG)
1401+
if (Pow2WG)
13871402
reduCGFuncImpl<KernelName, true, true>(CGH, KernelFunc, Range, ReduTuple,
13881403
ReduIndices);
13891404
else
13901405
reduCGFuncImpl<KernelName, false, true>(CGH, KernelFunc, Range, ReduTuple,
13911406
ReduIndices);
13921407
} else {
1393-
if (HasUniformWG)
1408+
if (Pow2WG)
13941409
reduCGFuncImpl<KernelName, true, false>(CGH, KernelFunc, Range, ReduTuple,
13951410
ReduIndices);
13961411
else

sycl/include/CL/sycl/handler.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ template <typename TupleT, std::size_t... Is>
269269
std::tuple<std::tuple_element_t<Is, TupleT>...>
270270
tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>);
271271

272-
template <typename FirstT, typename... RestT>
273-
struct are_all_but_last_reductions;
272+
template <typename FirstT, typename... RestT> struct AreAllButLastReductions;
274273

275274
} // namespace detail
276275
} // namespace ONEAPI
@@ -1288,9 +1287,8 @@ class __SYCL_EXPORT handler {
12881287
// versions handling 1 reduction variable are more efficient right now.
12891288
template <typename KernelName = detail::auto_name, int Dims,
12901289
typename... RestT>
1291-
std::enable_if_t<
1292-
(sizeof...(RestT) >= 3 &&
1293-
ONEAPI::detail::are_all_but_last_reductions<RestT...>::value)>
1290+
std::enable_if_t<(sizeof...(RestT) >= 3 &&
1291+
ONEAPI::detail::AreAllButLastReductions<RestT...>::value)>
12941292
parallel_for(nd_range<Dims> Range, RestT... Rest) {
12951293
std::tuple<RestT...> ArgsTuple(Rest...);
12961294
constexpr size_t NumArgs = sizeof...(RestT);

0 commit comments

Comments
 (0)