Skip to content

Commit 2e78bbb

Browse files
Modularized common checks, added tests to cover change in functionality
Fixed test_arange and added test_arange_mixed_dtype The test_arange test was specifying out of bounds starting value for certain types. Added tests to based on examples highlighting discrepancies. dpt.arange(-2.5, stop=200, step=100, dtype='int32') now produces the sequence consistent with np.arange output. So does dpt.arange(9.7, stop=200, step=100, dtype='i4') ```ipython In [1]: import dpctl, dpctl.tensor as dpt, numpy as np In [2]: dpt.asnumpy(dpt.arange(9.7, stop=200, step=100, dtype='i4')) Out[2]: array([ 9, 109], dtype=int32) In [3]: np.arange(9.7, stop=200, step=100, dtype='i4') Out[3]: array([ 9, 109], dtype=int32) ``` ``` ty = np.float32 assert ( dpt.arange( ty(0), stop=ty(504.0), step=ty(100), dtype=ty ).shape == (5,) ) ``` ``` In [4]: import numpy as np In [5]: np.arange(-5, stop=10**5, step=2.7, dtype=np.int64).shape Out[5]: (37039,) In [6]: import dpctl.tensor as dpt In [7]: dpt.arange(-5, stop=10**5, step=2.7, dtype=np.int64).shape Out[7]: (50003,) ```
1 parent 9700b94 commit 2e78bbb

File tree

3 files changed

+134
-166
lines changed

3 files changed

+134
-166
lines changed

dpctl/tests/helper/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919

2020
from ._helper import (
2121
create_invalid_capsule,
22+
get_queue_or_skip,
2223
has_cpu,
2324
has_gpu,
2425
has_sycl_platforms,
26+
skip_if_dtype_not_supported,
2527
)
2628

2729
__all__ = [
2830
"create_invalid_capsule",
2931
"has_cpu",
3032
"has_gpu",
3133
"has_sycl_platforms",
34+
"get_queue_or_skip",
35+
"skip_if_dtype_not_supported",
3236
]

dpctl/tests/helper/_helper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import pytest
18+
1719
import dpctl
1820

1921

@@ -39,3 +41,38 @@ def create_invalid_capsule():
3941
ctor.restype = ctypes.py_object
4042
ctor.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
4143
return ctor(id(ctor), b"invalid", 0)
44+
45+
46+
def get_queue_or_skip(args=tuple()):
47+
try:
48+
q = dpctl.SyclQueue(*args)
49+
except dpctl.SyclQueueCreationError:
50+
pytest.skip(f"Queue could not be created from {args}")
51+
return q
52+
53+
54+
def skip_if_dtype_not_supported(dt, q_or_dev):
55+
import dpctl.tensor as dpt
56+
57+
dt = dpt.dtype(dt)
58+
if type(q_or_dev) is dpctl.SyclQueue:
59+
dev = q_or_dev.sycl_device
60+
elif type(q_or_dev) is dpctl.SyclDevice:
61+
dev = q_or_dev
62+
else:
63+
raise TypeError(
64+
"Expected dpctl.SyclQueue or dpctl.SyclDevice, "
65+
f"got {type(q_or_dev)}"
66+
)
67+
dev_has_dp = dev.has_aspect_fp64
68+
if dev_has_dp is False and dt in [dpt.float64, dpt.complex128]:
69+
pytest.skip(
70+
f"{dev.name} does not support double precision floating point types"
71+
)
72+
dev_has_hp = dev.has_aspect_fp16
73+
if dev_has_hp is False and dt in [
74+
dpt.float16,
75+
]:
76+
pytest.skip(
77+
f"{dev.name} does not support half precision floating point type"
78+
)

0 commit comments

Comments
 (0)