1
1
# Data Parallel Control (dpctl)
2
2
#
3
- # Copyright 2020-2022 Intel Corporation
3
+ # Copyright 2020-2023 Intel Corporation
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
20
20
from numpy .testing import assert_array_equal
21
21
22
22
import dpctl .tensor as dpt
23
+ from dpctl .tensor ._search_functions import _where_result_type
24
+ from dpctl .tensor ._type_utils import _all_data_types
25
+ from dpctl .utils import ExecutionPlacementError
23
26
24
27
_all_dtypes = [
28
+ "?" ,
25
29
"u1" ,
26
30
"i1" ,
27
31
"u2" ,
38
42
]
39
43
40
44
45
+ class mock_device :
46
+ def __init__ (self , fp16 , fp64 ):
47
+ self .has_aspect_fp16 = fp16
48
+ self .has_aspect_fp64 = fp64
49
+
50
+
41
51
def test_where_basic ():
42
52
get_queue_or_skip ()
43
53
@@ -54,7 +64,16 @@ def test_where_basic():
54
64
out_expected = dpt .asarray (
55
65
[[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [0 , 0 , 0 ], [1 , 1 , 1 ]]
56
66
)
67
+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
57
68
69
+ out = dpt .where (cond , dpt .ones (cond .shape ), dpt .zeros (cond .shape ))
70
+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
71
+
72
+ out = dpt .where (
73
+ cond ,
74
+ dpt .ones (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
75
+ dpt .zeros (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
76
+ )
58
77
assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
59
78
60
79
@@ -72,38 +91,98 @@ def _dtype_all_close(x1, x2):
72
91
73
92
@pytest .mark .parametrize ("dt1" , _all_dtypes )
74
93
@pytest .mark .parametrize ("dt2" , _all_dtypes )
75
- def test_where_all_dtypes (dt1 , dt2 ):
94
+ @pytest .mark .parametrize ("fp16" , [True , False ])
95
+ @pytest .mark .parametrize ("fp64" , [True , False ])
96
+ def test_where_result_types (dt1 , dt2 , fp16 , fp64 ):
97
+ dev = mock_device (fp16 , fp64 )
98
+
99
+ dt1 = dpt .dtype (dt1 )
100
+ dt2 = dpt .dtype (dt2 )
101
+ res_t = _where_result_type (dt1 , dt2 , dev )
102
+
103
+ if fp16 and fp64 :
104
+ assert res_t == dpt .result_type (dt1 , dt2 )
105
+ else :
106
+ if res_t :
107
+ assert res_t .kind == dpt .result_type (dt1 , dt2 ).kind
108
+ else :
109
+ # some illegal cases are covered above, but
110
+ # this guarantees that _where_result_type
111
+ # produces None only when one of the dtypes
112
+ # is illegal given fp aspects of device
113
+ all_dts = _all_data_types (fp16 , fp64 )
114
+ assert dt1 not in all_dts or dt2 not in all_dts
115
+
116
+
117
+ @pytest .mark .parametrize ("dt" , _all_dtypes )
118
+ def test_where_all_dtypes (dt ):
76
119
q = get_queue_or_skip ()
77
- skip_if_dtype_not_supported (dt1 , q )
78
- skip_if_dtype_not_supported (dt2 , q )
120
+ skip_if_dtype_not_supported (dt , q )
79
121
80
- cond = dpt .asarray ([False , False , False , True , True ], sycl_queue = q )
81
- x1 = dpt .asarray (2 , sycl_queue = q )
82
- x2 = dpt .asarray (3 , sycl_queue = q )
122
+ # mask dtype changes
123
+ cond = dpt .asarray ([0 , 1 , 3 , 0 , 10 ], dtype = dt , sycl_queue = q )
124
+ x1 = dpt .asarray (0 , dtype = "f" , sycl_queue = q )
125
+ x2 = dpt .asarray (1 , dtype = "f" , sycl_queue = q )
126
+ res = dpt .where (cond , x1 , x2 )
127
+
128
+ res_check = np .asarray ([1 , 0 , 0 , 1 , 0 ], dtype = res .dtype )
129
+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
83
130
131
+ # contiguous cases
132
+ x1 = dpt .full (cond .shape , 0 , dtype = "f4" , sycl_queue = q )
133
+ x2 = dpt .full (cond .shape , 1 , dtype = "f4" , sycl_queue = q )
84
134
res = dpt .where (cond , x1 , x2 )
85
- res_check = np . asarray ([ 3 , 3 , 3 , 2 , 2 ], dtype = res . dtype )
135
+ assert _dtype_all_close ( dpt . asnumpy ( res ), res_check )
86
136
87
- dev = q .sycl_device
137
+ # input array dtype changes
138
+ cond = dpt .asarray ([False , True , True , False , True ], sycl_queue = q )
139
+ x1 = dpt .asarray (0 , dtype = dt , sycl_queue = q )
140
+ x2 = dpt .asarray (1 , dtype = dt , sycl_queue = q )
141
+ res = dpt .where (cond , x1 , x2 )
88
142
89
- if not dev . has_aspect_fp16 or not dev . has_aspect_fp64 :
90
- assert res . dtype . kind == dpt .result_type ( x1 . dtype , x2 . dtype ). kind
143
+ res_check = np . asarray ([ 1 , 0 , 0 , 1 , 0 ], dtype = res . dtype )
144
+ assert _dtype_all_close ( dpt .asnumpy ( res ), res_check )
91
145
146
+ # contiguous cases
147
+ x1 = dpt .full (cond .shape , 0 , dtype = dt , sycl_queue = q )
148
+ x2 = dpt .full (cond .shape , 1 , dtype = dt , sycl_queue = q )
149
+ res = dpt .where (cond , x1 , x2 )
92
150
assert _dtype_all_close (dpt .asnumpy (res ), res_check )
93
151
94
152
153
+ def test_where_nan_inf ():
154
+ get_queue_or_skip ()
155
+
156
+ cond = dpt .asarray ([True , False , True , False ], dtype = "?" )
157
+ x1 = dpt .asarray ([np .nan , 2.0 , np .inf , 3.0 ], dtype = "f4" )
158
+ x2 = dpt .asarray ([2.0 , np .nan , 3.0 , np .inf ], dtype = "f4" )
159
+
160
+ cond_np = dpt .asnumpy (cond )
161
+ x1_np = dpt .asnumpy (x1 )
162
+ x2_np = dpt .asnumpy (x2 )
163
+
164
+ res = dpt .where (cond , x1 , x2 )
165
+ res_np = np .where (cond_np , x1_np , x2_np )
166
+
167
+ assert np .allclose (dpt .asnumpy (res ), res_np , equal_nan = True )
168
+
169
+ res = dpt .where (x1 , cond , x2 )
170
+ res_np = np .where (x1_np , cond_np , x2_np )
171
+ assert _dtype_all_close (dpt .asnumpy (res ), res_np )
172
+
173
+
95
174
def test_where_empty ():
96
175
# check that numpy returns same results when
97
176
# handling empty arrays
98
177
get_queue_or_skip ()
99
178
100
- empty = dpt .empty (0 )
179
+ empty = dpt .empty (0 , dtype = "i2" )
101
180
m = dpt .asarray (True )
102
- x1 = dpt .asarray (1 )
103
- x2 = dpt .asarray (2 )
181
+ x1 = dpt .asarray (1 , dtype = "i2" )
182
+ x2 = dpt .asarray (2 , dtype = "i2" )
104
183
res = dpt .where (empty , x1 , x2 )
105
184
106
- empty_np = np .empty (0 )
185
+ empty_np = np .empty (0 , dtype = "i2" )
107
186
m_np = dpt .asnumpy (m )
108
187
x1_np = dpt .asnumpy (x1 )
109
188
x2_np = dpt .asnumpy (x2 )
@@ -116,12 +195,14 @@ def test_where_empty():
116
195
117
196
assert_array_equal (dpt .asnumpy (res ), res_np )
118
197
198
+ # check that broadcasting is performed
199
+ with pytest .raises (ValueError ):
200
+ dpt .where (empty , x1 , dpt .empty ((1 , 2 )))
201
+
119
202
120
- @pytest .mark .parametrize ("dt" , _all_dtypes )
121
203
@pytest .mark .parametrize ("order" , ["C" , "F" ])
122
- def test_where_contiguous (dt , order ):
123
- q = get_queue_or_skip ()
124
- skip_if_dtype_not_supported (dt , q )
204
+ def test_where_contiguous (order ):
205
+ get_queue_or_skip ()
125
206
126
207
cond = dpt .asarray (
127
208
[
@@ -131,14 +212,100 @@ def test_where_contiguous(dt, order):
131
212
[[False , False , False ], [True , False , True ]],
132
213
[[True , True , True ], [True , False , True ]],
133
214
],
134
- sycl_queue = q ,
135
215
order = order ,
136
216
)
137
217
138
- x1 = dpt .full (cond .shape , 2 , dtype = dt , order = order , sycl_queue = q )
139
- x2 = dpt .full (cond .shape , 3 , dtype = dt , order = order , sycl_queue = q )
218
+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" , order = order )
219
+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" , order = order )
220
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
221
+ res = dpt .where (cond , x1 , x2 )
222
+
223
+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
224
+
225
+
226
+ def test_where_contiguous1D ():
227
+ get_queue_or_skip ()
140
228
229
+ cond = dpt .asarray ([True , False , True , False , False , True ])
230
+
231
+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" )
232
+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" )
141
233
expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
142
234
res = dpt .where (cond , x1 , x2 )
235
+ assert_array_equal (dpt .asnumpy (res ), expected )
143
236
237
+ # test with complex dtype (branch in kernel)
238
+ x1 = dpt .astype (x1 , dpt .complex64 )
239
+ x2 = dpt .astype (x2 , dpt .complex64 )
240
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
241
+ res = dpt .where (cond , x1 , x2 )
144
242
assert _dtype_all_close (dpt .asnumpy (res ), expected )
243
+
244
+
245
+ def test_where_strided ():
246
+ get_queue_or_skip ()
247
+
248
+ s0 , s1 = 4 , 9
249
+ cond = dpt .reshape (
250
+ dpt .asarray (
251
+ [True , False , False , False , True , True , False , True , False ] * s0
252
+ ),
253
+ (s0 , s1 ),
254
+ )[:, ::3 ]
255
+
256
+ x1 = dpt .reshape (
257
+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 2 , dtype = "i4" ),
258
+ (cond .shape [0 ], cond .shape [1 ] * 2 ),
259
+ )[:, ::2 ]
260
+ x2 = dpt .reshape (
261
+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 3 , dtype = "i4" ),
262
+ (cond .shape [0 ], cond .shape [1 ] * 3 ),
263
+ )[:, ::3 ]
264
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
265
+ res = dpt .where (cond , x1 , x2 )
266
+
267
+ assert_array_equal (dpt .asnumpy (res ), expected )
268
+
269
+ # negative strides
270
+ res = dpt .where (cond , dpt .flip (x1 ), x2 )
271
+ expected = np .where (
272
+ dpt .asnumpy (cond ), np .flip (dpt .asnumpy (x1 )), dpt .asnumpy (x2 )
273
+ )
274
+ assert_array_equal (dpt .asnumpy (res ), expected )
275
+
276
+ res = dpt .where (dpt .flip (cond ), x1 , x2 )
277
+ expected = np .where (
278
+ np .flip (dpt .asnumpy (cond )), dpt .asnumpy (x1 ), dpt .asnumpy (x2 )
279
+ )
280
+ assert_array_equal (dpt .asnumpy (res ), expected )
281
+
282
+
283
+ def test_where_arg_validation ():
284
+ get_queue_or_skip ()
285
+
286
+ check = dict ()
287
+ x1 = dpt .empty ((1 ,), dtype = "i4" )
288
+ x2 = dpt .empty ((1 ,), dtype = "i4" )
289
+
290
+ with pytest .raises (TypeError ):
291
+ dpt .where (check , x1 , x2 )
292
+ with pytest .raises (TypeError ):
293
+ dpt .where (x1 , check , x2 )
294
+ with pytest .raises (TypeError ):
295
+ dpt .where (x1 , x2 , check )
296
+
297
+
298
+ def test_where_compute_follows_data ():
299
+ q1 = get_queue_or_skip ()
300
+ q2 = get_queue_or_skip ()
301
+ q3 = get_queue_or_skip ()
302
+
303
+ x1 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 )
304
+ x2 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q2 )
305
+
306
+ with pytest .raises (ExecutionPlacementError ):
307
+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 ), x1 , x2 )
308
+ with pytest .raises (ExecutionPlacementError ):
309
+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q3 ), x1 , x2 )
310
+ with pytest .raises (ExecutionPlacementError ):
311
+ dpt .where (x1 , x1 , x2 )
0 commit comments