From 8ab41e941ba9ee61837ae3faf4620e4966febbc9 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 26 Jun 2024 21:54:27 +0000 Subject: [PATCH 1/3] Add support to `where` for Python scalars `x1` and `x2` can now both be Python scalars. As `condition` has no impact on the data type of the result, when both are scalars, the default data type for the scalar kind is used. --- dpctl/tensor/_search_functions.py | 216 ++++++++++++++---- .../test_usm_ndarray_search_functions.py | 4 +- 2 files changed, 179 insertions(+), 41 deletions(-) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index c0fdfb7861..76e6630dff 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -17,11 +17,81 @@ import dpctl import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti -from dpctl.tensor._manipulation_functions import _broadcast_shapes +from dpctl.tensor._elementwise_common import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl from dpctl.utils import ExecutionPlacementError, SequentialOrderManager from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK -from ._type_utils import _all_data_types, _can_cast +from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, + _all_data_types, + _can_cast, + _is_weak_dtype, + _strong_dtype_num_kind, + _to_device_supported_dtype, + _weak_type_num_kind, +) + + +def _default_dtype_from_weak_type(dt, dev): + if isinstance(dt, WeakBooleanType): + return dpt.bool + if isinstance(dt, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(dt, WeakFloatingType): + return dpt.dtype(ti.default_device_fp_type(dev)) + if isinstance(dt, WeakComplexType): + return dpt.dtype(ti.default_device_complex_type(dev)) + + +def _resolve_two_weak_types(o1_dtype, o2_dtype, dev): + "Resolves two weak data types per NEP-0050" + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): + return _default_dtype_from_weak_type( + o1_dtype, dev + ), _default_dtype_from_weak_type(o2_dtype, dev) + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + return o2_dtype, o2_dtype + elif _is_weak_dtype(o2_dtype): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype def _where_result_type(dt1, dt2, dev): @@ -81,36 +151,90 @@ def where(condition, x1, x2, /, *, order="K", out=None): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" ) - if not isinstance(x1, dpt.usm_ndarray): - raise TypeError( - "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}" + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, condition_usm_type = condition.sycl_queue, condition.usm_type + q2, x1_usm_type = _get_queue_usm_type(x1) + q3, x2_usm_type = _get_queue_usm_type(x2) + if q2 is None and q3 is None: + exec_q = q1 + out_usm_type = condition_usm_type + elif q3 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x1_usm_type, + ) ) - if not isinstance(x2, dpt.usm_ndarray): + elif q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x2_usm_type, + ) + ) + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(out_usm_type, allow_none=False) + condition_shape = condition.shape + x1_shape = _get_shape(x1) + x2_shape = _get_shape(x2) + if not all( + isinstance(s, (tuple, list)) + for s in ( + x1_shape, + x2_shape, + ) + ): raise TypeError( - "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}" + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" ) - if order not in ["K", "C", "F", "A"]: - order = "K" - exec_q = dpctl.utils.get_execution_queue( - ( - condition.sycl_queue, - x1.sycl_queue, - x2.sycl_queue, + try: + res_shape = _broadcast_shape_impl( + [ + condition_shape, + x1_shape, + x2_shape, + ] ) - ) - if exec_q is None: - raise dpctl.utils.ExecutionPlacementError - out_usm_type = dpctl.utils.get_coerced_usm_type( - ( - condition.usm_type, - x1.usm_type, - x2.usm_type, + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{condition_shape}, {x1_shape}, and {x2_shape}" ) - ) - - x1_dtype = x1.dtype - x2_dtype = x2.dtype - out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) + sycl_dev = exec_q.sycl_device + x1_dtype = _get_dtype(x1, sycl_dev) + x2_dtype = _get_dtype(x2, sycl_dev) + if not all(_validate_dtype(o) for o in (x1_dtype, x2_dtype)): + raise ValueError("Operands have unsupported data types") + x1_dtype, x2_dtype = _resolve_two_weak_types(x1_dtype, x2_dtype, sycl_dev) + out_dtype = _where_result_type(x1_dtype, x2_dtype, sycl_dev) if out_dtype is None: raise TypeError( "function 'where' does not support input " @@ -119,8 +243,6 @@ def where(condition, x1, x2, /, *, order="K", out=None): "to any supported types according to the casting rule ''safe''." ) - res_shape = _broadcast_shapes(condition, x1, x2) - orig_out = out if out is not None: if not isinstance(out, dpt.usm_ndarray): @@ -149,16 +271,25 @@ def where(condition, x1, x2, /, *, order="K", out=None): "Input and output allocation queues are not compatible" ) - if ti._array_overlap(condition, out): - if not ti._same_logical_tensors(condition, out): - out = dpt.empty_like(out) + if ti._array_overlap(condition, out) and not ti._same_logical_tensors( + condition, out + ): + out = dpt.empty_like(out) - if ti._array_overlap(x1, out): - if not ti._same_logical_tensors(x1, out): + if isinstance(x1, dpt.usm_ndarray): + if ( + ti._array_overlap(x1, out) + and not ti._same_logical_tensors(x1, out) + and x1_dtype == out_dtype + ): out = dpt.empty_like(out) - if ti._array_overlap(x2, out): - if not ti._same_logical_tensors(x2, out): + if isinstance(x2, dpt.usm_ndarray): + if ( + ti._array_overlap(x2, out) + and not ti._same_logical_tensors(x2, out) + and x2_dtype == out_dtype + ): out = dpt.empty_like(out) if order == "A": @@ -174,6 +305,10 @@ def where(condition, x1, x2, /, *, order="K", out=None): ) else "C" ) + if not isinstance(x1, dpt.usm_ndarray): + x1 = dpt.asarray(x1, dtype=x1_dtype, sycl_queue=exec_q) + if not isinstance(x2, dpt.usm_ndarray): + x2 = dpt.asarray(x2, dtype=x2_dtype, sycl_queue=exec_q) if condition.size == 0: if out is not None: @@ -236,9 +371,12 @@ def where(condition, x1, x2, /, *, order="K", out=None): sycl_queue=exec_q, ) - condition = dpt.broadcast_to(condition, res_shape) - x1 = dpt.broadcast_to(x1, res_shape) - x2 = dpt.broadcast_to(x2, res_shape) + if condition_shape != res_shape: + condition = dpt.broadcast_to(condition, res_shape) + if x1_shape != res_shape: + x1 = dpt.broadcast_to(x1, res_shape) + if x2_shape != res_shape: + x2 = dpt.broadcast_to(x2, res_shape) dep_evs = _manager.submitted_events hev, where_ev = ti._where( diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index 38e106fb9f..a6552d0678 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -350,9 +350,9 @@ def test_where_arg_validation(): with pytest.raises(TypeError): dpt.where(check, x1, x2) - with pytest.raises(TypeError): + with pytest.raises(ValueError): dpt.where(x1, check, x2) - with pytest.raises(TypeError): + with pytest.raises(ValueError): dpt.where(x1, x2, check) From 20929bec233d2525701234f2e9610ac36a44a161 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 26 Jun 2024 22:24:01 +0000 Subject: [PATCH 2/3] Adds tests for `where` behavior with scalars --- .../test_usm_ndarray_search_functions.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index a6552d0678..aba0142ee0 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes +import itertools + import numpy as np import pytest from helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -522,3 +525,54 @@ def test_where_out_arg_validation(): dpt.where(condition, x1, x2, out=out_wrong_shape) with pytest.raises(ValueError): dpt.where(condition, x1, x2, out=out_not_writable) + + +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_where_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + n1, n2 = 10, 10 + condition = dpt.tile( + dpt.reshape( + dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2) + ), + (n1, n2 // 2), + ) + x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q) + py_scalars = ( + bool(0), + int(0), + float(0), + complex(0), + np.float32(0), + ctypes.c_int(0), + ) + for sc in py_scalars: + r = dpt.where(condition, x, sc) + assert isinstance(r, dpt.usm_ndarray) + r = dpt.where(condition, sc, x) + assert isinstance(r, dpt.usm_ndarray) + + +def test_where_two_python_scalars(): + get_queue_or_skip() + + n1, n2 = 10, 10 + condition = dpt.tile( + dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)), + (n1, n2 // 2), + ) + + py_scalars = [ + bool(0), + int(0), + float(0), + complex(0), + np.float32(0), + ctypes.c_int(0), + ] + + for sc1, sc2 in itertools.product(py_scalars, repeat=2): + r = dpt.where(condition, sc1, sc2) + assert isinstance(r, dpt.usm_ndarray) From 9f449d1f1f03ae0a447f2575b1626821716717b9 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 26 Jun 2024 23:36:24 +0000 Subject: [PATCH 3/3] Update docstring to reflect change in `where` functionality --- dpctl/tensor/_search_functions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 76e6630dff..4a6c32c3f4 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -121,16 +121,17 @@ def where(condition, x1, x2, /, *, order="K", out=None): and otherwise yields from ``x2``. Must be compatible with ``x1`` and ``x2`` according to broadcasting rules. - x1 (usm_ndarray): Array from which values are chosen when - ``condition`` is ``True``. + x1 (Union[usm_ndarray, bool, int, float, complex]): + Array from which values are chosen when ``condition`` is ``True``. Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. - x2 (usm_ndarray): Array from which values are chosen when - ``condition`` is not ``True``. + x2 (Union[usm_ndarray, bool, int, float, complex]): + Array from which values are chosen when ``condition`` is not + ``True``. Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional): - Memory layout of the new output arra, + Memory layout of the new output array, if parameter ``out`` is ``None``. Default: ``"K"``. out (Optional[usm_ndarray]):