Skip to content

Commit 8ab41e9

Browse files
committed
Add support to where for Python scalars
`x1` and `x2` can now both be Python scalars. As `condition` has no impact on the data type of the result, when both are scalars, the default data type for the scalar kind is used.
1 parent 26d34f5 commit 8ab41e9

File tree

2 files changed

+179
-41
lines changed

2 files changed

+179
-41
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 177 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,81 @@
1717
import dpctl
1818
import dpctl.tensor as dpt
1919
import dpctl.tensor._tensor_impl as ti
20-
from dpctl.tensor._manipulation_functions import _broadcast_shapes
20+
from dpctl.tensor._elementwise_common import (
21+
_get_dtype,
22+
_get_queue_usm_type,
23+
_get_shape,
24+
_validate_dtype,
25+
)
26+
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
2127
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
2228

2329
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
24-
from ._type_utils import _all_data_types, _can_cast
30+
from ._type_utils import (
31+
WeakBooleanType,
32+
WeakComplexType,
33+
WeakFloatingType,
34+
WeakIntegralType,
35+
_all_data_types,
36+
_can_cast,
37+
_is_weak_dtype,
38+
_strong_dtype_num_kind,
39+
_to_device_supported_dtype,
40+
_weak_type_num_kind,
41+
)
42+
43+
44+
def _default_dtype_from_weak_type(dt, dev):
45+
if isinstance(dt, WeakBooleanType):
46+
return dpt.bool
47+
if isinstance(dt, WeakIntegralType):
48+
return dpt.dtype(ti.default_device_int_type(dev))
49+
if isinstance(dt, WeakFloatingType):
50+
return dpt.dtype(ti.default_device_fp_type(dev))
51+
if isinstance(dt, WeakComplexType):
52+
return dpt.dtype(ti.default_device_complex_type(dev))
53+
54+
55+
def _resolve_two_weak_types(o1_dtype, o2_dtype, dev):
56+
"Resolves two weak data types per NEP-0050"
57+
if _is_weak_dtype(o1_dtype):
58+
if _is_weak_dtype(o2_dtype):
59+
return _default_dtype_from_weak_type(
60+
o1_dtype, dev
61+
), _default_dtype_from_weak_type(o2_dtype, dev)
62+
o1_kind_num = _weak_type_num_kind(o1_dtype)
63+
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
64+
if o1_kind_num > o2_kind_num:
65+
if isinstance(o1_dtype, WeakIntegralType):
66+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
67+
if isinstance(o1_dtype, WeakComplexType):
68+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
69+
return dpt.complex64, o2_dtype
70+
return (
71+
_to_device_supported_dtype(dpt.complex128, dev),
72+
o2_dtype,
73+
)
74+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
75+
else:
76+
return o2_dtype, o2_dtype
77+
elif _is_weak_dtype(o2_dtype):
78+
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
79+
o2_kind_num = _weak_type_num_kind(o2_dtype)
80+
if o2_kind_num > o1_kind_num:
81+
if isinstance(o2_dtype, WeakIntegralType):
82+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
83+
if isinstance(o2_dtype, WeakComplexType):
84+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
85+
return o1_dtype, dpt.complex64
86+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
87+
return (
88+
o1_dtype,
89+
_to_device_supported_dtype(dpt.float64, dev),
90+
)
91+
else:
92+
return o1_dtype, o1_dtype
93+
else:
94+
return o1_dtype, o2_dtype
2595

2696

2797
def _where_result_type(dt1, dt2, dev):
@@ -81,36 +151,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81151
raise TypeError(
82152
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
83153
)
84-
if not isinstance(x1, dpt.usm_ndarray):
85-
raise TypeError(
86-
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}"
154+
if order not in ["K", "C", "F", "A"]:
155+
order = "K"
156+
q1, condition_usm_type = condition.sycl_queue, condition.usm_type
157+
q2, x1_usm_type = _get_queue_usm_type(x1)
158+
q3, x2_usm_type = _get_queue_usm_type(x2)
159+
if q2 is None and q3 is None:
160+
exec_q = q1
161+
out_usm_type = condition_usm_type
162+
elif q3 is None:
163+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
164+
if exec_q is None:
165+
raise ExecutionPlacementError(
166+
"Execution placement can not be unambiguously inferred "
167+
"from input arguments."
168+
)
169+
out_usm_type = dpctl.utils.get_coerced_usm_type(
170+
(
171+
condition_usm_type,
172+
x1_usm_type,
173+
)
87174
)
88-
if not isinstance(x2, dpt.usm_ndarray):
175+
elif q2 is None:
176+
exec_q = dpctl.utils.get_execution_queue((q1, q3))
177+
if exec_q is None:
178+
raise ExecutionPlacementError(
179+
"Execution placement can not be unambiguously inferred "
180+
"from input arguments."
181+
)
182+
out_usm_type = dpctl.utils.get_coerced_usm_type(
183+
(
184+
condition_usm_type,
185+
x2_usm_type,
186+
)
187+
)
188+
else:
189+
exec_q = dpctl.utils.get_execution_queue((q1, q2, q3))
190+
if exec_q is None:
191+
raise ExecutionPlacementError(
192+
"Execution placement can not be unambiguously inferred "
193+
"from input arguments."
194+
)
195+
out_usm_type = dpctl.utils.get_coerced_usm_type(
196+
(
197+
condition_usm_type,
198+
x1_usm_type,
199+
x2_usm_type,
200+
)
201+
)
202+
dpctl.utils.validate_usm_type(out_usm_type, allow_none=False)
203+
condition_shape = condition.shape
204+
x1_shape = _get_shape(x1)
205+
x2_shape = _get_shape(x2)
206+
if not all(
207+
isinstance(s, (tuple, list))
208+
for s in (
209+
x1_shape,
210+
x2_shape,
211+
)
212+
):
89213
raise TypeError(
90-
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}"
214+
"Shape of arguments can not be inferred. "
215+
"Arguments are expected to be "
216+
"lists, tuples, or both"
91217
)
92-
if order not in ["K", "C", "F", "A"]:
93-
order = "K"
94-
exec_q = dpctl.utils.get_execution_queue(
95-
(
96-
condition.sycl_queue,
97-
x1.sycl_queue,
98-
x2.sycl_queue,
218+
try:
219+
res_shape = _broadcast_shape_impl(
220+
[
221+
condition_shape,
222+
x1_shape,
223+
x2_shape,
224+
]
99225
)
100-
)
101-
if exec_q is None:
102-
raise dpctl.utils.ExecutionPlacementError
103-
out_usm_type = dpctl.utils.get_coerced_usm_type(
104-
(
105-
condition.usm_type,
106-
x1.usm_type,
107-
x2.usm_type,
226+
except ValueError:
227+
raise ValueError(
228+
"operands could not be broadcast together with shapes "
229+
f"{condition_shape}, {x1_shape}, and {x2_shape}"
108230
)
109-
)
110-
111-
x1_dtype = x1.dtype
112-
x2_dtype = x2.dtype
113-
out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
231+
sycl_dev = exec_q.sycl_device
232+
x1_dtype = _get_dtype(x1, sycl_dev)
233+
x2_dtype = _get_dtype(x2, sycl_dev)
234+
if not all(_validate_dtype(o) for o in (x1_dtype, x2_dtype)):
235+
raise ValueError("Operands have unsupported data types")
236+
x1_dtype, x2_dtype = _resolve_two_weak_types(x1_dtype, x2_dtype, sycl_dev)
237+
out_dtype = _where_result_type(x1_dtype, x2_dtype, sycl_dev)
114238
if out_dtype is None:
115239
raise TypeError(
116240
"function 'where' does not support input "
@@ -119,8 +243,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119243
"to any supported types according to the casting rule ''safe''."
120244
)
121245

122-
res_shape = _broadcast_shapes(condition, x1, x2)
123-
124246
orig_out = out
125247
if out is not None:
126248
if not isinstance(out, dpt.usm_ndarray):
@@ -149,16 +271,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149271
"Input and output allocation queues are not compatible"
150272
)
151273

152-
if ti._array_overlap(condition, out):
153-
if not ti._same_logical_tensors(condition, out):
154-
out = dpt.empty_like(out)
274+
if ti._array_overlap(condition, out) and not ti._same_logical_tensors(
275+
condition, out
276+
):
277+
out = dpt.empty_like(out)
155278

156-
if ti._array_overlap(x1, out):
157-
if not ti._same_logical_tensors(x1, out):
279+
if isinstance(x1, dpt.usm_ndarray):
280+
if (
281+
ti._array_overlap(x1, out)
282+
and not ti._same_logical_tensors(x1, out)
283+
and x1_dtype == out_dtype
284+
):
158285
out = dpt.empty_like(out)
159286

160-
if ti._array_overlap(x2, out):
161-
if not ti._same_logical_tensors(x2, out):
287+
if isinstance(x2, dpt.usm_ndarray):
288+
if (
289+
ti._array_overlap(x2, out)
290+
and not ti._same_logical_tensors(x2, out)
291+
and x2_dtype == out_dtype
292+
):
162293
out = dpt.empty_like(out)
163294

164295
if order == "A":
@@ -174,6 +305,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174305
)
175306
else "C"
176307
)
308+
if not isinstance(x1, dpt.usm_ndarray):
309+
x1 = dpt.asarray(x1, dtype=x1_dtype, sycl_queue=exec_q)
310+
if not isinstance(x2, dpt.usm_ndarray):
311+
x2 = dpt.asarray(x2, dtype=x2_dtype, sycl_queue=exec_q)
177312

178313
if condition.size == 0:
179314
if out is not None:
@@ -236,9 +371,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236371
sycl_queue=exec_q,
237372
)
238373

239-
condition = dpt.broadcast_to(condition, res_shape)
240-
x1 = dpt.broadcast_to(x1, res_shape)
241-
x2 = dpt.broadcast_to(x2, res_shape)
374+
if condition_shape != res_shape:
375+
condition = dpt.broadcast_to(condition, res_shape)
376+
if x1_shape != res_shape:
377+
x1 = dpt.broadcast_to(x1, res_shape)
378+
if x2_shape != res_shape:
379+
x2 = dpt.broadcast_to(x2, res_shape)
242380

243381
dep_evs = _manager.submitted_events
244382
hev, where_ev = ti._where(

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ def test_where_arg_validation():
350350

351351
with pytest.raises(TypeError):
352352
dpt.where(check, x1, x2)
353-
with pytest.raises(TypeError):
353+
with pytest.raises(ValueError):
354354
dpt.where(x1, check, x2)
355-
with pytest.raises(TypeError):
355+
with pytest.raises(ValueError):
356356
dpt.where(x1, x2, check)
357357

358358

0 commit comments

Comments
 (0)