diff --git a/dpctl/tensor/_searchsorted.py b/dpctl/tensor/_searchsorted.py index eb24c2e9b4..3091c8fc5b 100644 --- a/dpctl/tensor/_searchsorted.py +++ b/dpctl/tensor/_searchsorted.py @@ -5,11 +5,13 @@ from ._copy_utils import _empty_like_orderK from ._ctors import empty -from ._data_types import int32, int64 from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy from ._tensor_impl import _take as ti_take +from ._tensor_impl import ( + default_device_index_type as ti_default_device_index_type, +) from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right -from ._type_utils import iinfo, isdtype, result_type +from ._type_utils import isdtype, result_type from ._usmarray import usm_ndarray @@ -141,9 +143,9 @@ def searchsorted( x2 = x2_buf dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) - dst_dt = int32 if x2.size <= iinfo(int32).max else int64 + index_dt = ti_default_device_index_type(q) - dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type) + dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) if side == "left": ht_ev, _ = _searchsorted_left( diff --git a/dpctl/tests/test_usm_ndarray_searchsorted.py b/dpctl/tests/test_usm_ndarray_searchsorted.py index caddfb2550..5004f71b70 100644 --- a/dpctl/tests/test_usm_ndarray_searchsorted.py +++ b/dpctl/tests/test_usm_ndarray_searchsorted.py @@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np): assert hay_stack.dtype == needles.dtype assert hay_stack.ndim == 1 + info_ = dpt.__array_namespace_info__() + default_dts_dev = info_.default_dtypes(hay_stack.device) + index_dt = default_dts_dev["indexing"] + p_left = dpt.searchsorted(hay_stack, needles, side="left") + assert p_left.dtype == index_dt hs_np = dpt.asnumpy(hay_stack) ref_left = np.searchsorted(hs_np, needles_np, side="left") assert dpt.all(p_left == dpt.asarray(ref_left)) p_right = dpt.searchsorted(hay_stack, needles, side="right") + assert p_right.dtype == index_dt + ref_right = np.searchsorted(hs_np, needles_np, side="right") assert dpt.all(p_right == dpt.asarray(ref_right)) sorter = dpt.arange(hay_stack.size) ps_left = dpt.searchsorted(hay_stack, needles, side="left", sorter=sorter) + assert ps_left.dtype == index_dt assert dpt.all(ps_left == p_left) ps_right = dpt.searchsorted(hay_stack, needles, side="right", sorter=sorter) + assert ps_right.dtype == index_dt assert dpt.all(ps_right == p_right)