Skip to content

Commit f9abe3e

Browse files
Functors for masked extract/place changed to store typed pointers
Also implement get_lws to choose local-work-group-size from given choices I0 > I1 > I2 > ..., if n > I0, use I0, if n > I1 use I1, and so on.
1 parent ff93cfc commit f9abe3e

File tree

1 file changed

+83
-48
lines changed

1 file changed

+83
-48
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 83 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
5252
typename LocalAccessorT>
5353
struct MaskedExtractStridedFunctor
5454
{
55-
MaskedExtractStridedFunctor(const char *src_data_p,
56-
const char *cumsum_data_p,
57-
char *dst_data_p,
55+
MaskedExtractStridedFunctor(const dataT *src_data_p,
56+
const indT *cumsum_data_p,
57+
dataT *dst_data_p,
5858
size_t masked_iter_size,
5959
const OrthogIndexerT &orthog_src_dst_indexer_,
6060
const MaskedSrcIndexerT &masked_src_indexer_,
6161
const MaskedDstIndexerT &masked_dst_indexer_,
6262
const LocalAccessorT &lacc_)
63-
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63+
: src(src_data_p), cumsum(cumsum_data_p), dst(dst_data_p),
6464
masked_nelems(masked_iter_size),
6565
orthog_src_dst_indexer(orthog_src_dst_indexer_),
6666
masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,20 @@ struct MaskedExtractStridedFunctor
7272

7373
void operator()(sycl::nd_item<2> ndit) const
7474
{
75-
const dataT *src_data = reinterpret_cast<const dataT *>(src_cp);
76-
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
77-
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
78-
79-
const size_t orthog_i = ndit.get_global_id(0);
80-
const size_t group_i = ndit.get_group(1);
75+
const std::size_t orthog_i = ndit.get_global_id(0);
8176
const std::uint32_t l_i = ndit.get_local_id(1);
8277
const std::uint32_t lws = ndit.get_local_range(1);
8378

84-
const size_t masked_block_start = group_i * lws;
85-
const size_t masked_i = masked_block_start + l_i;
79+
const std::size_t masked_i = ndit.get_global_id(1);
80+
const std::size_t masked_block_start = masked_i - l_i;
8681

82+
const std::size_t max_offset = masked_nelems + 1;
8783
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
8884
const size_t offset = masked_block_start + i;
8985
lacc[i] = (offset == 0) ? indT(0)
90-
: (offset - 1 < masked_nelems)
91-
? cumsum_data[offset - 1]
92-
: cumsum_data[masked_nelems - 1] + 1;
86+
: (offset < max_offset)
87+
? cumsum[offset - 1]
88+
: cumsum[masked_nelems - 1] + 1;
9389
}
9490

9591
sycl::group_barrier(ndit.get_group());
@@ -110,14 +106,14 @@ struct MaskedExtractStridedFunctor
110106
masked_dst_indexer(current_running_count - 1) +
111107
orthog_offsets.get_second_offset();
112108

113-
dst_data[total_dst_offset] = src_data[total_src_offset];
109+
dst[total_dst_offset] = src[total_src_offset];
114110
}
115111
}
116112

117113
private:
118-
const char *src_cp = nullptr;
119-
const char *cumsum_cp = nullptr;
120-
char *dst_cp = nullptr;
114+
const dataT *src = nullptr;
115+
const indT *cumsum = nullptr;
116+
dataT *dst = nullptr;
121117
const size_t masked_nelems = 0;
122118
// has nd, shape, src_strides, dst_strides for
123119
// dimensions that ARE NOT masked
@@ -138,15 +134,15 @@ template <typename OrthogIndexerT,
138134
typename LocalAccessorT>
139135
struct MaskedPlaceStridedFunctor
140136
{
141-
MaskedPlaceStridedFunctor(char *dst_data_p,
142-
const char *cumsum_data_p,
143-
const char *rhs_data_p,
137+
MaskedPlaceStridedFunctor(dataT *dst_data_p,
138+
const indT *cumsum_data_p,
139+
const dataT *rhs_data_p,
144140
size_t masked_iter_size,
145141
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146142
const MaskedDstIndexerT &masked_dst_indexer_,
147143
const MaskedRhsIndexerT &masked_rhs_indexer_,
148144
const LocalAccessorT &lacc_)
149-
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
145+
: dst(dst_data_p), cumsum(cumsum_data_p), rhs(rhs_data_p),
150146
masked_nelems(masked_iter_size),
151147
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152148
masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +154,20 @@ struct MaskedPlaceStridedFunctor
158154

159155
void operator()(sycl::nd_item<2> ndit) const
160156
{
161-
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
162-
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
163-
const dataT *rhs_data = reinterpret_cast<const dataT *>(rhs_cp);
164-
165157
const std::size_t orthog_i = ndit.get_global_id(0);
166-
const std::size_t group_i = ndit.get_group(1);
167158
const std::uint32_t l_i = ndit.get_local_id(1);
168159
const std::uint32_t lws = ndit.get_local_range(1);
169160

170-
const size_t masked_block_start = group_i * lws;
171-
const size_t masked_i = masked_block_start + l_i;
161+
const size_t masked_i = ndit.get_global_id(1);
162+
const size_t masked_block_start = masked_i - l_i;
172163

164+
const std::size_t max_offset = masked_nelems + 1;
173165
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
174166
const size_t offset = masked_block_start + i;
175167
lacc[i] = (offset == 0) ? indT(0)
176-
: (offset - 1 < masked_nelems)
177-
? cumsum_data[offset - 1]
178-
: cumsum_data[masked_nelems - 1] + 1;
168+
: (offset < max_offset)
169+
? cumsum[offset - 1]
170+
: cumsum[masked_nelems - 1] + 1;
179171
}
180172

181173
sycl::group_barrier(ndit.get_group());
@@ -196,14 +188,14 @@ struct MaskedPlaceStridedFunctor
196188
masked_rhs_indexer(current_running_count - 1) +
197189
orthog_offsets.get_second_offset();
198190

199-
dst_data[total_dst_offset] = rhs_data[total_rhs_offset];
191+
dst[total_dst_offset] = rhs[total_rhs_offset];
200192
}
201193
}
202194

203195
private:
204-
char *dst_cp = nullptr;
205-
const char *cumsum_cp = nullptr;
206-
const char *rhs_cp = nullptr;
196+
dataT *dst = nullptr;
197+
const indT *cumsum = nullptr;
198+
const dataT *rhs = nullptr;
207199
const size_t masked_nelems = 0;
208200
// has nd, shape, dst_strides, rhs_strides for
209201
// dimensions that ARE NOT masked
@@ -218,6 +210,26 @@ struct MaskedPlaceStridedFunctor
218210

219211
// ======= Masked extraction ================================
220212

213+
namespace {
214+
215+
template <std::size_t I, std::size_t... IR>
216+
std::size_t _get_lws_impl(std::size_t n) {
217+
if constexpr (sizeof...(IR) == 0) {
218+
return I;
219+
} else {
220+
return (n < I) ? _get_lws_impl<IR...>(n) : I;
221+
}
222+
}
223+
224+
std::size_t get_lws(std::size_t n) {
225+
constexpr std::size_t lws0 = 256u;
226+
constexpr std::size_t lws1 = 128u;
227+
constexpr std::size_t lws2 = 64u;
228+
return _get_lws_impl<lws0, lws1, lws2>(n);
229+
}
230+
231+
} // end of anonymous namespace
232+
221233
template <typename MaskedDstIndexerT, typename dataT, typename indT>
222234
class masked_extract_all_slices_contig_impl_krn;
223235

@@ -258,16 +270,21 @@ sycl::event masked_extract_all_slices_contig_impl(
258270
Strided1DIndexer, dataT, indT,
259271
LocalAccessorT>;
260272

261-
constexpr std::size_t nominal_lws = 256;
262273
const std::size_t masked_extent = iteration_size;
263-
const std::size_t lws = std::min(masked_extent, nominal_lws);
274+
275+
const std::size_t lws = get_lws(masked_extent);
276+
264277
const std::size_t n_groups = (iteration_size + lws - 1) / lws;
265278

266279
sycl::range<2> gRange{1, n_groups * lws};
267280
sycl::range<2> lRange{1, lws};
268281

269282
sycl::nd_range<2> ndRange(gRange, lRange);
270283

284+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
285+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
286+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
287+
271288
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
272289
cgh.depends_on(depends);
273290

@@ -276,7 +293,7 @@ sycl::event masked_extract_all_slices_contig_impl(
276293

277294
cgh.parallel_for<KernelName>(
278295
ndRange,
279-
Impl(src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
296+
Impl(src_tp, cumsum_tp, dst_tp, masked_extent, orthog_src_dst_indexer,
280297
masked_src_indexer, masked_dst_indexer, lacc));
281298
});
282299

@@ -332,16 +349,21 @@ sycl::event masked_extract_all_slices_strided_impl(
332349
StridedIndexer, Strided1DIndexer,
333350
dataT, indT, LocalAccessorT>;
334351

335-
constexpr std::size_t nominal_lws = 256;
336352
const std::size_t masked_nelems = iteration_size;
337-
const std::size_t lws = std::min(masked_nelems, nominal_lws);
353+
354+
const std::size_t lws = get_lws(masked_nelems);
355+
338356
const std::size_t n_groups = (masked_nelems + lws - 1) / lws;
339357

340358
sycl::range<2> gRange{1, n_groups * lws};
341359
sycl::range<2> lRange{1, lws};
342360

343361
sycl::nd_range<2> ndRange(gRange, lRange);
344362

363+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
364+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
365+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
366+
345367
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
346368
cgh.depends_on(depends);
347369

@@ -350,7 +372,7 @@ sycl::event masked_extract_all_slices_strided_impl(
350372

351373
cgh.parallel_for<KernelName>(
352374
ndRange,
353-
Impl(src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
375+
Impl(src_tp, cumsum_tp, dst_tp, iteration_size, orthog_src_dst_indexer,
354376
masked_src_indexer, masked_dst_indexer, lacc));
355377
});
356378

@@ -422,9 +444,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422444
StridedIndexer, Strided1DIndexer,
423445
dataT, indT, LocalAccessorT>;
424446

425-
const size_t nominal_lws = 256;
426447
const std::size_t masked_extent = masked_nelems;
427-
const size_t lws = std::min(masked_extent, nominal_lws);
448+
449+
const std::size_t lws = get_lws(masked_extent);
450+
428451
const size_t n_groups = ((masked_extent + lws - 1) / lws);
429452
const size_t orthog_extent = static_cast<size_t>(orthog_nelems);
430453

@@ -433,6 +456,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433456

434457
sycl::nd_range<2> ndRange(gRange, lRange);
435458

459+
const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
460+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
461+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
462+
436463
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
437464
cgh.depends_on(depends);
438465

@@ -442,7 +469,7 @@ sycl::event masked_extract_some_slices_strided_impl(
442469

443470
cgh.parallel_for<KernelName>(
444471
ndRange,
445-
Impl(src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
472+
Impl(src_tp, cumsum_tp, dst_tp, masked_nelems, orthog_src_dst_indexer,
446473
masked_src_indexer, masked_dst_indexer, lacc));
447474
});
448475

@@ -567,6 +594,10 @@ sycl::event masked_place_all_slices_strided_impl(
567594

568595
using LocalAccessorT = sycl::local_accessor<indT, 1>;
569596

597+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
598+
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
599+
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
600+
570601
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
571602
cgh.depends_on(depends);
572603

@@ -578,7 +609,7 @@ sycl::event masked_place_all_slices_strided_impl(
578609
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579610
Strided1DCyclicIndexer, dataT, indT,
580611
LocalAccessorT>(
581-
dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
612+
dst_tp, cumsum_tp, rhs_tp, iteration_size, orthog_dst_rhs_indexer,
582613
masked_dst_indexer, masked_rhs_indexer, lacc));
583614
});
584615

@@ -659,6 +690,10 @@ sycl::event masked_place_some_slices_strided_impl(
659690

660691
using LocalAccessorT = sycl::local_accessor<indT, 1>;
661692

693+
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
694+
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
695+
const indT* cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
696+
662697
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
663698
cgh.depends_on(depends);
664699

@@ -670,7 +705,7 @@ sycl::event masked_place_some_slices_strided_impl(
670705
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671706
Strided1DCyclicIndexer, dataT, indT,
672707
LocalAccessorT>(
673-
dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
708+
dst_tp, cumsum_tp, rhs_tp, masked_nelems, orthog_dst_rhs_indexer,
674709
masked_dst_indexer, masked_rhs_indexer, lacc));
675710
});
676711

0 commit comments

Comments
 (0)