@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
52
52
typename LocalAccessorT>
53
53
struct MaskedExtractStridedFunctor
54
54
{
55
- MaskedExtractStridedFunctor (const char *src_data_p,
56
- const char *cumsum_data_p,
57
- char *dst_data_p,
55
+ MaskedExtractStridedFunctor (const dataT *src_data_p,
56
+ const indT *cumsum_data_p,
57
+ dataT *dst_data_p,
58
58
size_t masked_iter_size,
59
59
const OrthogIndexerT &orthog_src_dst_indexer_,
60
60
const MaskedSrcIndexerT &masked_src_indexer_,
61
61
const MaskedDstIndexerT &masked_dst_indexer_,
62
62
const LocalAccessorT &lacc_)
63
- : src_cp (src_data_p), cumsum_cp (cumsum_data_p), dst_cp (dst_data_p),
63
+ : src (src_data_p), cumsum (cumsum_data_p), dst (dst_data_p),
64
64
masked_nelems (masked_iter_size),
65
65
orthog_src_dst_indexer(orthog_src_dst_indexer_),
66
66
masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,19 @@ struct MaskedExtractStridedFunctor
72
72
73
73
void operator ()(sycl::nd_item<2 > ndit) const
74
74
{
75
- const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
76
- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
77
- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
78
-
79
- const size_t orthog_i = ndit.get_global_id (0 );
80
- const size_t group_i = ndit.get_group (1 );
75
+ const std::size_t orthog_i = ndit.get_global_id (0 );
81
76
const std::uint32_t l_i = ndit.get_local_id (1 );
82
77
const std::uint32_t lws = ndit.get_local_range (1 );
83
78
84
- const size_t masked_block_start = group_i * lws ;
85
- const size_t masked_i = masked_block_start + l_i;
79
+ const std:: size_t masked_i = ndit. get_global_id ( 1 ) ;
80
+ const std:: size_t masked_block_start = masked_i - l_i;
86
81
82
+ const std::size_t max_offset = masked_nelems + 1 ;
87
83
for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
88
84
const size_t offset = masked_block_start + i;
89
- lacc[i] = (offset == 0 ) ? indT (0 )
90
- : (offset - 1 < masked_nelems)
91
- ? cumsum_data[offset - 1 ]
92
- : cumsum_data[masked_nelems - 1 ] + 1 ;
85
+ lacc[i] = (offset == 0 ) ? indT (0 )
86
+ : (offset < max_offset) ? cumsum[offset - 1 ]
87
+ : cumsum[masked_nelems - 1 ] + 1 ;
93
88
}
94
89
95
90
sycl::group_barrier (ndit.get_group ());
@@ -110,14 +105,14 @@ struct MaskedExtractStridedFunctor
110
105
masked_dst_indexer (current_running_count - 1 ) +
111
106
orthog_offsets.get_second_offset ();
112
107
113
- dst_data [total_dst_offset] = src_data [total_src_offset];
108
+ dst [total_dst_offset] = src [total_src_offset];
114
109
}
115
110
}
116
111
117
112
private:
118
- const char *src_cp = nullptr ;
119
- const char *cumsum_cp = nullptr ;
120
- char *dst_cp = nullptr ;
113
+ const dataT *src = nullptr ;
114
+ const indT *cumsum = nullptr ;
115
+ dataT *dst = nullptr ;
121
116
const size_t masked_nelems = 0 ;
122
117
// has nd, shape, src_strides, dst_strides for
123
118
// dimensions that ARE NOT masked
@@ -138,15 +133,15 @@ template <typename OrthogIndexerT,
138
133
typename LocalAccessorT>
139
134
struct MaskedPlaceStridedFunctor
140
135
{
141
- MaskedPlaceStridedFunctor (char *dst_data_p,
142
- const char *cumsum_data_p,
143
- const char *rhs_data_p,
136
+ MaskedPlaceStridedFunctor (dataT *dst_data_p,
137
+ const indT *cumsum_data_p,
138
+ const dataT *rhs_data_p,
144
139
size_t masked_iter_size,
145
140
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146
141
const MaskedDstIndexerT &masked_dst_indexer_,
147
142
const MaskedRhsIndexerT &masked_rhs_indexer_,
148
143
const LocalAccessorT &lacc_)
149
- : dst_cp (dst_data_p), cumsum_cp (cumsum_data_p), rhs_cp (rhs_data_p),
144
+ : dst (dst_data_p), cumsum (cumsum_data_p), rhs (rhs_data_p),
150
145
masked_nelems (masked_iter_size),
151
146
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152
147
masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +153,19 @@ struct MaskedPlaceStridedFunctor
158
153
159
154
void operator ()(sycl::nd_item<2 > ndit) const
160
155
{
161
- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
162
- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
163
- const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
164
-
165
156
const std::size_t orthog_i = ndit.get_global_id (0 );
166
- const std::size_t group_i = ndit.get_group (1 );
167
157
const std::uint32_t l_i = ndit.get_local_id (1 );
168
158
const std::uint32_t lws = ndit.get_local_range (1 );
169
159
170
- const size_t masked_block_start = group_i * lws ;
171
- const size_t masked_i = masked_block_start + l_i;
160
+ const size_t masked_i = ndit. get_global_id ( 1 ) ;
161
+ const size_t masked_block_start = masked_i - l_i;
172
162
163
+ const std::size_t max_offset = masked_nelems + 1 ;
173
164
for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174
165
const size_t offset = masked_block_start + i;
175
- lacc[i] = (offset == 0 ) ? indT (0 )
176
- : (offset - 1 < masked_nelems)
177
- ? cumsum_data[offset - 1 ]
178
- : cumsum_data[masked_nelems - 1 ] + 1 ;
166
+ lacc[i] = (offset == 0 ) ? indT (0 )
167
+ : (offset < max_offset) ? cumsum[offset - 1 ]
168
+ : cumsum[masked_nelems - 1 ] + 1 ;
179
169
}
180
170
181
171
sycl::group_barrier (ndit.get_group ());
@@ -196,14 +186,14 @@ struct MaskedPlaceStridedFunctor
196
186
masked_rhs_indexer (current_running_count - 1 ) +
197
187
orthog_offsets.get_second_offset ();
198
188
199
- dst_data [total_dst_offset] = rhs_data [total_rhs_offset];
189
+ dst [total_dst_offset] = rhs [total_rhs_offset];
200
190
}
201
191
}
202
192
203
193
private:
204
- char *dst_cp = nullptr ;
205
- const char *cumsum_cp = nullptr ;
206
- const char *rhs_cp = nullptr ;
194
+ dataT *dst = nullptr ;
195
+ const indT *cumsum = nullptr ;
196
+ const dataT *rhs = nullptr ;
207
197
const size_t masked_nelems = 0 ;
208
198
// has nd, shape, dst_strides, rhs_strides for
209
199
// dimensions that ARE NOT masked
@@ -218,6 +208,30 @@ struct MaskedPlaceStridedFunctor
218
208
219
209
// ======= Masked extraction ================================
220
210
211
+ namespace
212
+ {
213
+
214
+ template <std::size_t I, std::size_t ... IR>
215
+ std::size_t _get_lws_impl (std::size_t n)
216
+ {
217
+ if constexpr (sizeof ...(IR) == 0 ) {
218
+ return I;
219
+ }
220
+ else {
221
+ return (n < I) ? _get_lws_impl<IR...>(n) : I;
222
+ }
223
+ }
224
+
225
+ std::size_t get_lws (std::size_t n)
226
+ {
227
+ constexpr std::size_t lws0 = 256u ;
228
+ constexpr std::size_t lws1 = 128u ;
229
+ constexpr std::size_t lws2 = 64u ;
230
+ return _get_lws_impl<lws0, lws1, lws2>(n);
231
+ }
232
+
233
+ } // end of anonymous namespace
234
+
221
235
template <typename MaskedDstIndexerT, typename dataT, typename indT>
222
236
class masked_extract_all_slices_contig_impl_krn ;
223
237
@@ -258,26 +272,31 @@ sycl::event masked_extract_all_slices_contig_impl(
258
272
Strided1DIndexer, dataT, indT,
259
273
LocalAccessorT>;
260
274
261
- constexpr std::size_t nominal_lws = 256 ;
262
275
const std::size_t masked_extent = iteration_size;
263
- const std::size_t lws = std::min (masked_extent, nominal_lws);
276
+
277
+ const std::size_t lws = get_lws (masked_extent);
278
+
264
279
const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
265
280
266
281
sycl::range<2 > gRange {1 , n_groups * lws};
267
282
sycl::range<2 > lRange{1 , lws};
268
283
269
284
sycl::nd_range<2 > ndRange (gRange , lRange);
270
285
286
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
287
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
288
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
289
+
271
290
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
272
291
cgh.depends_on (depends);
273
292
274
293
const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
275
294
LocalAccessorT lacc (lacc_size, cgh);
276
295
277
296
cgh.parallel_for <KernelName>(
278
- ndRange,
279
- Impl (src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
280
- masked_src_indexer, masked_dst_indexer, lacc));
297
+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, masked_extent,
298
+ orthog_src_dst_indexer, masked_src_indexer ,
299
+ masked_dst_indexer, lacc));
281
300
});
282
301
283
302
return comp_ev;
@@ -332,26 +351,31 @@ sycl::event masked_extract_all_slices_strided_impl(
332
351
StridedIndexer, Strided1DIndexer,
333
352
dataT, indT, LocalAccessorT>;
334
353
335
- constexpr std::size_t nominal_lws = 256 ;
336
354
const std::size_t masked_nelems = iteration_size;
337
- const std::size_t lws = std::min (masked_nelems, nominal_lws);
355
+
356
+ const std::size_t lws = get_lws (masked_nelems);
357
+
338
358
const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
339
359
340
360
sycl::range<2 > gRange {1 , n_groups * lws};
341
361
sycl::range<2 > lRange{1 , lws};
342
362
343
363
sycl::nd_range<2 > ndRange (gRange , lRange);
344
364
365
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
366
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
367
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
368
+
345
369
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
346
370
cgh.depends_on (depends);
347
371
348
372
const std::size_t lacc_size = std::min (lws, masked_nelems) + 1 ;
349
373
LocalAccessorT lacc (lacc_size, cgh);
350
374
351
375
cgh.parallel_for <KernelName>(
352
- ndRange,
353
- Impl (src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
354
- masked_src_indexer, masked_dst_indexer, lacc));
376
+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, iteration_size,
377
+ orthog_src_dst_indexer, masked_src_indexer ,
378
+ masked_dst_indexer, lacc));
355
379
});
356
380
357
381
return comp_ev;
@@ -422,9 +446,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422
446
StridedIndexer, Strided1DIndexer,
423
447
dataT, indT, LocalAccessorT>;
424
448
425
- const size_t nominal_lws = 256 ;
426
449
const std::size_t masked_extent = masked_nelems;
427
- const size_t lws = std::min (masked_extent, nominal_lws);
450
+
451
+ const std::size_t lws = get_lws (masked_extent);
452
+
428
453
const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
429
454
const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
430
455
@@ -433,6 +458,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433
458
434
459
sycl::nd_range<2 > ndRange (gRange , lRange);
435
460
461
+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
462
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
463
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
464
+
436
465
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
437
466
cgh.depends_on (depends);
438
467
@@ -441,9 +470,9 @@ sycl::event masked_extract_some_slices_strided_impl(
441
470
LocalAccessorT lacc (lacc_size, cgh);
442
471
443
472
cgh.parallel_for <KernelName>(
444
- ndRange,
445
- Impl (src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
446
- masked_src_indexer, masked_dst_indexer, lacc));
473
+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, masked_nelems,
474
+ orthog_src_dst_indexer, masked_src_indexer ,
475
+ masked_dst_indexer, lacc));
447
476
});
448
477
449
478
return comp_ev;
@@ -567,6 +596,10 @@ sycl::event masked_place_all_slices_strided_impl(
567
596
568
597
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569
598
599
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
600
+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
601
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
602
+
570
603
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
571
604
cgh.depends_on (depends);
572
605
@@ -578,8 +611,9 @@ sycl::event masked_place_all_slices_strided_impl(
578
611
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579
612
Strided1DCyclicIndexer, dataT, indT,
580
613
LocalAccessorT>(
581
- dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
582
- masked_dst_indexer, masked_rhs_indexer, lacc));
614
+ dst_tp, cumsum_tp, rhs_tp, iteration_size,
615
+ orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
616
+ lacc));
583
617
});
584
618
585
619
return comp_ev;
@@ -659,6 +693,10 @@ sycl::event masked_place_some_slices_strided_impl(
659
693
660
694
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
661
695
696
+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
697
+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
698
+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
699
+
662
700
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
663
701
cgh.depends_on (depends);
664
702
@@ -670,8 +708,9 @@ sycl::event masked_place_some_slices_strided_impl(
670
708
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671
709
Strided1DCyclicIndexer, dataT, indT,
672
710
LocalAccessorT>(
673
- dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
674
- masked_dst_indexer, masked_rhs_indexer, lacc));
711
+ dst_tp, cumsum_tp, rhs_tp, masked_nelems,
712
+ orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
713
+ lacc));
675
714
});
676
715
677
716
return comp_ev;
0 commit comments