From 9962dd201923613ab90a9beeebb30d8f1fbb432f Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 19 Mar 2024 23:12:00 -0500 Subject: [PATCH 1/2] Made explicit wrapping behavior of out of bound values of sorter tensor.searchsorter apply sorter array, if present, to first input array using tensor.take with default mode="wrap", which replaces out of bound indices with in-bound ones using modular reduction. --- dpctl/tensor/_searchsorted.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_searchsorted.py b/dpctl/tensor/_searchsorted.py index 3091c8fc5b..d9408e072e 100644 --- a/dpctl/tensor/_searchsorted.py +++ b/dpctl/tensor/_searchsorted.py @@ -22,7 +22,7 @@ def searchsorted( *, side: Literal["left", "right"] = "left", sorter: Union[usm_ndarray, None] = None, -): +) -> usm_ndarray: """searchsorted(x1, x2, side='left', sorter=None) Finds the indices into `x1` such that, if the corresponding elements @@ -50,6 +50,8 @@ def searchsorted( sorter (Optional[usm_ndarray]): array of indices that sort `x1` in ascending order. The array must have the same shape as `x1` and have an integral data type. + Out of bound index values of `sorter` array are treated using + `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. Default: `None`. """ if not isinstance(x1, usm_ndarray): From dd368941b4584aa3a2de1cff788e255e72a2d9f9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 20 Mar 2024 09:38:01 -0500 Subject: [PATCH 2/2] Added test for out of bounds indices in sorter --- dpctl/tests/test_usm_ndarray_searchsorted.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index 5004f71b70..4d2e899fe1 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -340,3 +340,18 @@ def test_pw_linear_interpolation_example(): exp = dpt.vecdot(vals[1:] + vals[:-1], bins[1:] - bins[:-1]) / 2 assert dpt.abs(av - exp) < 0.1 + + +def test_out_of_bound_sorter_values(): + get_queue_or_skip() + + x = dpt.asarray([1, 2, 0], dtype="i4") + n = x.shape[0] + + # use out-of-bounds indices in sorter + sorter = dpt.asarray([2, 0 - n, 1 - n], dtype="i8") + + x2 = dpt.arange(3, dtype=x.dtype) + p = dpt.searchsorted(x, x2, sorter=sorter) + # verify that they were applied with mode="wrap" + assert dpt.all(p == dpt.arange(3, dtype=p.dtype))