17
17
import dpctl
18
18
import dpctl .tensor as dpt
19
19
import dpctl .tensor ._tensor_impl as ti
20
- from dpctl .tensor ._manipulation_functions import _broadcast_shapes
20
+ from dpctl .tensor ._elementwise_common import (
21
+ _get_dtype ,
22
+ _get_queue_usm_type ,
23
+ _get_shape ,
24
+ _validate_dtype ,
25
+ )
26
+ from dpctl .tensor ._manipulation_functions import _broadcast_shape_impl
21
27
from dpctl .utils import ExecutionPlacementError , SequentialOrderManager
22
28
23
29
from ._copy_utils import _empty_like_orderK , _empty_like_triple_orderK
24
- from ._type_utils import _all_data_types , _can_cast
30
+ from ._type_utils import (
31
+ WeakBooleanType ,
32
+ WeakComplexType ,
33
+ WeakFloatingType ,
34
+ WeakIntegralType ,
35
+ _all_data_types ,
36
+ _can_cast ,
37
+ _is_weak_dtype ,
38
+ _strong_dtype_num_kind ,
39
+ _to_device_supported_dtype ,
40
+ _weak_type_num_kind ,
41
+ )
42
+
43
+
44
+ def _default_dtype_from_weak_type (dt , dev ):
45
+ if isinstance (dt , WeakBooleanType ):
46
+ return dpt .bool
47
+ if isinstance (dt , WeakIntegralType ):
48
+ return dpt .dtype (ti .default_device_int_type (dev ))
49
+ if isinstance (dt , WeakFloatingType ):
50
+ return dpt .dtype (ti .default_device_fp_type (dev ))
51
+ if isinstance (dt , WeakComplexType ):
52
+ return dpt .dtype (ti .default_device_complex_type (dev ))
53
+
54
+
55
+ def _resolve_two_weak_types (o1_dtype , o2_dtype , dev ):
56
+ "Resolves two weak data types per NEP-0050"
57
+ if _is_weak_dtype (o1_dtype ):
58
+ if _is_weak_dtype (o2_dtype ):
59
+ return _default_dtype_from_weak_type (
60
+ o1_dtype , dev
61
+ ), _default_dtype_from_weak_type (o2_dtype , dev )
62
+ o1_kind_num = _weak_type_num_kind (o1_dtype )
63
+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
64
+ if o1_kind_num > o2_kind_num :
65
+ if isinstance (o1_dtype , WeakIntegralType ):
66
+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
67
+ if isinstance (o1_dtype , WeakComplexType ):
68
+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
69
+ return dpt .complex64 , o2_dtype
70
+ return (
71
+ _to_device_supported_dtype (dpt .complex128 , dev ),
72
+ o2_dtype ,
73
+ )
74
+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
75
+ else :
76
+ return o2_dtype , o2_dtype
77
+ elif _is_weak_dtype (o2_dtype ):
78
+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
79
+ o2_kind_num = _weak_type_num_kind (o2_dtype )
80
+ if o2_kind_num > o1_kind_num :
81
+ if isinstance (o2_dtype , WeakIntegralType ):
82
+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
83
+ if isinstance (o2_dtype , WeakComplexType ):
84
+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
85
+ return o1_dtype , dpt .complex64
86
+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
87
+ return (
88
+ o1_dtype ,
89
+ _to_device_supported_dtype (dpt .float64 , dev ),
90
+ )
91
+ else :
92
+ return o1_dtype , o1_dtype
93
+ else :
94
+ return o1_dtype , o2_dtype
25
95
26
96
27
97
def _where_result_type (dt1 , dt2 , dev ):
@@ -81,36 +151,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81
151
raise TypeError (
82
152
"Expecting dpctl.tensor.usm_ndarray type, " f"got { type (condition )} "
83
153
)
84
- if not isinstance (x1 , dpt .usm_ndarray ):
85
- raise TypeError (
86
- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x1 )} "
154
+ if order not in ["K" , "C" , "F" , "A" ]:
155
+ order = "K"
156
+ q1 , condition_usm_type = condition .sycl_queue , condition .usm_type
157
+ q2 , x1_usm_type = _get_queue_usm_type (x1 )
158
+ q3 , x2_usm_type = _get_queue_usm_type (x2 )
159
+ if q2 is None and q3 is None :
160
+ exec_q = q1
161
+ out_usm_type = condition_usm_type
162
+ elif q3 is None :
163
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
164
+ if exec_q is None :
165
+ raise ExecutionPlacementError (
166
+ "Execution placement can not be unambiguously inferred "
167
+ "from input arguments."
168
+ )
169
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
170
+ (
171
+ condition_usm_type ,
172
+ x1_usm_type ,
173
+ )
87
174
)
88
- if not isinstance (x2 , dpt .usm_ndarray ):
175
+ elif q2 is None :
176
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q3 ))
177
+ if exec_q is None :
178
+ raise ExecutionPlacementError (
179
+ "Execution placement can not be unambiguously inferred "
180
+ "from input arguments."
181
+ )
182
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
183
+ (
184
+ condition_usm_type ,
185
+ x2_usm_type ,
186
+ )
187
+ )
188
+ else :
189
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 , q3 ))
190
+ if exec_q is None :
191
+ raise ExecutionPlacementError (
192
+ "Execution placement can not be unambiguously inferred "
193
+ "from input arguments."
194
+ )
195
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
196
+ (
197
+ condition_usm_type ,
198
+ x1_usm_type ,
199
+ x2_usm_type ,
200
+ )
201
+ )
202
+ dpctl .utils .validate_usm_type (out_usm_type , allow_none = False )
203
+ condition_shape = condition .shape
204
+ x1_shape = _get_shape (x1 )
205
+ x2_shape = _get_shape (x2 )
206
+ if not all (
207
+ isinstance (s , (tuple , list ))
208
+ for s in (
209
+ x1_shape ,
210
+ x2_shape ,
211
+ )
212
+ ):
89
213
raise TypeError (
90
- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x2 )} "
214
+ "Shape of arguments can not be inferred. "
215
+ "Arguments are expected to be "
216
+ "lists, tuples, or both"
91
217
)
92
- if order not in [ "K" , "C" , "F" , "A" ] :
93
- order = "K"
94
- exec_q = dpctl . utils . get_execution_queue (
95
- (
96
- condition . sycl_queue ,
97
- x1 . sycl_queue ,
98
- x2 . sycl_queue ,
218
+ try :
219
+ res_shape = _broadcast_shape_impl (
220
+ [
221
+ condition_shape ,
222
+ x1_shape ,
223
+ x2_shape ,
224
+ ]
99
225
)
100
- )
101
- if exec_q is None :
102
- raise dpctl .utils .ExecutionPlacementError
103
- out_usm_type = dpctl .utils .get_coerced_usm_type (
104
- (
105
- condition .usm_type ,
106
- x1 .usm_type ,
107
- x2 .usm_type ,
226
+ except ValueError :
227
+ raise ValueError (
228
+ "operands could not be broadcast together with shapes "
229
+ f"{ condition_shape } , { x1_shape } , and { x2_shape } "
108
230
)
109
- )
110
-
111
- x1_dtype = x1 .dtype
112
- x2_dtype = x2 .dtype
113
- out_dtype = _where_result_type (x1_dtype , x2_dtype , exec_q .sycl_device )
231
+ sycl_dev = exec_q .sycl_device
232
+ x1_dtype = _get_dtype (x1 , sycl_dev )
233
+ x2_dtype = _get_dtype (x2 , sycl_dev )
234
+ if not all (_validate_dtype (o ) for o in (x1_dtype , x2_dtype )):
235
+ raise ValueError ("Operands have unsupported data types" )
236
+ x1_dtype , x2_dtype = _resolve_two_weak_types (x1_dtype , x2_dtype , sycl_dev )
237
+ out_dtype = _where_result_type (x1_dtype , x2_dtype , sycl_dev )
114
238
if out_dtype is None :
115
239
raise TypeError (
116
240
"function 'where' does not support input "
@@ -119,8 +243,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119
243
"to any supported types according to the casting rule ''safe''."
120
244
)
121
245
122
- res_shape = _broadcast_shapes (condition , x1 , x2 )
123
-
124
246
orig_out = out
125
247
if out is not None :
126
248
if not isinstance (out , dpt .usm_ndarray ):
@@ -149,16 +271,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149
271
"Input and output allocation queues are not compatible"
150
272
)
151
273
152
- if ti ._array_overlap (condition , out ):
153
- if not ti ._same_logical_tensors (condition , out ):
154
- out = dpt .empty_like (out )
274
+ if ti ._array_overlap (condition , out ) and not ti ._same_logical_tensors (
275
+ condition , out
276
+ ):
277
+ out = dpt .empty_like (out )
155
278
156
- if ti ._array_overlap (x1 , out ):
157
- if not ti ._same_logical_tensors (x1 , out ):
279
+ if isinstance (x1 , dpt .usm_ndarray ):
280
+ if (
281
+ ti ._array_overlap (x1 , out )
282
+ and not ti ._same_logical_tensors (x1 , out )
283
+ and x1_dtype == out_dtype
284
+ ):
158
285
out = dpt .empty_like (out )
159
286
160
- if ti ._array_overlap (x2 , out ):
161
- if not ti ._same_logical_tensors (x2 , out ):
287
+ if isinstance (x2 , dpt .usm_ndarray ):
288
+ if (
289
+ ti ._array_overlap (x2 , out )
290
+ and not ti ._same_logical_tensors (x2 , out )
291
+ and x2_dtype == out_dtype
292
+ ):
162
293
out = dpt .empty_like (out )
163
294
164
295
if order == "A" :
@@ -174,6 +305,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174
305
)
175
306
else "C"
176
307
)
308
+ if not isinstance (x1 , dpt .usm_ndarray ):
309
+ x1 = dpt .asarray (x1 , dtype = x1_dtype , sycl_queue = exec_q )
310
+ if not isinstance (x2 , dpt .usm_ndarray ):
311
+ x2 = dpt .asarray (x2 , dtype = x2_dtype , sycl_queue = exec_q )
177
312
178
313
if condition .size == 0 :
179
314
if out is not None :
@@ -236,9 +371,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236
371
sycl_queue = exec_q ,
237
372
)
238
373
239
- condition = dpt .broadcast_to (condition , res_shape )
240
- x1 = dpt .broadcast_to (x1 , res_shape )
241
- x2 = dpt .broadcast_to (x2 , res_shape )
374
+ if condition_shape != res_shape :
375
+ condition = dpt .broadcast_to (condition , res_shape )
376
+ if x1_shape != res_shape :
377
+ x1 = dpt .broadcast_to (x1 , res_shape )
378
+ if x2_shape != res_shape :
379
+ x2 = dpt .broadcast_to (x2 , res_shape )
242
380
243
381
dep_evs = _manager .submitted_events
244
382
hev , where_ev = ti ._where (
0 commit comments