From 97b1c82cf6585b6bc8cdfd3f5ae3b081a3cb4177 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 22 May 2024 12:54:14 -0700 Subject: [PATCH 1/2] Adds code to handle edge case of strided input and scalar `needle` in `searchsorted.cpp` --- .../libtensor/source/sorting/searchsorted.cpp | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp b/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp index 211ffdcbef..ae7bd5bd9d 100644 --- a/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp +++ b/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp @@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay, int simplified_nd = needles_nd; using shT = std::vector; - shT simplified_common_shape; shT simplified_needles_strides; shT simplified_positions_strides; py::ssize_t needles_offset(0); py::ssize_t positions_offset(0); - dpctl::tensor::py_internal::simplify_iteration_space( - // modified by refernce - simplified_nd, - // read-only inputs - needles_shape_ptr, needles_strides, positions_strides, - // output, modified by reference - simplified_common_shape, simplified_needles_strides, - simplified_positions_strides, needles_offset, positions_offset); - + if (simplified_nd == 0) { + // needles and positions have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_positions_strides.push_back(0); + } + else { + dpctl::tensor::py_internal::simplify_iteration_space( + // modified by refernce + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, positions_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides, needles_offset, positions_offset); + } std::vector host_task_events; host_task_events.reserve(2); From 8c6a57e1bf6b1edf3d71fede1d7e2060107f4c81 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 22 May 2024 13:12:56 -0700 Subject: [PATCH 2/2] Adds a test for fix to gh-1689 --- dpctl/tests/test_usm_ndarray_searchsorted.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index 41f6ecac7a..0e65fcc235 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -355,3 +355,19 @@ def test_out_of_bound_sorter_values(): p = dpt.searchsorted(x, x2, sorter=sorter) # verify that they were applied with mode="wrap" assert dpt.all(p == dpt.arange(3, dtype=p.dtype)) + + +def test_searchsorted_strided_scalar_needle(): + get_queue_or_skip() + + a_max = 255 + + hay_stack = dpt.flip( + dpt.repeat(dpt.arange(a_max - 1, -1, -1, dtype=dpt.int32), 4) + ) + needles_np = np.squeeze( + np.random.randint(0, a_max, dtype=dpt.int32, size=1), axis=0 + ) + needles = dpt.asarray(needles_np) + + _check(hay_stack, needles, needles_np)