Skip to content

Commit 2a4714f

Browse files
authored
Merge pull request #1894 from IntelPython/fix-indexing-oob-indices
Fix undefined behavior in integer advanced indexing and indexing functions
2 parents dd2812f + ba0bfd9 commit 2a4714f

File tree

4 files changed

+271
-98
lines changed

4 files changed

+271
-98
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ The full list of changes that went into this release are:
9292
* Element-wise `tensor.divide` and comparison operations allow greater range of Python integer and integer array combinations [gh-1771](https://github.com/IntelPython/dpctl/pull/1771)
9393
* Fix for unexpected behavior when using floating point types for array indexing [gh-1792](https://github.com/IntelPython/dpctl/pull/1792)
9494
* Enable `pytest --pyargs dpctl.tests` [gh-1833](https://github.com/IntelPython/dpctl/pull/1833)
95+
* Fix for undefined behavior in indexing using integer arrays [gh-1894](https://github.com/IntelPython/dpctl/pull/1894)
9596

9697
### Maintenance
9798

dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 93 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <type_traits>
3131

3232
#include "dpctl_tensor_types.hpp"
33+
#include "utils/indexing_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_utils.hpp"
3536

@@ -42,54 +43,10 @@ namespace kernels
4243
namespace indexing
4344
{
4445

45-
using namespace dpctl::tensor::offset_utils;
46-
47-
template <typename ProjectorT,
48-
typename OrthogStrider,
49-
typename IndicesStrider,
50-
typename AxesStrider,
51-
typename T,
52-
typename indT>
53-
class take_kernel;
5446
template <typename ProjectorT,
55-
typename OrthogStrider,
56-
typename IndicesStrider,
57-
typename AxesStrider,
58-
typename T,
59-
typename indT>
60-
class put_kernel;
61-
62-
class WrapIndex
63-
{
64-
public:
65-
WrapIndex() = default;
66-
67-
void operator()(ssize_t max_item, ssize_t &ind) const
68-
{
69-
max_item = std::max<ssize_t>(max_item, 1);
70-
ind = sycl::clamp<ssize_t>(ind, -max_item, max_item - 1);
71-
ind = (ind < 0) ? ind + max_item : ind;
72-
return;
73-
}
74-
};
75-
76-
class ClipIndex
77-
{
78-
public:
79-
ClipIndex() = default;
80-
81-
void operator()(ssize_t max_item, ssize_t &ind) const
82-
{
83-
max_item = std::max<ssize_t>(max_item, 1);
84-
ind = sycl::clamp<ssize_t>(ind, ssize_t(0), max_item - 1);
85-
return;
86-
}
87-
};
88-
89-
template <typename ProjectorT,
90-
typename OrthogStrider,
91-
typename IndicesStrider,
92-
typename AxesStrider,
47+
typename OrthogIndexer,
48+
typename IndicesIndexer,
49+
typename AxesIndexer,
9350
typename T,
9451
typename indT>
9552
class TakeFunctor
@@ -101,9 +58,9 @@ class TakeFunctor
10158
int k_ = 0;
10259
size_t ind_nelems_ = 0;
10360
const ssize_t *axes_shape_and_strides_ = nullptr;
104-
const OrthogStrider orthog_strider;
105-
const IndicesStrider ind_strider;
106-
const AxesStrider axes_strider;
61+
const OrthogIndexer orthog_strider;
62+
const IndicesIndexer ind_strider;
63+
const AxesIndexer axes_strider;
10764

10865
public:
10966
TakeFunctor(const char *src_cp,
@@ -112,9 +69,9 @@ class TakeFunctor
11269
int k,
11370
size_t ind_nelems,
11471
const ssize_t *axes_shape_and_strides,
115-
const OrthogStrider &orthog_strider_,
116-
const IndicesStrider &ind_strider_,
117-
const AxesStrider &axes_strider_)
72+
const OrthogIndexer &orthog_strider_,
73+
const IndicesIndexer &ind_strider_,
74+
const AxesIndexer &axes_strider_)
11875
: src_(src_cp), dst_(dst_cp), ind_(ind_cp), k_(k),
11976
ind_nelems_(ind_nelems),
12077
axes_shape_and_strides_(axes_shape_and_strides),
@@ -136,16 +93,16 @@ class TakeFunctor
13693
ssize_t src_offset = orthog_offsets.get_first_offset();
13794
ssize_t dst_offset = orthog_offsets.get_second_offset();
13895

139-
const ProjectorT proj{};
96+
constexpr ProjectorT proj{};
14097
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
14198
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
14299

143100
ssize_t ind_offset = ind_strider(i_along, axis_idx);
144-
ssize_t i = static_cast<ssize_t>(ind_data[ind_offset]);
145-
146-
proj(axes_shape_and_strides_[axis_idx], i);
147-
148-
src_offset += i * axes_shape_and_strides_[k_ + axis_idx];
101+
// proj produces an index in the range of the given axis
102+
ssize_t projected_idx =
103+
proj(axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
104+
src_offset +=
105+
projected_idx * axes_shape_and_strides_[k_ + axis_idx];
149106
}
150107

151108
dst_offset += axes_strider(i_along);
@@ -154,6 +111,14 @@ class TakeFunctor
154111
}
155112
};
156113

114+
template <typename ProjectorT,
115+
typename OrthogIndexer,
116+
typename IndicesIndexer,
117+
typename AxesIndexer,
118+
typename T,
119+
typename indT>
120+
class take_kernel;
121+
157122
typedef sycl::event (*take_fn_ptr_t)(sycl::queue &,
158123
size_t,
159124
size_t,
@@ -194,21 +159,29 @@ sycl::event take_impl(sycl::queue &q,
194159
sycl::event take_ev = q.submit([&](sycl::handler &cgh) {
195160
cgh.depends_on(depends);
196161

197-
const TwoOffsets_StridedIndexer orthog_indexer{
198-
nd, src_offset, dst_offset, orthog_shape_and_strides};
199-
const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
200-
ind_shape_and_strides};
201-
const StridedIndexer axes_indexer{ind_nd, 0,
202-
axes_shape_and_strides + (2 * k)};
162+
using OrthogIndexerT =
163+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
164+
const OrthogIndexerT orthog_indexer{nd, src_offset, dst_offset,
165+
orthog_shape_and_strides};
166+
167+
using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
168+
const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
169+
ind_shape_and_strides};
170+
171+
using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
172+
const AxesIndexerT axes_indexer{ind_nd, 0,
173+
axes_shape_and_strides + (2 * k)};
174+
175+
using KernelName =
176+
take_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
177+
AxesIndexerT, Ty, indT>;
203178

204179
const size_t gws = orthog_nelems * ind_nelems;
205180

206-
cgh.parallel_for<
207-
take_kernel<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
208-
StridedIndexer, Ty, indT>>(
181+
cgh.parallel_for<KernelName>(
209182
sycl::range<1>(gws),
210-
TakeFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
211-
StridedIndexer, Ty, indT>(
183+
TakeFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
184+
AxesIndexerT, Ty, indT>(
212185
src_p, dst_p, ind_p, k, ind_nelems, axes_shape_and_strides,
213186
orthog_indexer, indices_indexer, axes_indexer));
214187
});
@@ -217,9 +190,9 @@ sycl::event take_impl(sycl::queue &q,
217190
}
218191

219192
template <typename ProjectorT,
220-
typename OrthogStrider,
221-
typename IndicesStrider,
222-
typename AxesStrider,
193+
typename OrthogIndexer,
194+
typename IndicesIndexer,
195+
typename AxesIndexer,
223196
typename T,
224197
typename indT>
225198
class PutFunctor
@@ -231,9 +204,9 @@ class PutFunctor
231204
int k_ = 0;
232205
size_t ind_nelems_ = 0;
233206
const ssize_t *axes_shape_and_strides_ = nullptr;
234-
const OrthogStrider orthog_strider;
235-
const IndicesStrider ind_strider;
236-
const AxesStrider axes_strider;
207+
const OrthogIndexer orthog_strider;
208+
const IndicesIndexer ind_strider;
209+
const AxesIndexer axes_strider;
237210

238211
public:
239212
PutFunctor(char *dst_cp,
@@ -242,9 +215,9 @@ class PutFunctor
242215
int k,
243216
size_t ind_nelems,
244217
const ssize_t *axes_shape_and_strides,
245-
const OrthogStrider &orthog_strider_,
246-
const IndicesStrider &ind_strider_,
247-
const AxesStrider &axes_strider_)
218+
const OrthogIndexer &orthog_strider_,
219+
const IndicesIndexer &ind_strider_,
220+
const AxesIndexer &axes_strider_)
248221
: dst_(dst_cp), val_(val_cp), ind_(ind_cp), k_(k),
249222
ind_nelems_(ind_nelems),
250223
axes_shape_and_strides_(axes_shape_and_strides),
@@ -266,16 +239,17 @@ class PutFunctor
266239
ssize_t dst_offset = orthog_offsets.get_first_offset();
267240
ssize_t val_offset = orthog_offsets.get_second_offset();
268241

269-
const ProjectorT proj{};
242+
constexpr ProjectorT proj{};
270243
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
271244
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
272245

273246
ssize_t ind_offset = ind_strider(i_along, axis_idx);
274-
ssize_t i = static_cast<ssize_t>(ind_data[ind_offset]);
275-
276-
proj(axes_shape_and_strides_[axis_idx], i);
277247

278-
dst_offset += i * axes_shape_and_strides_[k_ + axis_idx];
248+
// proj produces an index in the range of the given axis
249+
ssize_t projected_idx =
250+
proj(axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
251+
dst_offset +=
252+
projected_idx * axes_shape_and_strides_[k_ + axis_idx];
279253
}
280254

281255
val_offset += axes_strider(i_along);
@@ -284,6 +258,14 @@ class PutFunctor
284258
}
285259
};
286260

261+
template <typename ProjectorT,
262+
typename OrthogIndexer,
263+
typename IndicesIndexer,
264+
typename AxesIndexer,
265+
typename T,
266+
typename indT>
267+
class put_kernel;
268+
287269
typedef sycl::event (*put_fn_ptr_t)(sycl::queue &,
288270
size_t,
289271
size_t,
@@ -324,20 +306,29 @@ sycl::event put_impl(sycl::queue &q,
324306
sycl::event put_ev = q.submit([&](sycl::handler &cgh) {
325307
cgh.depends_on(depends);
326308

327-
const TwoOffsets_StridedIndexer orthog_indexer{
328-
nd, dst_offset, val_offset, orthog_shape_and_strides};
329-
const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
330-
ind_shape_and_strides};
331-
const StridedIndexer axes_indexer{ind_nd, 0,
332-
axes_shape_and_strides + (2 * k)};
309+
using OrthogIndexerT =
310+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
311+
const OrthogIndexerT orthog_indexer{nd, dst_offset, val_offset,
312+
orthog_shape_and_strides};
313+
314+
using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
315+
const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
316+
ind_shape_and_strides};
317+
318+
using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
319+
const AxesIndexerT axes_indexer{ind_nd, 0,
320+
axes_shape_and_strides + (2 * k)};
321+
322+
using KernelName =
323+
put_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
324+
AxesIndexerT, Ty, indT>;
333325

334326
const size_t gws = orthog_nelems * ind_nelems;
335327

336-
cgh.parallel_for<put_kernel<ProjectorT, TwoOffsets_StridedIndexer,
337-
NthStrideOffset, StridedIndexer, Ty, indT>>(
328+
cgh.parallel_for<KernelName>(
338329
sycl::range<1>(gws),
339-
PutFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
340-
StridedIndexer, Ty, indT>(
330+
PutFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
331+
AxesIndexerT, Ty, indT>(
341332
dst_p, val_p, ind_p, k, ind_nelems, axes_shape_and_strides,
342333
orthog_indexer, indices_indexer, axes_indexer));
343334
});
@@ -352,7 +343,8 @@ template <typename fnT, typename T, typename indT> struct TakeWrapFactory
352343
if constexpr (std::is_integral<indT>::value &&
353344
!std::is_same<indT, bool>::value)
354345
{
355-
fnT fn = take_impl<WrapIndex, T, indT>;
346+
using dpctl::tensor::indexing_utils::WrapIndex;
347+
fnT fn = take_impl<WrapIndex<indT>, T, indT>;
356348
return fn;
357349
}
358350
else {
@@ -369,7 +361,8 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
369361
if constexpr (std::is_integral<indT>::value &&
370362
!std::is_same<indT, bool>::value)
371363
{
372-
fnT fn = take_impl<ClipIndex, T, indT>;
364+
using dpctl::tensor::indexing_utils::ClipIndex;
365+
fnT fn = take_impl<ClipIndex<indT>, T, indT>;
373366
return fn;
374367
}
375368
else {
@@ -386,7 +379,8 @@ template <typename fnT, typename T, typename indT> struct PutWrapFactory
386379
if constexpr (std::is_integral<indT>::value &&
387380
!std::is_same<indT, bool>::value)
388381
{
389-
fnT fn = put_impl<WrapIndex, T, indT>;
382+
using dpctl::tensor::indexing_utils::WrapIndex;
383+
fnT fn = put_impl<WrapIndex<indT>, T, indT>;
390384
return fn;
391385
}
392386
else {
@@ -403,7 +397,8 @@ template <typename fnT, typename T, typename indT> struct PutClipFactory
403397
if constexpr (std::is_integral<indT>::value &&
404398
!std::is_same<indT, bool>::value)
405399
{
406-
fnT fn = put_impl<ClipIndex, T, indT>;
400+
using dpctl::tensor::indexing_utils::ClipIndex;
401+
fnT fn = put_impl<ClipIndex<indT>, T, indT>;
407402
return fn;
408403
}
409404
else {

0 commit comments

Comments
 (0)