diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d373efa75..2fa6434fe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and complies with revision [2023.12](https://data-apis.org/array-api/2023.12/) o * Fixed bug in basic slicing of empty arrays: [gh-1680](https://github.com/IntelPython/dpctl/pull/1680) * Fixed bug in `tensor.bitwise_invert` for boolean input array: [gh-1681](https://github.com/IntelPython/dpctl/pull/1681) * Fixed bug in `tensor.repeat` on zero-size input arrays: [gh-1682](https://github.com/IntelPython/dpctl/pull/1682) +* Fixed bug in `tensor.searchsorted` for 0d needle vector and strided hay: [gh-1694](https://github.com/IntelPython/dpctl/pull/1694) ## [0.16.1] - Apr. 10, 2024 diff --git a/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp b/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp index 211ffdcbef..ae7bd5bd9d 100644 --- a/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp +++ b/dpctl/tensor/libtensor/source/sorting/searchsorted.cpp @@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay, int simplified_nd = needles_nd; using shT = std::vector; - shT simplified_common_shape; shT simplified_needles_strides; shT simplified_positions_strides; py::ssize_t needles_offset(0); py::ssize_t positions_offset(0); - dpctl::tensor::py_internal::simplify_iteration_space( - // modified by refernce - simplified_nd, - // read-only inputs - needles_shape_ptr, needles_strides, positions_strides, - // output, modified by reference - simplified_common_shape, simplified_needles_strides, - simplified_positions_strides, needles_offset, positions_offset); - + if (simplified_nd == 0) { + // needles and positions have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_positions_strides.push_back(0); + } + else { + dpctl::tensor::py_internal::simplify_iteration_space( + // modified by refernce + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, positions_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides, needles_offset, positions_offset); + } std::vector host_task_events; host_task_events.reserve(2); diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index 41f6ecac7a..0e65fcc235 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -355,3 +355,19 @@ def test_out_of_bound_sorter_values(): p = dpt.searchsorted(x, x2, sorter=sorter) # verify that they were applied with mode="wrap" assert dpt.all(p == dpt.arange(3, dtype=p.dtype)) + + +def test_searchsorted_strided_scalar_needle(): + get_queue_or_skip() + + a_max = 255 + + hay_stack = dpt.flip( + dpt.repeat(dpt.arange(a_max - 1, -1, -1, dtype=dpt.int32), 4) + ) + needles_np = np.squeeze( + np.random.randint(0, a_max, dtype=dpt.int32, size=1), axis=0 + ) + needles = dpt.asarray(needles_np) + + _check(hay_stack, needles, needles_np)