@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
52
52
typename LocalAccessorT>
53
53
struct MaskedExtractStridedFunctor
54
54
{
55
- MaskedExtractStridedFunctor (const char *src_data_p,
56
- const char *cumsum_data_p,
57
- char *dst_data_p,
55
+ MaskedExtractStridedFunctor (const dataT *src_data_p,
56
+ const indT *cumsum_data_p,
57
+ dataT *dst_data_p,
58
58
size_t masked_iter_size,
59
59
const OrthogIndexerT &orthog_src_dst_indexer_,
60
60
const MaskedSrcIndexerT &masked_src_indexer_,
61
61
const MaskedDstIndexerT &masked_dst_indexer_,
62
62
const LocalAccessorT &lacc_)
63
- : src_cp (src_data_p), cumsum_cp (cumsum_data_p), dst_cp (dst_data_p),
63
+ : src (src_data_p), cumsum (cumsum_data_p), dst (dst_data_p),
64
64
masked_nelems (masked_iter_size),
65
65
orthog_src_dst_indexer(orthog_src_dst_indexer_),
66
66
masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,20 @@ struct MaskedExtractStridedFunctor
72
72
73
73
void operator ()(sycl::nd_item<2 > ndit) const
74
74
{
75
- const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
76
- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
77
- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
78
-
79
- const size_t orthog_i = ndit.get_global_id (0 );
80
- const size_t group_i = ndit.get_group (1 );
75
+ const std::size_t orthog_i = ndit.get_global_id (0 );
81
76
const std::uint32_t l_i = ndit.get_local_id (1 );
82
77
const std::uint32_t lws = ndit.get_local_range (1 );
83
78
84
- const size_t masked_block_start = group_i * lws ;
85
- const size_t masked_i = masked_block_start + l_i;
79
+ const std:: size_t masked_i = ndit. get_global_id ( 1 ) ;
80
+ const std:: size_t masked_block_start = masked_i - l_i;
86
81
82
+ const std::size_t max_offset = masked_nelems + 1 ;
87
83
for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
88
84
const size_t offset = masked_block_start + i;
89
85
lacc[i] = (offset == 0 ) ? indT (0 )
90
- : (offset - 1 < masked_nelems )
91
- ? cumsum_data [offset - 1 ]
92
- : cumsum_data [masked_nelems - 1 ] + 1 ;
86
+ : (offset < max_offset )
87
+ ? cumsum [offset - 1 ]
88
+ : cumsum [masked_nelems - 1 ] + 1 ;
93
89
}
94
90
95
91
sycl::group_barrier (ndit.get_group ());
@@ -110,14 +106,14 @@ struct MaskedExtractStridedFunctor
110
106
masked_dst_indexer (current_running_count - 1 ) +
111
107
orthog_offsets.get_second_offset ();
112
108
113
- dst_data [total_dst_offset] = src_data [total_src_offset];
109
+ dst [total_dst_offset] = src [total_src_offset];
114
110
}
115
111
}
116
112
117
113
private:
118
- const char *src_cp = nullptr ;
119
- const char *cumsum_cp = nullptr ;
120
- char *dst_cp = nullptr ;
114
+ const dataT *src = nullptr ;
115
+ const indT *cumsum = nullptr ;
116
+ dataT *dst = nullptr ;
121
117
const size_t masked_nelems = 0 ;
122
118
// has nd, shape, src_strides, dst_strides for
123
119
// dimensions that ARE NOT masked
@@ -138,15 +134,15 @@ template <typename OrthogIndexerT,
138
134
typename LocalAccessorT>
139
135
struct MaskedPlaceStridedFunctor
140
136
{
141
- MaskedPlaceStridedFunctor (char *dst_data_p,
142
- const char *cumsum_data_p,
143
- const char *rhs_data_p,
137
+ MaskedPlaceStridedFunctor (dataT *dst_data_p,
138
+ const indT *cumsum_data_p,
139
+ const dataT *rhs_data_p,
144
140
size_t masked_iter_size,
145
141
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146
142
const MaskedDstIndexerT &masked_dst_indexer_,
147
143
const MaskedRhsIndexerT &masked_rhs_indexer_,
148
144
const LocalAccessorT &lacc_)
149
- : dst_cp (dst_data_p), cumsum_cp (cumsum_data_p), rhs_cp (rhs_data_p),
145
+ : dst (dst_data_p), cumsum (cumsum_data_p), rhs (rhs_data_p),
150
146
masked_nelems (masked_iter_size),
151
147
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152
148
masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +154,20 @@ struct MaskedPlaceStridedFunctor
158
154
159
155
void operator ()(sycl::nd_item<2 > ndit) const
160
156
{
161
- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
162
- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
163
- const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
164
-
165
157
const std::size_t orthog_i = ndit.get_global_id (0 );
166
- const std::size_t group_i = ndit.get_group (1 );
167
158
const std::uint32_t l_i = ndit.get_local_id (1 );
168
159
const std::uint32_t lws = ndit.get_local_range (1 );
169
160
170
- const size_t masked_block_start = group_i * lws ;
171
- const size_t masked_i = masked_block_start + l_i;
161
+ const size_t masked_i = ndit. get_global_id ( 1 ) ;
162
+ const size_t masked_block_start = masked_i - l_i;
172
163
164
+ const std::size_t max_offset = masked_nelems + 1 ;
173
165
for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174
166
const size_t offset = masked_block_start + i;
175
167
lacc[i] = (offset == 0 ) ? indT (0 )
176
- : (offset - 1 < masked_nelems )
177
- ? cumsum_data [offset - 1 ]
178
- : cumsum_data [masked_nelems - 1 ] + 1 ;
168
+ : (offset < max_offset )
169
+ ? cumsum [offset - 1 ]
170
+ : cumsum [masked_nelems - 1 ] + 1 ;
179
171
}
180
172
181
173
sycl::group_barrier (ndit.get_group ());
@@ -196,14 +188,14 @@ struct MaskedPlaceStridedFunctor
196
188
masked_rhs_indexer (current_running_count - 1 ) +
197
189
orthog_offsets.get_second_offset ();
198
190
199
- dst_data [total_dst_offset] = rhs_data [total_rhs_offset];
191
+ dst [total_dst_offset] = rhs [total_rhs_offset];
200
192
}
201
193
}
202
194
203
195
private:
204
- char *dst_cp = nullptr ;
205
- const char *cumsum_cp = nullptr ;
206
- const char *rhs_cp = nullptr ;
196
+ dataT *dst = nullptr ;
197
+ const indT *cumsum = nullptr ;
198
+ const dataT *rhs = nullptr ;
207
199
const size_t masked_nelems = 0 ;
208
200
// has nd, shape, dst_strides, rhs_strides for
209
201
// dimensions that ARE NOT masked
@@ -218,6 +210,26 @@ struct MaskedPlaceStridedFunctor
218
210
219
211
// ======= Masked extraction ================================
220
212
213
+ namespace {
214
+
215
+ template <std::size_t I, std::size_t ... IR>
216
+ std::size_t _get_lws_impl (std::size_t n) {
217
+ if constexpr (sizeof ...(IR) == 0 ) {
218
+ return I;
219
+ } else {
220
+ return (n < I) ? _get_lws_impl<IR...>(n) : I;
221
+ }
222
+ }
223
+
224
+ std::size_t get_lws (std::size_t n) {
225
+ constexpr std::size_t lws0 = 256u ;
226
+ constexpr std::size_t lws1 = 128u ;
227
+ constexpr std::size_t lws2 = 64u ;
228
+ return _get_lws_impl<lws0, lws1, lws2>(n);
229
+ }
230
+
231
+ } // end of anonymous namespace
232
+
221
233
template <typename MaskedDstIndexerT, typename dataT, typename indT>
222
234
class masked_extract_all_slices_contig_impl_krn ;
223
235
@@ -258,16 +270,21 @@ sycl::event masked_extract_all_slices_contig_impl(
258
270
Strided1DIndexer, dataT, indT,
259
271
LocalAccessorT>;
260
272
261
- constexpr std::size_t nominal_lws = 256 ;
262
273
const std::size_t masked_extent = iteration_size;
263
- const std::size_t lws = std::min (masked_extent, nominal_lws);
274
+
275
+ const std::size_t lws = get_lws (masked_extent);
276
+
264
277
const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
265
278
266
279
sycl::range<2 > gRange {1 , n_groups * lws};
267
280
sycl::range<2 > lRange{1 , lws};
268
281
269
282
sycl::nd_range<2 > ndRange (gRange , lRange);
270
283
284
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
285
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
286
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
287
+
271
288
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
272
289
cgh.depends_on (depends);
273
290
@@ -276,7 +293,7 @@ sycl::event masked_extract_all_slices_contig_impl(
276
293
277
294
cgh.parallel_for <KernelName>(
278
295
ndRange,
279
- Impl (src_p, cumsum_p, dst_p , masked_extent, orthog_src_dst_indexer,
296
+ Impl (src_tp, cumsum_tp, dst_tp , masked_extent, orthog_src_dst_indexer,
280
297
masked_src_indexer, masked_dst_indexer, lacc));
281
298
});
282
299
@@ -332,16 +349,21 @@ sycl::event masked_extract_all_slices_strided_impl(
332
349
StridedIndexer, Strided1DIndexer,
333
350
dataT, indT, LocalAccessorT>;
334
351
335
- constexpr std::size_t nominal_lws = 256 ;
336
352
const std::size_t masked_nelems = iteration_size;
337
- const std::size_t lws = std::min (masked_nelems, nominal_lws);
353
+
354
+ const std::size_t lws = get_lws (masked_nelems);
355
+
338
356
const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
339
357
340
358
sycl::range<2 > gRange {1 , n_groups * lws};
341
359
sycl::range<2 > lRange{1 , lws};
342
360
343
361
sycl::nd_range<2 > ndRange (gRange , lRange);
344
362
363
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
364
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
365
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
366
+
345
367
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
346
368
cgh.depends_on (depends);
347
369
@@ -350,7 +372,7 @@ sycl::event masked_extract_all_slices_strided_impl(
350
372
351
373
cgh.parallel_for <KernelName>(
352
374
ndRange,
353
- Impl (src_p, cumsum_p, dst_p , iteration_size, orthog_src_dst_indexer,
375
+ Impl (src_tp, cumsum_tp, dst_tp , iteration_size, orthog_src_dst_indexer,
354
376
masked_src_indexer, masked_dst_indexer, lacc));
355
377
});
356
378
@@ -422,9 +444,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422
444
StridedIndexer, Strided1DIndexer,
423
445
dataT, indT, LocalAccessorT>;
424
446
425
- const size_t nominal_lws = 256 ;
426
447
const std::size_t masked_extent = masked_nelems;
427
- const size_t lws = std::min (masked_extent, nominal_lws);
448
+
449
+ const std::size_t lws = get_lws (masked_extent);
450
+
428
451
const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
429
452
const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
430
453
@@ -433,6 +456,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433
456
434
457
sycl::nd_range<2 > ndRange (gRange , lRange);
435
458
459
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
460
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
461
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
462
+
436
463
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
437
464
cgh.depends_on (depends);
438
465
@@ -442,7 +469,7 @@ sycl::event masked_extract_some_slices_strided_impl(
442
469
443
470
cgh.parallel_for <KernelName>(
444
471
ndRange,
445
- Impl (src_p, cumsum_p, dst_p , masked_nelems, orthog_src_dst_indexer,
472
+ Impl (src_tp, cumsum_tp, dst_tp , masked_nelems, orthog_src_dst_indexer,
446
473
masked_src_indexer, masked_dst_indexer, lacc));
447
474
});
448
475
@@ -567,6 +594,10 @@ sycl::event masked_place_all_slices_strided_impl(
567
594
568
595
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569
596
597
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
598
+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
599
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
600
+
570
601
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
571
602
cgh.depends_on (depends);
572
603
@@ -578,7 +609,7 @@ sycl::event masked_place_all_slices_strided_impl(
578
609
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579
610
Strided1DCyclicIndexer, dataT, indT,
580
611
LocalAccessorT>(
581
- dst_p, cumsum_p, rhs_p , iteration_size, orthog_dst_rhs_indexer,
612
+ dst_tp, cumsum_tp, rhs_tp , iteration_size, orthog_dst_rhs_indexer,
582
613
masked_dst_indexer, masked_rhs_indexer, lacc));
583
614
});
584
615
@@ -659,6 +690,10 @@ sycl::event masked_place_some_slices_strided_impl(
659
690
660
691
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
661
692
693
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
694
+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
695
+ const indT* cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
696
+
662
697
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
663
698
cgh.depends_on (depends);
664
699
@@ -670,7 +705,7 @@ sycl::event masked_place_some_slices_strided_impl(
670
705
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671
706
Strided1DCyclicIndexer, dataT, indT,
672
707
LocalAccessorT>(
673
- dst_p, cumsum_p, rhs_p , masked_nelems, orthog_dst_rhs_indexer,
708
+ dst_tp, cumsum_tp, rhs_tp , masked_nelems, orthog_dst_rhs_indexer,
674
709
masked_dst_indexer, masked_rhs_indexer, lacc));
675
710
});
676
711
0 commit comments