Skip to content

Commit a5b03ca

Browse files
Functors for masked extract/place changed to store typed pointers
Also implement get_lws to choose local-work-group-size from given choices I0 > I1 > I2 > ..., if n > I0, use I0, if n > I1 use I1, and so on.
1 parent ff93cfc commit a5b03ca

File tree

1 file changed

+97
-58
lines changed

1 file changed

+97
-58
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 97 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
5252
typename LocalAccessorT>
5353
struct MaskedExtractStridedFunctor
5454
{
55-
MaskedExtractStridedFunctor(const char *src_data_p,
56-
const char *cumsum_data_p,
57-
char *dst_data_p,
55+
MaskedExtractStridedFunctor(const dataT *src_data_p,
56+
const indT *cumsum_data_p,
57+
dataT *dst_data_p,
5858
size_t masked_iter_size,
5959
const OrthogIndexerT &orthog_src_dst_indexer_,
6060
const MaskedSrcIndexerT &masked_src_indexer_,
6161
const MaskedDstIndexerT &masked_dst_indexer_,
6262
const LocalAccessorT &lacc_)
63-
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63+
: src(src_data_p), cumsum(cumsum_data_p), dst(dst_data_p),
6464
masked_nelems(masked_iter_size),
6565
orthog_src_dst_indexer(orthog_src_dst_indexer_),
6666
masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,19 @@ struct MaskedExtractStridedFunctor
7272

7373
void operator()(sycl::nd_item<2> ndit) const
7474
{
75-
const dataT *src_data = reinterpret_cast<const dataT *>(src_cp);
76-
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
77-
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
78-
79-
const size_t orthog_i = ndit.get_global_id(0);
80-
const size_t group_i = ndit.get_group(1);
75+
const std::size_t orthog_i = ndit.get_global_id(0);
8176
const std::uint32_t l_i = ndit.get_local_id(1);
8277
const std::uint32_t lws = ndit.get_local_range(1);
8378

84-
const size_t masked_block_start = group_i * lws;
85-
const size_t masked_i = masked_block_start + l_i;
79+
const std::size_t masked_i = ndit.get_global_id(1);
80+
const std::size_t masked_block_start = masked_i - l_i;
8681

82+
const std::size_t max_offset = masked_nelems + 1;
8783
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
8884
const size_t offset = masked_block_start + i;
89-
lacc[i] = (offset == 0) ? indT(0)
90-
: (offset - 1 < masked_nelems)
91-
? cumsum_data[offset - 1]
92-
: cumsum_data[masked_nelems - 1] + 1;
85+
lacc[i] = (offset == 0) ? indT(0)
86+
: (offset < max_offset) ? cumsum[offset - 1]
87+
: cumsum[masked_nelems - 1] + 1;
9388
}
9489

9590
sycl::group_barrier(ndit.get_group());
@@ -110,14 +105,14 @@ struct MaskedExtractStridedFunctor
110105
masked_dst_indexer(current_running_count - 1) +
111106
orthog_offsets.get_second_offset();
112107

113-
dst_data[total_dst_offset] = src_data[total_src_offset];
108+
dst[total_dst_offset] = src[total_src_offset];
114109
}
115110
}
116111

117112
private:
118-
const char *src_cp = nullptr;
119-
const char *cumsum_cp = nullptr;
120-
char *dst_cp = nullptr;
113+
const dataT *src = nullptr;
114+
const indT *cumsum = nullptr;
115+
dataT *dst = nullptr;
121116
const size_t masked_nelems = 0;
122117
// has nd, shape, src_strides, dst_strides for
123118
// dimensions that ARE NOT masked
@@ -138,15 +133,15 @@ template <typename OrthogIndexerT,
138133
typename LocalAccessorT>
139134
struct MaskedPlaceStridedFunctor
140135
{
141-
MaskedPlaceStridedFunctor(char *dst_data_p,
142-
const char *cumsum_data_p,
143-
const char *rhs_data_p,
136+
MaskedPlaceStridedFunctor(dataT *dst_data_p,
137+
const indT *cumsum_data_p,
138+
const dataT *rhs_data_p,
144139
size_t masked_iter_size,
145140
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146141
const MaskedDstIndexerT &masked_dst_indexer_,
147142
const MaskedRhsIndexerT &masked_rhs_indexer_,
148143
const LocalAccessorT &lacc_)
149-
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
144+
: dst(dst_data_p), cumsum(cumsum_data_p), rhs(rhs_data_p),
150145
masked_nelems(masked_iter_size),
151146
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152147
masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +153,19 @@ struct MaskedPlaceStridedFunctor
158153

159154
void operator()(sycl::nd_item<2> ndit) const
160155
{
161-
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
162-
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
163-
const dataT *rhs_data = reinterpret_cast<const dataT *>(rhs_cp);
164-
165156
const std::size_t orthog_i = ndit.get_global_id(0);
166-
const std::size_t group_i = ndit.get_group(1);
167157
const std::uint32_t l_i = ndit.get_local_id(1);
168158
const std::uint32_t lws = ndit.get_local_range(1);
169159

170-
const size_t masked_block_start = group_i * lws;
171-
const size_t masked_i = masked_block_start + l_i;
160+
const size_t masked_i = ndit.get_global_id(1);
161+
const size_t masked_block_start = masked_i - l_i;
172162

163+
const std::size_t max_offset = masked_nelems + 1;
173164
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
174165
const size_t offset = masked_block_start + i;
175-
lacc[i] = (offset == 0) ? indT(0)
176-
: (offset - 1 < masked_nelems)
177-
? cumsum_data[offset - 1]
178-
: cumsum_data[masked_nelems - 1] + 1;
166+
lacc[i] = (offset == 0) ? indT(0)
167+
: (offset < max_offset) ? cumsum[offset - 1]
168+
: cumsum[masked_nelems - 1] + 1;
179169
}
180170

181171
sycl::group_barrier(ndit.get_group());
@@ -196,14 +186,14 @@ struct MaskedPlaceStridedFunctor
196186
masked_rhs_indexer(current_running_count - 1) +
197187
orthog_offsets.get_second_offset();
198188

199-
dst_data[total_dst_offset] = rhs_data[total_rhs_offset];
189+
dst[total_dst_offset] = rhs[total_rhs_offset];
200190
}
201191
}
202192

203193
private:
204-
char *dst_cp = nullptr;
205-
const char *cumsum_cp = nullptr;
206-
const char *rhs_cp = nullptr;
194+
dataT *dst = nullptr;
195+
const indT *cumsum = nullptr;
196+
const dataT *rhs = nullptr;
207197
const size_t masked_nelems = 0;
208198
// has nd, shape, dst_strides, rhs_strides for
209199
// dimensions that ARE NOT masked
@@ -218,6 +208,30 @@ struct MaskedPlaceStridedFunctor
218208

219209
// ======= Masked extraction ================================
220210

211+
namespace
212+
{
213+
214+
template <std::size_t I, std::size_t... IR>
215+
std::size_t _get_lws_impl(std::size_t n)
216+
{
217+
if constexpr (sizeof...(IR) == 0) {
218+
return I;
219+
}
220+
else {
221+
return (n < I) ? _get_lws_impl<IR...>(n) : I;
222+
}
223+
}
224+
225+
std::size_t get_lws(std::size_t n)
226+
{
227+
constexpr std::size_t lws0 = 256u;
228+
constexpr std::size_t lws1 = 128u;
229+
constexpr std::size_t lws2 = 64u;
230+
return _get_lws_impl<lws0, lws1, lws2>(n);
231+
}
232+
233+
} // end of anonymous namespace
234+
221235
template <typename MaskedDstIndexerT, typename dataT, typename indT>
222236
class masked_extract_all_slices_contig_impl_krn;
223237

@@ -258,26 +272,31 @@ sycl::event masked_extract_all_slices_contig_impl(
258272
Strided1DIndexer, dataT, indT,
259273
LocalAccessorT>;
260274

261-
constexpr std::size_t nominal_lws = 256;
262275
const std::size_t masked_extent = iteration_size;
263-
const std::size_t lws = std::min(masked_extent, nominal_lws);
276+
277+
const std::size_t lws = get_lws(masked_extent);
278+
264279
const std::size_t n_groups = (iteration_size + lws - 1) / lws;
265280

266281
sycl::range<2> gRange{1, n_groups * lws};
267282
sycl::range<2> lRange{1, lws};
268283

269284
sycl::nd_range<2> ndRange(gRange, lRange);
270285

286+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
287+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
288+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
289+
271290
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
272291
cgh.depends_on(depends);
273292

274293
const std::size_t lacc_size = std::min(lws, masked_extent) + 1;
275294
LocalAccessorT lacc(lacc_size, cgh);
276295

277296
cgh.parallel_for<KernelName>(
278-
ndRange,
279-
Impl(src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
280-
masked_src_indexer, masked_dst_indexer, lacc));
297+
ndRange, Impl(src_tp, cumsum_tp, dst_tp, masked_extent,
298+
orthog_src_dst_indexer, masked_src_indexer,
299+
masked_dst_indexer, lacc));
281300
});
282301

283302
return comp_ev;
@@ -332,26 +351,31 @@ sycl::event masked_extract_all_slices_strided_impl(
332351
StridedIndexer, Strided1DIndexer,
333352
dataT, indT, LocalAccessorT>;
334353

335-
constexpr std::size_t nominal_lws = 256;
336354
const std::size_t masked_nelems = iteration_size;
337-
const std::size_t lws = std::min(masked_nelems, nominal_lws);
355+
356+
const std::size_t lws = get_lws(masked_nelems);
357+
338358
const std::size_t n_groups = (masked_nelems + lws - 1) / lws;
339359

340360
sycl::range<2> gRange{1, n_groups * lws};
341361
sycl::range<2> lRange{1, lws};
342362

343363
sycl::nd_range<2> ndRange(gRange, lRange);
344364

365+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
366+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
367+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
368+
345369
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
346370
cgh.depends_on(depends);
347371

348372
const std::size_t lacc_size = std::min(lws, masked_nelems) + 1;
349373
LocalAccessorT lacc(lacc_size, cgh);
350374

351375
cgh.parallel_for<KernelName>(
352-
ndRange,
353-
Impl(src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
354-
masked_src_indexer, masked_dst_indexer, lacc));
376+
ndRange, Impl(src_tp, cumsum_tp, dst_tp, iteration_size,
377+
orthog_src_dst_indexer, masked_src_indexer,
378+
masked_dst_indexer, lacc));
355379
});
356380

357381
return comp_ev;
@@ -422,9 +446,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422446
StridedIndexer, Strided1DIndexer,
423447
dataT, indT, LocalAccessorT>;
424448

425-
const size_t nominal_lws = 256;
426449
const std::size_t masked_extent = masked_nelems;
427-
const size_t lws = std::min(masked_extent, nominal_lws);
450+
451+
const std::size_t lws = get_lws(masked_extent);
452+
428453
const size_t n_groups = ((masked_extent + lws - 1) / lws);
429454
const size_t orthog_extent = static_cast<size_t>(orthog_nelems);
430455

@@ -433,6 +458,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433458

434459
sycl::nd_range<2> ndRange(gRange, lRange);
435460

461+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
462+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
463+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
464+
436465
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
437466
cgh.depends_on(depends);
438467

@@ -441,9 +470,9 @@ sycl::event masked_extract_some_slices_strided_impl(
441470
LocalAccessorT lacc(lacc_size, cgh);
442471

443472
cgh.parallel_for<KernelName>(
444-
ndRange,
445-
Impl(src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
446-
masked_src_indexer, masked_dst_indexer, lacc));
473+
ndRange, Impl(src_tp, cumsum_tp, dst_tp, masked_nelems,
474+
orthog_src_dst_indexer, masked_src_indexer,
475+
masked_dst_indexer, lacc));
447476
});
448477

449478
return comp_ev;
@@ -567,6 +596,10 @@ sycl::event masked_place_all_slices_strided_impl(
567596

568597
using LocalAccessorT = sycl::local_accessor<indT, 1>;
569598

599+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
600+
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
601+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
602+
570603
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
571604
cgh.depends_on(depends);
572605

@@ -578,8 +611,9 @@ sycl::event masked_place_all_slices_strided_impl(
578611
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579612
Strided1DCyclicIndexer, dataT, indT,
580613
LocalAccessorT>(
581-
dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
582-
masked_dst_indexer, masked_rhs_indexer, lacc));
614+
dst_tp, cumsum_tp, rhs_tp, iteration_size,
615+
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
616+
lacc));
583617
});
584618

585619
return comp_ev;
@@ -659,6 +693,10 @@ sycl::event masked_place_some_slices_strided_impl(
659693

660694
using LocalAccessorT = sycl::local_accessor<indT, 1>;
661695

696+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
697+
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
698+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
699+
662700
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
663701
cgh.depends_on(depends);
664702

@@ -670,8 +708,9 @@ sycl::event masked_place_some_slices_strided_impl(
670708
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671709
Strided1DCyclicIndexer, dataT, indT,
672710
LocalAccessorT>(
673-
dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
674-
masked_dst_indexer, masked_rhs_indexer, lacc));
711+
dst_tp, cumsum_tp, rhs_tp, masked_nelems,
712+
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
713+
lacc));
675714
});
676715

677716
return comp_ev;

0 commit comments

Comments
 (0)