Skip to content

Add Python scalar support to dpt.where #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 183 additions & 44 deletions dpctl/tensor/_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,81 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._manipulation_functions import _broadcast_shapes
from dpctl.tensor._elementwise_common import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._type_utils import _all_data_types, _can_cast
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_all_data_types,
_can_cast,
_is_weak_dtype,
_strong_dtype_num_kind,
_to_device_supported_dtype,
_weak_type_num_kind,
)


def _default_dtype_from_weak_type(dt, dev):
if isinstance(dt, WeakBooleanType):
return dpt.bool
if isinstance(dt, WeakIntegralType):
return dpt.dtype(ti.default_device_int_type(dev))
if isinstance(dt, WeakFloatingType):
return dpt.dtype(ti.default_device_fp_type(dev))
if isinstance(dt, WeakComplexType):
return dpt.dtype(ti.default_device_complex_type(dev))


def _resolve_two_weak_types(o1_dtype, o2_dtype, dev):
"Resolves two weak data types per NEP-0050"
if _is_weak_dtype(o1_dtype):
if _is_weak_dtype(o2_dtype):
return _default_dtype_from_weak_type(
o1_dtype, dev
), _default_dtype_from_weak_type(o2_dtype, dev)
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakIntegralType):
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if isinstance(o1_dtype, WeakComplexType):
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
return dpt.complex64, o2_dtype
return (
_to_device_supported_dtype(dpt.complex128, dev),
o2_dtype,
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
return o2_dtype, o2_dtype
elif _is_weak_dtype(o2_dtype):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if isinstance(o2_dtype, WeakComplexType):
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
return o1_dtype, dpt.complex64
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
return (
o1_dtype,
_to_device_supported_dtype(dpt.float64, dev),
)
else:
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype


def _where_result_type(dt1, dt2, dev):
Expand Down Expand Up @@ -51,16 +121,17 @@ def where(condition, x1, x2, /, *, order="K", out=None):
and otherwise yields from ``x2``.
Must be compatible with ``x1`` and ``x2`` according
to broadcasting rules.
x1 (usm_ndarray): Array from which values are chosen when
``condition`` is ``True``.
x1 (Union[usm_ndarray, bool, int, float, complex]):
Array from which values are chosen when ``condition`` is ``True``.
Must be compatible with ``condition`` and ``x2`` according
to broadcasting rules.
x2 (usm_ndarray): Array from which values are chosen when
``condition`` is not ``True``.
x2 (Union[usm_ndarray, bool, int, float, complex]):
Array from which values are chosen when ``condition`` is not
``True``.
Must be compatible with ``condition`` and ``x2`` according
to broadcasting rules.
order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional):
Memory layout of the new output arra,
Memory layout of the new output array,
if parameter ``out`` is ``None``.
Default: ``"K"``.
out (Optional[usm_ndarray]):
Expand All @@ -81,36 +152,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
raise TypeError(
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
)
if not isinstance(x1, dpt.usm_ndarray):
raise TypeError(
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}"
if order not in ["K", "C", "F", "A"]:
order = "K"
q1, condition_usm_type = condition.sycl_queue, condition.usm_type
q2, x1_usm_type = _get_queue_usm_type(x1)
q3, x2_usm_type = _get_queue_usm_type(x2)
if q2 is None and q3 is None:
exec_q = q1
out_usm_type = condition_usm_type
elif q3 is None:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
out_usm_type = dpctl.utils.get_coerced_usm_type(
(
condition_usm_type,
x1_usm_type,
)
)
if not isinstance(x2, dpt.usm_ndarray):
elif q2 is None:
exec_q = dpctl.utils.get_execution_queue((q1, q3))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
out_usm_type = dpctl.utils.get_coerced_usm_type(
(
condition_usm_type,
x2_usm_type,
)
)
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2, q3))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
out_usm_type = dpctl.utils.get_coerced_usm_type(
(
condition_usm_type,
x1_usm_type,
x2_usm_type,
)
)
dpctl.utils.validate_usm_type(out_usm_type, allow_none=False)
condition_shape = condition.shape
x1_shape = _get_shape(x1)
x2_shape = _get_shape(x2)
if not all(
isinstance(s, (tuple, list))
for s in (
x1_shape,
x2_shape,
)
):
raise TypeError(
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}"
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
if order not in ["K", "C", "F", "A"]:
order = "K"
exec_q = dpctl.utils.get_execution_queue(
(
condition.sycl_queue,
x1.sycl_queue,
x2.sycl_queue,
try:
res_shape = _broadcast_shape_impl(
[
condition_shape,
x1_shape,
x2_shape,
]
)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError
out_usm_type = dpctl.utils.get_coerced_usm_type(
(
condition.usm_type,
x1.usm_type,
x2.usm_type,
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{condition_shape}, {x1_shape}, and {x2_shape}"
)
)

x1_dtype = x1.dtype
x2_dtype = x2.dtype
out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
sycl_dev = exec_q.sycl_device
x1_dtype = _get_dtype(x1, sycl_dev)
x2_dtype = _get_dtype(x2, sycl_dev)
if not all(_validate_dtype(o) for o in (x1_dtype, x2_dtype)):
raise ValueError("Operands have unsupported data types")
x1_dtype, x2_dtype = _resolve_two_weak_types(x1_dtype, x2_dtype, sycl_dev)
out_dtype = _where_result_type(x1_dtype, x2_dtype, sycl_dev)
if out_dtype is None:
raise TypeError(
"function 'where' does not support input "
Expand All @@ -119,8 +244,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
"to any supported types according to the casting rule ''safe''."
)

res_shape = _broadcast_shapes(condition, x1, x2)

orig_out = out
if out is not None:
if not isinstance(out, dpt.usm_ndarray):
Expand Down Expand Up @@ -149,16 +272,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
"Input and output allocation queues are not compatible"
)

if ti._array_overlap(condition, out):
if not ti._same_logical_tensors(condition, out):
out = dpt.empty_like(out)
if ti._array_overlap(condition, out) and not ti._same_logical_tensors(
condition, out
):
out = dpt.empty_like(out)

if ti._array_overlap(x1, out):
if not ti._same_logical_tensors(x1, out):
if isinstance(x1, dpt.usm_ndarray):
if (
ti._array_overlap(x1, out)
and not ti._same_logical_tensors(x1, out)
and x1_dtype == out_dtype
):
out = dpt.empty_like(out)

if ti._array_overlap(x2, out):
if not ti._same_logical_tensors(x2, out):
if isinstance(x2, dpt.usm_ndarray):
if (
ti._array_overlap(x2, out)
and not ti._same_logical_tensors(x2, out)
and x2_dtype == out_dtype
):
out = dpt.empty_like(out)

if order == "A":
Expand All @@ -174,6 +306,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
)
else "C"
)
if not isinstance(x1, dpt.usm_ndarray):
x1 = dpt.asarray(x1, dtype=x1_dtype, sycl_queue=exec_q)
if not isinstance(x2, dpt.usm_ndarray):
x2 = dpt.asarray(x2, dtype=x2_dtype, sycl_queue=exec_q)

if condition.size == 0:
if out is not None:
Expand Down Expand Up @@ -236,9 +372,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
sycl_queue=exec_q,
)

condition = dpt.broadcast_to(condition, res_shape)
x1 = dpt.broadcast_to(x1, res_shape)
x2 = dpt.broadcast_to(x2, res_shape)
if condition_shape != res_shape:
condition = dpt.broadcast_to(condition, res_shape)
if x1_shape != res_shape:
x1 = dpt.broadcast_to(x1, res_shape)
if x2_shape != res_shape:
x2 = dpt.broadcast_to(x2, res_shape)

dep_evs = _manager.submitted_events
hev, where_ev = ti._where(
Expand Down
58 changes: 56 additions & 2 deletions dpctl/tests/test_usm_ndarray_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes
import itertools

import numpy as np
import pytest
from helper import get_queue_or_skip, skip_if_dtype_not_supported
Expand Down Expand Up @@ -350,9 +353,9 @@ def test_where_arg_validation():

with pytest.raises(TypeError):
dpt.where(check, x1, x2)
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.where(x1, check, x2)
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.where(x1, x2, check)


Expand Down Expand Up @@ -522,3 +525,54 @@ def test_where_out_arg_validation():
dpt.where(condition, x1, x2, out=out_wrong_shape)
with pytest.raises(ValueError):
dpt.where(condition, x1, x2, out=out_not_writable)


@pytest.mark.parametrize("arr_dt", _all_dtypes)
def test_where_python_scalar(arr_dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arr_dt, q)

n1, n2 = 10, 10
condition = dpt.tile(
dpt.reshape(
dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2)
),
(n1, n2 // 2),
)
x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q)
py_scalars = (
bool(0),
int(0),
float(0),
complex(0),
np.float32(0),
ctypes.c_int(0),
)
for sc in py_scalars:
r = dpt.where(condition, x, sc)
assert isinstance(r, dpt.usm_ndarray)
r = dpt.where(condition, sc, x)
assert isinstance(r, dpt.usm_ndarray)


def test_where_two_python_scalars():
get_queue_or_skip()

n1, n2 = 10, 10
condition = dpt.tile(
dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)),
(n1, n2 // 2),
)

py_scalars = [
bool(0),
int(0),
float(0),
complex(0),
np.float32(0),
ctypes.c_int(0),
]

for sc1, sc2 in itertools.product(py_scalars, repeat=2):
r = dpt.where(condition, sc1, sc2)
assert isinstance(r, dpt.usm_ndarray)
Loading