30
30
#include < type_traits>
31
31
32
32
#include " dpctl_tensor_types.hpp"
33
+ #include " utils/indexing_utils.hpp"
33
34
#include " utils/offset_utils.hpp"
34
35
#include " utils/type_utils.hpp"
35
36
@@ -42,54 +43,10 @@ namespace kernels
42
43
namespace indexing
43
44
{
44
45
45
- using namespace dpctl ::tensor::offset_utils;
46
-
47
- template <typename ProjectorT,
48
- typename OrthogStrider,
49
- typename IndicesStrider,
50
- typename AxesStrider,
51
- typename T,
52
- typename indT>
53
- class take_kernel ;
54
46
template <typename ProjectorT,
55
- typename OrthogStrider,
56
- typename IndicesStrider,
57
- typename AxesStrider,
58
- typename T,
59
- typename indT>
60
- class put_kernel ;
61
-
62
- class WrapIndex
63
- {
64
- public:
65
- WrapIndex () = default ;
66
-
67
- void operator ()(ssize_t max_item, ssize_t &ind) const
68
- {
69
- max_item = std::max<ssize_t >(max_item, 1 );
70
- ind = sycl::clamp<ssize_t >(ind, -max_item, max_item - 1 );
71
- ind = (ind < 0 ) ? ind + max_item : ind;
72
- return ;
73
- }
74
- };
75
-
76
- class ClipIndex
77
- {
78
- public:
79
- ClipIndex () = default ;
80
-
81
- void operator ()(ssize_t max_item, ssize_t &ind) const
82
- {
83
- max_item = std::max<ssize_t >(max_item, 1 );
84
- ind = sycl::clamp<ssize_t >(ind, ssize_t (0 ), max_item - 1 );
85
- return ;
86
- }
87
- };
88
-
89
- template <typename ProjectorT,
90
- typename OrthogStrider,
91
- typename IndicesStrider,
92
- typename AxesStrider,
47
+ typename OrthogIndexer,
48
+ typename IndicesIndexer,
49
+ typename AxesIndexer,
93
50
typename T,
94
51
typename indT>
95
52
class TakeFunctor
@@ -101,9 +58,9 @@ class TakeFunctor
101
58
int k_ = 0 ;
102
59
size_t ind_nelems_ = 0 ;
103
60
const ssize_t *axes_shape_and_strides_ = nullptr ;
104
- const OrthogStrider orthog_strider;
105
- const IndicesStrider ind_strider;
106
- const AxesStrider axes_strider;
61
+ const OrthogIndexer orthog_strider;
62
+ const IndicesIndexer ind_strider;
63
+ const AxesIndexer axes_strider;
107
64
108
65
public:
109
66
TakeFunctor (const char *src_cp,
@@ -112,9 +69,9 @@ class TakeFunctor
112
69
int k,
113
70
size_t ind_nelems,
114
71
const ssize_t *axes_shape_and_strides,
115
- const OrthogStrider &orthog_strider_,
116
- const IndicesStrider &ind_strider_,
117
- const AxesStrider &axes_strider_)
72
+ const OrthogIndexer &orthog_strider_,
73
+ const IndicesIndexer &ind_strider_,
74
+ const AxesIndexer &axes_strider_)
118
75
: src_(src_cp), dst_(dst_cp), ind_(ind_cp), k_(k),
119
76
ind_nelems_ (ind_nelems),
120
77
axes_shape_and_strides_(axes_shape_and_strides),
@@ -136,16 +93,16 @@ class TakeFunctor
136
93
ssize_t src_offset = orthog_offsets.get_first_offset ();
137
94
ssize_t dst_offset = orthog_offsets.get_second_offset ();
138
95
139
- const ProjectorT proj{};
96
+ constexpr ProjectorT proj{};
140
97
for (int axis_idx = 0 ; axis_idx < k_; ++axis_idx) {
141
98
indT *ind_data = reinterpret_cast <indT *>(ind_[axis_idx]);
142
99
143
100
ssize_t ind_offset = ind_strider (i_along, axis_idx);
144
- ssize_t i = static_cast < ssize_t >(ind_data[ind_offset]);
145
-
146
- proj (axes_shape_and_strides_[axis_idx], i );
147
-
148
- src_offset += i * axes_shape_and_strides_[k_ + axis_idx];
101
+ // proj produces an index in the range of the given axis
102
+ ssize_t projected_idx =
103
+ proj (axes_shape_and_strides_[axis_idx], ind_data[ind_offset] );
104
+ src_offset +=
105
+ projected_idx * axes_shape_and_strides_[k_ + axis_idx];
149
106
}
150
107
151
108
dst_offset += axes_strider (i_along);
@@ -154,6 +111,14 @@ class TakeFunctor
154
111
}
155
112
};
156
113
114
+ template <typename ProjectorT,
115
+ typename OrthogIndexer,
116
+ typename IndicesIndexer,
117
+ typename AxesIndexer,
118
+ typename T,
119
+ typename indT>
120
+ class take_kernel ;
121
+
157
122
typedef sycl::event (*take_fn_ptr_t )(sycl::queue &,
158
123
size_t ,
159
124
size_t ,
@@ -194,21 +159,29 @@ sycl::event take_impl(sycl::queue &q,
194
159
sycl::event take_ev = q.submit ([&](sycl::handler &cgh) {
195
160
cgh.depends_on (depends);
196
161
197
- const TwoOffsets_StridedIndexer orthog_indexer{
198
- nd, src_offset, dst_offset, orthog_shape_and_strides};
199
- const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
200
- ind_shape_and_strides};
201
- const StridedIndexer axes_indexer{ind_nd, 0 ,
202
- axes_shape_and_strides + (2 * k)};
162
+ using OrthogIndexerT =
163
+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
164
+ const OrthogIndexerT orthog_indexer{nd, src_offset, dst_offset,
165
+ orthog_shape_and_strides};
166
+
167
+ using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
168
+ const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
169
+ ind_shape_and_strides};
170
+
171
+ using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
172
+ const AxesIndexerT axes_indexer{ind_nd, 0 ,
173
+ axes_shape_and_strides + (2 * k)};
174
+
175
+ using KernelName =
176
+ take_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
177
+ AxesIndexerT, Ty, indT>;
203
178
204
179
const size_t gws = orthog_nelems * ind_nelems;
205
180
206
- cgh.parallel_for <
207
- take_kernel<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
208
- StridedIndexer, Ty, indT>>(
181
+ cgh.parallel_for <KernelName>(
209
182
sycl::range<1 >(gws),
210
- TakeFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset ,
211
- StridedIndexer , Ty, indT>(
183
+ TakeFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT ,
184
+ AxesIndexerT , Ty, indT>(
212
185
src_p, dst_p, ind_p, k, ind_nelems, axes_shape_and_strides,
213
186
orthog_indexer, indices_indexer, axes_indexer));
214
187
});
@@ -217,9 +190,9 @@ sycl::event take_impl(sycl::queue &q,
217
190
}
218
191
219
192
template <typename ProjectorT,
220
- typename OrthogStrider ,
221
- typename IndicesStrider ,
222
- typename AxesStrider ,
193
+ typename OrthogIndexer ,
194
+ typename IndicesIndexer ,
195
+ typename AxesIndexer ,
223
196
typename T,
224
197
typename indT>
225
198
class PutFunctor
@@ -231,9 +204,9 @@ class PutFunctor
231
204
int k_ = 0 ;
232
205
size_t ind_nelems_ = 0 ;
233
206
const ssize_t *axes_shape_and_strides_ = nullptr ;
234
- const OrthogStrider orthog_strider;
235
- const IndicesStrider ind_strider;
236
- const AxesStrider axes_strider;
207
+ const OrthogIndexer orthog_strider;
208
+ const IndicesIndexer ind_strider;
209
+ const AxesIndexer axes_strider;
237
210
238
211
public:
239
212
PutFunctor (char *dst_cp,
@@ -242,9 +215,9 @@ class PutFunctor
242
215
int k,
243
216
size_t ind_nelems,
244
217
const ssize_t *axes_shape_and_strides,
245
- const OrthogStrider &orthog_strider_,
246
- const IndicesStrider &ind_strider_,
247
- const AxesStrider &axes_strider_)
218
+ const OrthogIndexer &orthog_strider_,
219
+ const IndicesIndexer &ind_strider_,
220
+ const AxesIndexer &axes_strider_)
248
221
: dst_(dst_cp), val_(val_cp), ind_(ind_cp), k_(k),
249
222
ind_nelems_ (ind_nelems),
250
223
axes_shape_and_strides_(axes_shape_and_strides),
@@ -266,16 +239,17 @@ class PutFunctor
266
239
ssize_t dst_offset = orthog_offsets.get_first_offset ();
267
240
ssize_t val_offset = orthog_offsets.get_second_offset ();
268
241
269
- const ProjectorT proj{};
242
+ constexpr ProjectorT proj{};
270
243
for (int axis_idx = 0 ; axis_idx < k_; ++axis_idx) {
271
244
indT *ind_data = reinterpret_cast <indT *>(ind_[axis_idx]);
272
245
273
246
ssize_t ind_offset = ind_strider (i_along, axis_idx);
274
- ssize_t i = static_cast <ssize_t >(ind_data[ind_offset]);
275
-
276
- proj (axes_shape_and_strides_[axis_idx], i);
277
247
278
- dst_offset += i * axes_shape_and_strides_[k_ + axis_idx];
248
+ // proj produces an index in the range of the given axis
249
+ ssize_t projected_idx =
250
+ proj (axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
251
+ dst_offset +=
252
+ projected_idx * axes_shape_and_strides_[k_ + axis_idx];
279
253
}
280
254
281
255
val_offset += axes_strider (i_along);
@@ -284,6 +258,14 @@ class PutFunctor
284
258
}
285
259
};
286
260
261
+ template <typename ProjectorT,
262
+ typename OrthogIndexer,
263
+ typename IndicesIndexer,
264
+ typename AxesIndexer,
265
+ typename T,
266
+ typename indT>
267
+ class put_kernel ;
268
+
287
269
typedef sycl::event (*put_fn_ptr_t )(sycl::queue &,
288
270
size_t ,
289
271
size_t ,
@@ -324,20 +306,29 @@ sycl::event put_impl(sycl::queue &q,
324
306
sycl::event put_ev = q.submit ([&](sycl::handler &cgh) {
325
307
cgh.depends_on (depends);
326
308
327
- const TwoOffsets_StridedIndexer orthog_indexer{
328
- nd, dst_offset, val_offset, orthog_shape_and_strides};
329
- const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
330
- ind_shape_and_strides};
331
- const StridedIndexer axes_indexer{ind_nd, 0 ,
332
- axes_shape_and_strides + (2 * k)};
309
+ using OrthogIndexerT =
310
+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
311
+ const OrthogIndexerT orthog_indexer{nd, dst_offset, val_offset,
312
+ orthog_shape_and_strides};
313
+
314
+ using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
315
+ const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
316
+ ind_shape_and_strides};
317
+
318
+ using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
319
+ const AxesIndexerT axes_indexer{ind_nd, 0 ,
320
+ axes_shape_and_strides + (2 * k)};
321
+
322
+ using KernelName =
323
+ put_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
324
+ AxesIndexerT, Ty, indT>;
333
325
334
326
const size_t gws = orthog_nelems * ind_nelems;
335
327
336
- cgh.parallel_for <put_kernel<ProjectorT, TwoOffsets_StridedIndexer,
337
- NthStrideOffset, StridedIndexer, Ty, indT>>(
328
+ cgh.parallel_for <KernelName>(
338
329
sycl::range<1 >(gws),
339
- PutFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset ,
340
- StridedIndexer , Ty, indT>(
330
+ PutFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT ,
331
+ AxesIndexerT , Ty, indT>(
341
332
dst_p, val_p, ind_p, k, ind_nelems, axes_shape_and_strides,
342
333
orthog_indexer, indices_indexer, axes_indexer));
343
334
});
@@ -352,7 +343,8 @@ template <typename fnT, typename T, typename indT> struct TakeWrapFactory
352
343
if constexpr (std::is_integral<indT>::value &&
353
344
!std::is_same<indT, bool >::value)
354
345
{
355
- fnT fn = take_impl<WrapIndex, T, indT>;
346
+ using dpctl::tensor::indexing_utils::WrapIndex;
347
+ fnT fn = take_impl<WrapIndex<indT>, T, indT>;
356
348
return fn;
357
349
}
358
350
else {
@@ -369,7 +361,8 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
369
361
if constexpr (std::is_integral<indT>::value &&
370
362
!std::is_same<indT, bool >::value)
371
363
{
372
- fnT fn = take_impl<ClipIndex, T, indT>;
364
+ using dpctl::tensor::indexing_utils::ClipIndex;
365
+ fnT fn = take_impl<ClipIndex<indT>, T, indT>;
373
366
return fn;
374
367
}
375
368
else {
@@ -386,7 +379,8 @@ template <typename fnT, typename T, typename indT> struct PutWrapFactory
386
379
if constexpr (std::is_integral<indT>::value &&
387
380
!std::is_same<indT, bool >::value)
388
381
{
389
- fnT fn = put_impl<WrapIndex, T, indT>;
382
+ using dpctl::tensor::indexing_utils::WrapIndex;
383
+ fnT fn = put_impl<WrapIndex<indT>, T, indT>;
390
384
return fn;
391
385
}
392
386
else {
@@ -403,7 +397,8 @@ template <typename fnT, typename T, typename indT> struct PutClipFactory
403
397
if constexpr (std::is_integral<indT>::value &&
404
398
!std::is_same<indT, bool >::value)
405
399
{
406
- fnT fn = put_impl<ClipIndex, T, indT>;
400
+ using dpctl::tensor::indexing_utils::ClipIndex;
401
+ fnT fn = put_impl<ClipIndex<indT>, T, indT>;
407
402
return fn;
408
403
}
409
404
else {
0 commit comments