@@ -392,18 +392,17 @@ class reducer<T, BinaryOperation,
392
392
// / implementation classes. It is needed to detect the reduction classes.
393
393
class reduction_impl_base {};
394
394
395
- // / Predicate returning true if and only if 'FirstT' is a reduction class and
396
- // / all types except the last one from 'RestT' are reductions as well.
397
- template <typename FirstT, typename ... RestT>
398
- struct are_all_but_last_reductions {
395
+ // / Predicate returning true if all template type parameters except the last one
396
+ // / are reductions.
397
+ template <typename FirstT, typename ... RestT> struct AreAllButLastReductions {
399
398
static constexpr bool value =
400
399
std::is_base_of<reduction_impl_base, FirstT>::value &&
401
- are_all_but_last_reductions <RestT...>::value;
400
+ AreAllButLastReductions <RestT...>::value;
402
401
};
403
402
404
- // / Helper specialization of are_all_but_last_reductions for one element only.
405
- // / Returns true if the last and only typename is not a reduction.
406
- template <typename T> struct are_all_but_last_reductions <T> {
403
+ // / Helper specialization of AreAllButLastReductions for one element only.
404
+ // / Returns true if the template parameter is not a reduction.
405
+ template <typename T> struct AreAllButLastReductions <T> {
407
406
static constexpr bool value = !std::is_base_of<reduction_impl_base, T>::value;
408
407
};
409
408
@@ -1097,9 +1096,11 @@ reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
1097
1096
// / the reductions for which a local accessors are needed, this function creates
1098
1097
// / those local accessors and returns a tuple consisting of them.
1099
1098
template <typename ... Reductions, size_t ... Is>
1100
- std::tuple<typename Reductions::local_accessor_type...>
1101
- createReduLocalAccs (size_t Size, handler &CGH, std::index_sequence<Is...>) {
1102
- return {Reductions::getReadWriteLocalAcc (Size, CGH)...};
1099
+ auto createReduLocalAccs (size_t Size, handler &CGH,
1100
+ std::index_sequence<Is...>) {
1101
+ return std::make_tuple (
1102
+ std::tuple_element_t <Is, std::tuple<Reductions...>>::getReadWriteLocalAcc (
1103
+ Size, CGH)...);
1103
1104
}
1104
1105
1105
1106
// / For the given 'Reductions' types pack and indices enumerating them this
@@ -1154,7 +1155,7 @@ void callReduUserKernelFunc(KernelType KernelFunc, nd_item<Dims> NDIt,
1154
1155
KernelFunc (NDIt, std::get<Is>(Reducers)...);
1155
1156
}
1156
1157
1157
- template <bool UniformPow2WG , typename ... LocalAccT, typename ... ReducerT,
1158
+ template <bool Pow2WG , typename ... LocalAccT, typename ... ReducerT,
1158
1159
typename ... ResultT, size_t ... Is>
1159
1160
void initReduLocalAccs (size_t LID, size_t WGSize,
1160
1161
std::tuple<LocalAccT...> LocalAccs,
@@ -1163,7 +1164,11 @@ void initReduLocalAccs(size_t LID, size_t WGSize,
1163
1164
std::index_sequence<Is...>) {
1164
1165
std::tie (std::get<Is>(LocalAccs)[LID]...) =
1165
1166
std::make_tuple (std::get<Is>(Reducers).MValue ...);
1166
- if (!UniformPow2WG)
1167
+
1168
+ // For work-groups, which size is not power of two, local accessors have
1169
+ // an additional element with index WGSize that is used by the tree-reduction
1170
+ // algorithm. Initialize those additional elements with identity values here.
1171
+ if (!Pow2WG)
1167
1172
std::tie (std::get<Is>(LocalAccs)[WGSize]...) =
1168
1173
std::make_tuple (std::get<Is>(Identities)...);
1169
1174
}
@@ -1175,12 +1180,22 @@ void initReduLocalAccs(size_t LID, size_t GID, size_t NWorkItems, size_t WGSize,
1175
1180
std::tuple<LocalAccT...> InputAccs,
1176
1181
const std::tuple<ResultT...> Identities,
1177
1182
std::index_sequence<Is...>) {
1183
+ // Normally, the local accessors are initialized with elements from the input
1184
+ // accessors. The exception is the case when (GID >= NWorkItems), which
1185
+ // possible only when UniformPow2WG is false. For that case the elements of
1186
+ // local accessors are initialized with identity value, so they would not
1187
+ // give any impact into the final partial sums during the tree-reduction
1188
+ // algorithm work.
1178
1189
if (UniformPow2WG || GID < NWorkItems)
1179
1190
std::tie (std::get<Is>(LocalAccs)[LID]...) =
1180
1191
std::make_tuple (std::get<Is>(InputAccs)[GID]...);
1181
1192
else
1182
1193
std::tie (std::get<Is>(LocalAccs)[LID]...) =
1183
1194
std::make_tuple (std::get<Is>(Identities)...);
1195
+
1196
+ // For work-groups, which size is not power of two, local accessors have
1197
+ // an additional element with index WGSize that is used by the tree-reduction
1198
+ // algorithm. Initialize those additional elements with identity values here.
1184
1199
if (!UniformPow2WG)
1185
1200
std::tie (std::get<Is>(LocalAccs)[WGSize]...) =
1186
1201
std::make_tuple (std::get<Is>(Identities)...);
@@ -1196,7 +1211,7 @@ void reduceReduLocalAccs(size_t IndexA, size_t IndexB,
1196
1211
std::get<Is>(LocalAccs)[IndexB]))...);
1197
1212
}
1198
1213
1199
- template <bool UniformPow2WG , typename ... Reductions, typename ... OutAccT,
1214
+ template <bool Pow2WG , typename ... Reductions, typename ... OutAccT,
1200
1215
typename ... LocalAccT, typename ... BOPsT, size_t ... Is,
1201
1216
size_t ... RWIs>
1202
1217
void writeReduSumsToOutAccs (size_t OutAccIndex, size_t WGSize,
@@ -1214,11 +1229,16 @@ void writeReduSumsToOutAccs(size_t OutAccIndex, size_t WGSize,
1214
1229
std::tuple_element_t <RWIs, std::tuple<Reductions...>>::getOutPointer (
1215
1230
std::get<RWIs>(OutAccs))[OutAccIndex])...);
1216
1231
1217
- if (UniformPow2WG) {
1232
+ if (Pow2WG) {
1233
+ // The partial sums for the work-group are stored in 0-th elements of local
1234
+ // accessors. Simply write those sums to output accessors.
1218
1235
std::tie (std::tuple_element_t <Is, std::tuple<Reductions...>>::getOutPointer (
1219
1236
std::get<Is>(OutAccs))[OutAccIndex]...) =
1220
1237
std::make_tuple (std::get<Is>(LocalAccs)[0 ]...);
1221
1238
} else {
1239
+ // Each of local accessors keeps two partial sums: in 0-th and WGsize-th
1240
+ // elements. Combine them into final partial sums and write to output
1241
+ // accessors.
1222
1242
std::tie (std::tuple_element_t <Is, std::tuple<Reductions...>>::getOutPointer (
1223
1243
std::get<Is>(OutAccs))[OutAccIndex]...) =
1224
1244
std::make_tuple (std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[0 ],
@@ -1300,15 +1320,15 @@ constexpr auto filterSequence(FunctorT F, std::index_sequence<Is...> Indices) {
1300
1320
return filterSequenceHelper<T...>(F, Indices);
1301
1321
}
1302
1322
1303
- template <typename KernelName, bool UniformPow2WG , bool IsOneWG,
1304
- typename KernelType, int Dims, typename ... Reductions, size_t ... Is>
1323
+ template <typename KernelName, bool Pow2WG , bool IsOneWG, typename KernelType ,
1324
+ int Dims, typename ... Reductions, size_t ... Is>
1305
1325
void reduCGFuncImpl (handler &CGH, KernelType KernelFunc,
1306
1326
const nd_range<Dims> &Range,
1307
1327
std::tuple<Reductions...> &ReduTuple,
1308
1328
std::index_sequence<Is...> ReduIndices) {
1309
1329
1310
1330
size_t WGSize = Range.get_local_range ().size ();
1311
- size_t LocalAccSize = WGSize + (UniformPow2WG ? 0 : 1 );
1331
+ size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1 );
1312
1332
auto LocalAccsTuple =
1313
1333
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
1314
1334
@@ -1318,10 +1338,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
1318
1338
auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
1319
1339
auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
1320
1340
1321
- using Name =
1322
- typename get_reduction_main_kernel_name_t <KernelName, KernelType,
1323
- UniformPow2WG, IsOneWG,
1324
- decltype (OutAccsTuple)>::name;
1341
+ using Name = typename get_reduction_main_kernel_name_t <
1342
+ KernelName, KernelType, Pow2WG, IsOneWG, decltype (OutAccsTuple)>::name;
1325
1343
CGH.parallel_for <Name>(Range, [=](nd_item<Dims> NDIt) {
1326
1344
auto ReduIndices = std::index_sequence_for<Reductions...>();
1327
1345
auto ReducersTuple =
@@ -1332,8 +1350,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
1332
1350
1333
1351
size_t WGSize = NDIt.get_local_range ().size ();
1334
1352
size_t LID = NDIt.get_local_linear_id ();
1335
- initReduLocalAccs<UniformPow2WG >(LID, WGSize, LocalAccsTuple, ReducersTuple,
1336
- IdentitiesTuple, ReduIndices);
1353
+ initReduLocalAccs<Pow2WG >(LID, WGSize, LocalAccsTuple, ReducersTuple,
1354
+ IdentitiesTuple, ReduIndices);
1337
1355
NDIt.barrier ();
1338
1356
1339
1357
size_t PrevStep = WGSize;
@@ -1342,7 +1360,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
1342
1360
// LocalReds[LID] = BOp(LocalReds[LID], LocalReds[LID + CurStep]);
1343
1361
reduceReduLocalAccs (LID, LID + CurStep, LocalAccsTuple, BOPsTuple,
1344
1362
ReduIndices);
1345
- } else if (!UniformPow2WG && LID == CurStep && (PrevStep & 0x1 )) {
1363
+ } else if (!Pow2WG && LID == CurStep && (PrevStep & 0x1 )) {
1346
1364
// LocalReds[WGSize] = BOp(LocalReds[WGSize], LocalReds[PrevStep - 1]);
1347
1365
reduceReduLocalAccs (WGSize, PrevStep - 1 , LocalAccsTuple, BOPsTuple,
1348
1366
ReduIndices);
@@ -1363,7 +1381,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
1363
1381
Predicate;
1364
1382
auto RWReduIndices =
1365
1383
filterSequence<Reductions...>(Predicate, ReduIndices);
1366
- writeReduSumsToOutAccs<UniformPow2WG >(
1384
+ writeReduSumsToOutAccs<Pow2WG >(
1367
1385
GrID, WGSize, (std::tuple<Reductions...> *)nullptr , OutAccsTuple,
1368
1386
LocalAccsTuple, BOPsTuple, ReduIndices, RWReduIndices);
1369
1387
}
@@ -1376,21 +1394,18 @@ void reduCGFunc(handler &CGH, KernelType KernelFunc,
1376
1394
const nd_range<Dims> &Range,
1377
1395
std::tuple<Reductions...> &ReduTuple,
1378
1396
std::index_sequence<Is...> ReduIndices) {
1379
- size_t NWorkItems = Range.get_global_range ().size ();
1380
1397
size_t WGSize = Range.get_local_range ().size ();
1381
1398
size_t NWorkGroups = Range.get_group_range ().size ();
1382
-
1383
1399
bool Pow2WG = (WGSize & (WGSize - 1 )) == 0 ;
1384
- bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
1385
1400
if (NWorkGroups == 1 ) {
1386
- if (HasUniformWG )
1401
+ if (Pow2WG )
1387
1402
reduCGFuncImpl<KernelName, true , true >(CGH, KernelFunc, Range, ReduTuple,
1388
1403
ReduIndices);
1389
1404
else
1390
1405
reduCGFuncImpl<KernelName, false , true >(CGH, KernelFunc, Range, ReduTuple,
1391
1406
ReduIndices);
1392
1407
} else {
1393
- if (HasUniformWG )
1408
+ if (Pow2WG )
1394
1409
reduCGFuncImpl<KernelName, true , false >(CGH, KernelFunc, Range, ReduTuple,
1395
1410
ReduIndices);
1396
1411
else
0 commit comments