diff --git a/dpctl/tensor/_searchsorted.py b/dpctl/tensor/_searchsorted.py index 3091c8fc5b..d9408e072e 100644 --- a/dpctl/tensor/_searchsorted.py +++ b/dpctl/tensor/_searchsorted.py @@ -22,7 +22,7 @@ def searchsorted( *, side: Literal["left", "right"] = "left", sorter: Union[usm_ndarray, None] = None, -): +) -> usm_ndarray: """searchsorted(x1, x2, side='left', sorter=None) Finds the indices into `x1` such that, if the corresponding elements @@ -50,6 +50,8 @@ def searchsorted( sorter (Optional[usm_ndarray]): array of indices that sort `x1` in ascending order. The array must have the same shape as `x1` and have an integral data type. + Out of bound index values of `sorter` array are treated using + `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. Default: `None`. """ if not isinstance(x1, usm_ndarray): diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index 5004f71b70..4d2e899fe1 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -340,3 +340,18 @@ def test_pw_linear_interpolation_example(): exp = dpt.vecdot(vals[1:] + vals[:-1], bins[1:] - bins[:-1]) / 2 assert dpt.abs(av - exp) < 0.1 + + +def test_out_of_bound_sorter_values(): + get_queue_or_skip() + + x = dpt.asarray([1, 2, 0], dtype="i4") + n = x.shape[0] + + # use out-of-bounds indices in sorter + sorter = dpt.asarray([2, 0 - n, 1 - n], dtype="i8") + + x2 = dpt.arange(3, dtype=x.dtype) + p = dpt.searchsorted(x, x2, sorter=sorter) + # verify that they were applied with mode="wrap" + assert dpt.all(p == dpt.arange(3, dtype=p.dtype))