Skip to content

Commit f08413e

Browse files
committed
Added tests for array printing
1 parent 42a9dc6 commit f08413e

File tree

3 files changed

+280
-6
lines changed

3 files changed

+280
-6
lines changed

dpctl/tensor/_print.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _nd_corners(x, edge_items, slices=()):
157157
return _nd_corners(x, edge_items, slices + (slice(None, None, None),))
158158

159159

160-
def usm_ndarray_str(
160+
def _usm_ndarray_str(
161161
x,
162162
line_width=None,
163163
edge_items=None,
@@ -204,7 +204,7 @@ def usm_ndarray_str(
204204
return s
205205

206206

207-
def usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
207+
def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
208208
if not isinstance(x, dpt.usm_ndarray):
209209
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
210210

@@ -221,7 +221,7 @@ def usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
221221
prefix = "usm_ndarray("
222222
suffix = ")"
223223

224-
s = usm_ndarray_str(
224+
s = _usm_ndarray_str(
225225
x,
226226
line_width=line_width,
227227
precision=precision,

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import dpctl
2626
import dpctl.memory as dpmem
2727

2828
from ._device import Device
29-
from ._print import usm_ndarray_repr, usm_ndarray_str
29+
from ._print import _usm_ndarray_repr, _usm_ndarray_str
3030

3131
from cpython.mem cimport PyMem_Free
3232
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
@@ -1133,10 +1133,10 @@ cdef class usm_ndarray:
11331133
return self
11341134

11351135
def __str__(self):
1136-
return usm_ndarray_str(self)
1136+
return _usm_ndarray_str(self)
11371137

11381138
def __repr__(self):
1139-
return usm_ndarray_repr(self)
1139+
return _usm_ndarray_repr(self)
11401140

11411141

11421142
cdef usm_ndarray _real_view(usm_ndarray ary):

dpctl/tests/test_usm_ndarray_print.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
import pytest
19+
from helper import get_queue_or_skip, skip_if_dtype_not_supported
20+
21+
import dpctl.tensor as dpt
22+
23+
24+
class TestPrint:
25+
def setup_method(self):
26+
self._retain_options = dpt.get_print_options()
27+
28+
def teardown_method(self):
29+
dpt.set_print_options(**self._retain_options)
30+
31+
32+
class TestArgValidation(TestPrint):
33+
@pytest.mark.parametrize(
34+
"arg,err",
35+
[
36+
({"linewidth": "I"}, TypeError),
37+
({"edgeitems": "I"}, TypeError),
38+
({"threshold": "I"}, TypeError),
39+
({"precision": "I"}, TypeError),
40+
({"floatmode": "I"}, ValueError),
41+
({"edgeitems": "I"}, TypeError),
42+
({"sign": "I"}, ValueError),
43+
({"nanstr": np.nan}, TypeError),
44+
({"infstr": np.nan}, TypeError),
45+
],
46+
)
47+
def test_print_option_arg_validation(self, arg, err):
48+
with pytest.raises(err):
49+
dpt.set_print_options(**arg)
50+
51+
52+
class TestSetPrintOptions(TestPrint):
53+
def test_set_linewidth(self):
54+
q = get_queue_or_skip()
55+
56+
dpt.set_print_options(linewidth=1)
57+
x = dpt.asarray([0, 1], sycl_queue=q)
58+
assert str(x) == "[0\n 1]"
59+
60+
def test_set_precision(self):
61+
q = get_queue_or_skip()
62+
63+
dpt.set_print_options(precision=4)
64+
x = dpt.asarray([1.23450], sycl_queue=q)
65+
assert str(x) == "[1.2345]"
66+
67+
def test_threshold_edgeitems(self):
68+
q = get_queue_or_skip()
69+
70+
dpt.set_print_options(threshold=1, edgeitems=1)
71+
x = dpt.arange(9, sycl_queue=q)
72+
assert str(x) == "[0 ... 8]"
73+
dpt.set_print_options(edgeitems=9)
74+
assert str(x) == "[0 1 2 3 4 5 6 7 8]"
75+
76+
def test_floatmodes(self):
77+
q = get_queue_or_skip()
78+
79+
x = dpt.asarray([0.1234, 0.1234678], sycl_queue=q)
80+
dpt.set_print_options(floatmode="fixed", precision=4)
81+
assert str(x) == "[0.1234 0.1235]"
82+
83+
dpt.set_print_options(floatmode="unique")
84+
assert str(x) == "[0.1234 0.1234678]"
85+
86+
dpt.set_print_options(floatmode="maxprec")
87+
assert str(x) == "[0.1234 0.1235]"
88+
89+
dpt.set_print_options(floatmode="maxprec", precision=8)
90+
assert str(x) == "[0.1234 0.1234678]"
91+
92+
dpt.set_print_options(floatmode="maxprec_equal", precision=4)
93+
assert str(x) == "[0.1234 0.1235]"
94+
95+
dpt.set_print_options(floatmode="maxprec_equal", precision=8)
96+
assert str(x) == "[0.1234000 0.1234678]"
97+
98+
def test_nan_inf_suppress(self):
99+
q = get_queue_or_skip()
100+
101+
dpt.set_print_options(nanstr="nan1", infstr="inf1")
102+
x = dpt.asarray([np.nan, np.inf], sycl_queue=q)
103+
assert str(x) == "[nan1 inf1]"
104+
105+
def test_suppress_small(self):
106+
q = get_queue_or_skip()
107+
108+
dpt.set_print_options(suppress=True)
109+
x = dpt.asarray(5e-10, sycl_queue=q)
110+
assert str(x) == "0."
111+
112+
def test_sign(self):
113+
q = get_queue_or_skip()
114+
115+
x = dpt.asarray([0.0, 1.0, 2.0], sycl_queue=q)
116+
y = dpt.asarray(1.0, sycl_queue=q)
117+
z = dpt.asarray([1.0 + 1.0j], sycl_queue=q)
118+
assert str(x) == "[0. 1. 2.]"
119+
assert str(y) == "1."
120+
assert str(z) == "[1.+1.j]"
121+
122+
dpt.set_print_options(sign="+")
123+
assert str(x) == "[+0. +1. +2.]"
124+
assert str(y) == "+1."
125+
assert str(z) == "[+1.+1.j]"
126+
127+
dpt.set_print_options(sign=" ")
128+
assert str(x) == "[ 0. 1. 2.]"
129+
assert str(y) == " 1."
130+
assert str(z) == "[ 1.+1.j]"
131+
132+
def test_numpy(self):
133+
dpt.set_print_options(numpy=True)
134+
options = dpt.get_print_options()
135+
np_options = np.get_printoptions()
136+
assert all(np_options[k] == options[k] for k in options.keys())
137+
138+
139+
class TestPrintFns(TestPrint):
140+
@pytest.mark.parametrize(
141+
"dtype,x_str",
142+
[
143+
("b1", "[False True True True]"),
144+
("i1", "[0 1 2 3]"),
145+
("u1", "[0 1 2 3]"),
146+
("i2", "[0 1 2 3]"),
147+
("u2", "[0 1 2 3]"),
148+
("i4", "[0 1 2 3]"),
149+
("u4", "[0 1 2 3]"),
150+
("i8", "[0 1 2 3]"),
151+
("u8", "[0 1 2 3]"),
152+
("f2", "[0. 1. 2. 3.]"),
153+
("f4", "[0. 1. 2. 3.]"),
154+
("f8", "[0. 1. 2. 3.]"),
155+
("c8", "[0.+0.j 1.+0.j 2.+0.j 3.+0.j]"),
156+
("c16", "[0.+0.j 1.+0.j 2.+0.j 3.+0.j]"),
157+
],
158+
)
159+
def test_print_types(self, dtype, x_str):
160+
q = get_queue_or_skip()
161+
skip_if_dtype_not_supported(dtype, q)
162+
163+
x = dpt.asarray([0, 1, 2, 3], dtype=dtype, sycl_queue=q)
164+
assert str(x) == x_str
165+
166+
def test_print_str(self):
167+
q = get_queue_or_skip()
168+
169+
x = dpt.asarray(0, sycl_queue=q)
170+
assert str(x) == "0"
171+
172+
x = dpt.asarray([np.nan, np.inf], sycl_queue=q)
173+
assert str(x) == "[nan inf]"
174+
175+
x = dpt.arange(9, sycl_queue=q)
176+
assert str(x) == "[0 1 2 3 4 5 6 7 8]"
177+
178+
y = dpt.reshape(x, (3, 3), copy=True)
179+
assert str(y) == "[[0 1 2]\n [3 4 5]\n [6 7 8]]"
180+
181+
def test_print_str_abbreviated(self):
182+
q = get_queue_or_skip()
183+
184+
dpt.set_print_options(threshold=0, edgeitems=1)
185+
x = dpt.arange(9, sycl_queue=q)
186+
assert str(x) == "[0 ... 8]"
187+
188+
x = dpt.reshape(x, (3, 3))
189+
assert str(x) == "[[0 ... 2]\n ...\n [6 ... 8]]"
190+
191+
def test_print_repr(self):
192+
q = get_queue_or_skip()
193+
194+
x = dpt.asarray(0, sycl_queue=q)
195+
assert repr(x) == "usm_ndarray(0)"
196+
197+
x = dpt.asarray([np.nan, np.inf], sycl_queue=q)
198+
assert repr(x) == "usm_ndarray([nan, inf])"
199+
200+
x = dpt.arange(9, sycl_queue=q)
201+
assert repr(x) == "usm_ndarray([0, 1, 2, 3, 4, 5, 6, 7, 8])"
202+
203+
x = dpt.reshape(x, (3, 3))
204+
np.testing.assert_equal(
205+
repr(x),
206+
"usm_ndarray([[0, 1, 2],"
207+
"\n [3, 4, 5],"
208+
"\n [6, 7, 8]])",
209+
)
210+
211+
x = dpt.arange(4, dtype="f2", sycl_queue=q)
212+
assert repr(x) == "usm_ndarray([0., 1., 2., 3.], dtype=float16)"
213+
214+
def test_print_repr_abbreviated(self):
215+
q = get_queue_or_skip()
216+
217+
dpt.set_print_options(threshold=0, edgeitems=1)
218+
x = dpt.arange(9, sycl_queue=q)
219+
assert repr(x) == "usm_ndarray([0, ..., 8])"
220+
221+
y = dpt.asarray(x, dtype="f2", copy=True)
222+
assert repr(y) == "usm_ndarray([0., ..., 8.], dtype=float16)"
223+
224+
x = dpt.reshape(x, (3, 3))
225+
np.testing.assert_equal(
226+
repr(x),
227+
"usm_ndarray([[0, ..., 2],"
228+
"\n ...,"
229+
"\n [6, ..., 8]])",
230+
)
231+
232+
y = dpt.reshape(y, (3, 3))
233+
np.testing.assert_equal(
234+
repr(y),
235+
"usm_ndarray([[0., ..., 2.],"
236+
"\n ...,"
237+
"\n [6., ..., 8.]], dtype=float16)",
238+
)
239+
240+
@pytest.mark.parametrize(
241+
"dtype",
242+
[
243+
"i1",
244+
"u1",
245+
"i2",
246+
"u2",
247+
"i4",
248+
"u4",
249+
"u8",
250+
"f2",
251+
"f4",
252+
"c8",
253+
],
254+
)
255+
def test_repr_appended_dtype(self, dtype):
256+
q = get_queue_or_skip()
257+
skip_if_dtype_not_supported(dtype, q)
258+
259+
x = dpt.empty(4, dtype=dtype)
260+
assert repr(x).split("=")[-1][:-1] == x.dtype.name
261+
262+
263+
class TestContextManager:
264+
def test_context_manager_basic(self):
265+
options = dpt.get_print_options()
266+
with dpt.print_options(precision=4):
267+
s = str(dpt.asarray(1.234567))
268+
assert s == "1.2346"
269+
assert options == dpt.get_print_options()
270+
271+
def test_context_manager_as(self):
272+
with dpt.print_options(precision=4) as x:
273+
options = x.copy()
274+
assert options["precision"] == 4

0 commit comments

Comments
 (0)