Skip to content

Commit 36233fd

Browse files
Fix remarks
1 parent a23b6bf commit 36233fd

File tree

2 files changed

+32
-60
lines changed

2 files changed

+32
-60
lines changed

dpctl/tensor/_data_types.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,13 @@ def isdtype(dtype_, kind):
5757
elif kind == "unsigned integer":
5858
return dtype_.kind == "u"
5959
elif kind == "integral":
60-
return dtype_.kind in ("u", "i")
60+
return dtype_.kind in "iu"
6161
elif kind == "real floating":
6262
return dtype_.kind == "f"
6363
elif kind == "complex floating":
6464
return dtype_.kind == "c"
6565
elif kind == "numeric":
66-
return isdtype(
67-
dtype_, ("integral", "real floating", "complex floating")
68-
)
66+
return dtype_.kind in "iufc"
6967
else:
7068
raise ValueError(f"Unrecognized data type kind: {kind}")
7169

dpctl/tests/test_tensor_dtype_routines.py

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -60,84 +60,58 @@
6060
@pytest.mark.parametrize("kind_str", dtype_categories.keys())
6161
@pytest.mark.parametrize("dtype_str", list_dtypes)
6262
def test_isdtype_kind_str(dtype_str, kind_str):
63-
if dtype_str in dtype_categories[kind_str]:
64-
assert dpt.isdtype(dpt.dtype(dtype_str), kind_str)
65-
else:
66-
assert not dpt.isdtype(dpt.dtype(dtype_str), kind_str)
63+
dt = dpt.dtype(dtype_str)
64+
is_in_kind = dpt.isdtype(dt, kind_str)
65+
expected = dtype_str in dtype_categories[kind_str]
66+
assert is_in_kind == expected
6767

6868

6969
@pytest.mark.parametrize("dtype_str", list_dtypes)
7070
def test_isdtype_kind_tuple(dtype_str):
71+
dt = dpt.dtype(dtype_str)
7172
if dtype_str.startswith("bool"):
72-
assert dpt.isdtype(dpt.dtype(dtype_str), ("real floating", "bool"))
73+
assert dpt.isdtype(dt, ("real floating", "bool"))
7374
assert not dpt.isdtype(
74-
dpt.dtype(dtype_str),
75-
("integral", "real floating", "complex floating"),
75+
dt, ("integral", "real floating", "complex floating")
7676
)
7777
elif dtype_str.startswith("int"):
78-
assert dpt.isdtype(
79-
dpt.dtype(dtype_str), ("real floating", "signed integer")
80-
)
78+
assert dpt.isdtype(dt, ("real floating", "signed integer"))
8179
assert not dpt.isdtype(
82-
dpt.dtype(dtype_str), ("bool", "unsigned integer", "real floating")
80+
dt, ("bool", "unsigned integer", "real floating")
8381
)
8482
elif dtype_str.startswith("uint"):
85-
assert dpt.isdtype(dpt.dtype(dtype_str), ("bool", "unsigned integer"))
86-
assert not dpt.isdtype(
87-
dpt.dtype(dtype_str), ("real floating", "complex floating")
88-
)
83+
assert dpt.isdtype(dt, ("bool", "unsigned integer"))
84+
assert not dpt.isdtype(dt, ("real floating", "complex floating"))
8985
elif dtype_str.startswith("float"):
90-
assert dpt.isdtype(
91-
dpt.dtype(dtype_str), ("complex floating", "real floating")
92-
)
93-
assert not dpt.isdtype(
94-
dpt.dtype(dtype_str), ("integral", "complex floating")
95-
)
86+
assert dpt.isdtype(dt, ("complex floating", "real floating"))
87+
assert not dpt.isdtype(dt, ("integral", "complex floating", "bool"))
9688
else:
97-
assert dpt.isdtype(
98-
dpt.dtype(dtype_str), ("integral", "complex floating")
99-
)
100-
assert not dpt.isdtype(
101-
dpt.dtype(dtype_str), ("bool", "integral", "real floating")
102-
)
89+
assert dpt.isdtype(dt, ("integral", "complex floating"))
90+
assert not dpt.isdtype(dt, ("bool", "integral", "real floating"))
10391

10492

10593
@pytest.mark.parametrize("dtype_str", list_dtypes)
10694
def test_isdtype_kind_tuple_dtypes(dtype_str):
95+
dt = dpt.dtype(dtype_str)
10796
if dtype_str.startswith("bool"):
108-
assert dpt.isdtype(dpt.dtype(dtype_str), (dpt.int32, dpt.bool))
109-
assert not dpt.isdtype(
110-
dpt.dtype(dtype_str), (dpt.int16, dpt.uint32, dpt.float64)
111-
)
97+
assert dpt.isdtype(dt, (dpt.int32, dpt.bool))
98+
assert not dpt.isdtype(dt, (dpt.int16, dpt.uint32, dpt.float64))
99+
112100
elif dtype_str.startswith("int"):
113-
assert dpt.isdtype(
114-
dpt.dtype(dtype_str), (dpt.int8, dpt.int16, dpt.int32, dpt.int64)
115-
)
116-
assert not dpt.isdtype(
117-
dpt.dtype(dtype_str), (dpt.bool, dpt.float32, dpt.complex64)
118-
)
101+
assert dpt.isdtype(dt, (dpt.int8, dpt.int16, dpt.int32, dpt.int64))
102+
assert not dpt.isdtype(dt, (dpt.bool, dpt.float32, dpt.complex64))
103+
119104
elif dtype_str.startswith("uint"):
120-
assert dpt.isdtype(
121-
dpt.dtype(dtype_str),
122-
(dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64),
123-
)
124-
assert not dpt.isdtype(
125-
dpt.dtype(dtype_str), (dpt.bool, dpt.int32, dpt.float32)
126-
)
105+
assert dpt.isdtype(dt, (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64))
106+
assert not dpt.isdtype(dt, (dpt.bool, dpt.int32, dpt.float32))
107+
127108
elif dtype_str.startswith("float"):
128-
assert dpt.isdtype(
129-
dpt.dtype(dtype_str), (dpt.float16, dpt.float32, dpt.float64)
130-
)
131-
assert not dpt.isdtype(
132-
dpt.dtype(dtype_str), (dpt.bool, dpt.complex64, dpt.int8)
133-
)
109+
assert dpt.isdtype(dt, (dpt.float16, dpt.float32, dpt.float64))
110+
assert not dpt.isdtype(dt, (dpt.bool, dpt.complex64, dpt.int8))
111+
134112
else:
135-
assert dpt.isdtype(
136-
dpt.dtype(dtype_str), (dpt.complex64, dpt.complex128)
137-
)
138-
assert not dpt.isdtype(
139-
dpt.dtype(dtype_str), (dpt.bool, dpt.uint64, dpt.int8)
140-
)
113+
assert dpt.isdtype(dt, (dpt.complex64, dpt.complex128))
114+
assert not dpt.isdtype(dt, (dpt.bool, dpt.uint64, dpt.int8))
141115

142116

143117
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)