diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 3f1d093fe1..9768d9ea7d 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -58,6 +58,11 @@ squeeze, stack, ) +from dpctl.tensor._print import ( + get_print_options, + print_options, + set_print_options, +) from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray @@ -129,4 +134,7 @@ "can_cast", "result_type", "meshgrid", + "get_print_options", + "set_print_options", + "print_options", ] diff --git a/dpctl/tensor/_print.py b/dpctl/tensor/_print.py new file mode 100644 index 0000000000..f1e20a12c4 --- /dev/null +++ b/dpctl/tensor/_print.py @@ -0,0 +1,323 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import operator + +import numpy as np + +import dpctl.tensor as dpt + +__doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`." + +_print_options = { + "linewidth": 75, + "edgeitems": 3, + "threshold": 1000, + "precision": 8, + "floatmode": "maxprec", + "suppress": False, + "nanstr": "nan", + "infstr": "inf", + "sign": "-", +} + + +def _options_dict( + linewidth=None, + edgeitems=None, + threshold=None, + precision=None, + floatmode=None, + suppress=None, + nanstr=None, + infstr=None, + sign=None, + numpy=False, +): + if numpy: + numpy_options = np.get_printoptions() + options = {k: numpy_options[k] for k in _print_options.keys()} + else: + options = _print_options.copy() + + if suppress: + options["suppress"] = True + + local = dict(locals().items()) + for int_arg in ["linewidth", "precision", "threshold", "edgeitems"]: + val = local[int_arg] + if val is not None: + options[int_arg] = operator.index(val) + + for str_arg in ["nanstr", "infstr"]: + val = local[str_arg] + if val is not None: + if not isinstance(val, str): + raise TypeError( + "`{}` ".format(str_arg) + "must be of `string` type." + ) + options[str_arg] = val + + signs = ["-", "+", " "] + if sign is not None: + if sign not in signs: + raise ValueError( + "`sign` must be one of" + + ", ".join("`{}`".format(s) for s in signs) + ) + options["sign"] = sign + + floatmodes = ["fixed", "unique", "maxprec", "maxprec_equal"] + if floatmode is not None: + if floatmode not in floatmodes: + raise ValueError( + "`floatmode` must be one of" + + ", ".join("`{}`".format(m) for m in floatmodes) + ) + options["floatmode"] = floatmode + + return options + + +def set_print_options( + linewidth=None, + edgeitems=None, + threshold=None, + precision=None, + floatmode=None, + suppress=None, + nanstr=None, + infstr=None, + sign=None, + numpy=False, +): + """ + set_print_options(linewidth=None, edgeitems=None, threshold=None, + precision=None, floatmode=None, suppress=None, nanstr=None, + infstr=None, sign=None, numpy=False) + + Set options for printing ``dpctl.tensor.usm_ndarray`` class. + + Args: + linewidth (int, optional): Number of characters printed per line. + Raises `TypeError` if linewidth is not an integer. + Default: `75`. + edgeitems (int, optional): Number of elements at the beginning and end + when the printed array is abbreviated. + Raises `TypeError` if edgeitems is not an integer. + Default: `3`. + threshold (int, optional): Number of elements that triggers array + abbreviation. + Raises `TypeError` if threshold is not an integer. + Default: `1000`. + precision (int or None, optional): Number of digits printed for + floating point numbers. + Raises `TypeError` if precision is not an integer. + Default: `8`. + floatmode (str, optional): Controls how floating point + numbers are interpreted. + + `"fixed:`: Always prints exactly `precision` digits. + `"unique"`: Ignores precision, prints the number of + digits necessary to uniquely specify each number. + `"maxprec"`: Prints `precision` digits or fewer, + if fewer will uniquely represent a number. + `"maxprec_equal"`: Prints an equal number of digits + for each number. This number is `precision` digits or fewer, + if fewer will uniquely represent each number. + Raises `ValueError` if floatmode is not one of + `fixed`, `unique`, `maxprec`, or `maxprec_equal`. + Default: "maxprec_equal" + suppress (bool, optional): If `True,` numbers equal to zero + in the current precision will print as zero. + Default: `False`. + nanstr (str, optional): String used to repesent nan. + Raises `TypeError` if nanstr is not a string. + Default: `"nan"`. + infstr (str, optional): String used to represent infinity. + Raises `TypeError` if infstr is not a string. + Default: `"inf"`. + sign (str, optional): Controls the sign of floating point + numbers. + `"-"`: Omit the sign of positive numbers. + `"+"`: Always print the sign of positive numbers. + `" "`: Always print a whitespace in place of the + sign of positive numbers. + Raises `ValueError` if sign is not one of + `"-"`, `"+"`, or `" "`. + Default: `"-"`. + numpy (bool, optional): If `True,` then before other specified print + options are set, a dictionary of Numpy's print options + will be used to initialize dpctl's print options. + Default: "False" + """ + options = _options_dict( + linewidth=linewidth, + edgeitems=edgeitems, + threshold=threshold, + precision=precision, + floatmode=floatmode, + suppress=suppress, + nanstr=nanstr, + infstr=infstr, + sign=sign, + numpy=numpy, + ) + _print_options.update(options) + + +def get_print_options(): + """ + get_print_options() -> dict + + Returns a copy of current options for printing + ``dpctl.tensor.usm_ndarray`` class. + + Options: + - "linewidth" : int, default 75 + - "edgeitems" : int, default 3 + - "threshold" : int, default 1000 + - "precision" : int, default 8 + - "floatmode" : str, default "maxprec_equal" + - "suppress" : bool, default False + - "nanstr" : str, default "nan" + - "infstr" : str, default "inf" + - "sign" : str, default "-" + """ + return _print_options.copy() + + +@contextlib.contextmanager +def print_options(*args, **kwargs): + """ + Context manager for print options. + + Set print options for the scope of a `with` block. + `as` yields dictionary of print options. + """ + options = dpt.get_print_options() + try: + dpt.set_print_options(*args, **kwargs) + yield dpt.get_print_options() + finally: + dpt.set_print_options(**options) + + +def _nd_corners(x, edge_items, slices=()): + axes_reduced = len(slices) + if axes_reduced == x.ndim: + return x[slices] + + if x.shape[axes_reduced] > 2 * edge_items: + return dpt.concat( + ( + _nd_corners( + x, edge_items, slices + (slice(None, edge_items, None),) + ), + _nd_corners( + x, edge_items, slices + (slice(-edge_items, None, None),) + ), + ), + axis=axes_reduced, + ) + else: + return _nd_corners(x, edge_items, slices + (slice(None, None, None),)) + + +def _usm_ndarray_str( + x, + line_width=None, + edge_items=None, + threshold=None, + precision=None, + floatmode=None, + suppress=None, + sign=None, + numpy=False, + separator=" ", + prefix="", + suffix="", +): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + options = get_print_options() + options.update( + _options_dict( + linewidth=line_width, + edgeitems=edge_items, + threshold=threshold, + precision=precision, + floatmode=floatmode, + suppress=suppress, + sign=sign, + numpy=numpy, + ) + ) + + threshold = options["threshold"] + edge_items = options["edgeitems"] + + if x.size > threshold: + # need edge_items + 1 elements for np.array2string to abbreviate + data = dpt.asnumpy(_nd_corners(x, edge_items + 1)) + options["threshold"] = 0 + else: + data = dpt.asnumpy(x) + with np.printoptions(**options): + s = np.array2string( + data, separator=separator, prefix=prefix, suffix=suffix + ) + return s + + +def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + if line_width is None: + line_width = _print_options["linewidth"] + + show_dtype = x.dtype not in [ + dpt.bool, + dpt.int64, + dpt.float64, + dpt.complex128, + ] + + prefix = "usm_ndarray(" + suffix = ")" + + s = _usm_ndarray_str( + x, + line_width=line_width, + precision=precision, + suppress=suppress, + separator=", ", + prefix=prefix, + suffix=suffix, + ) + + if show_dtype: + dtype_str = "dtype={}".format(x.dtype.name) + bottom_len = len(s) - (s.rfind("\n") + 1) + next_line = bottom_len + len(dtype_str) + 1 > line_width + dtype_str = ",\n" + dtype_str if next_line else ", " + dtype_str + else: + dtype_str = "" + + return prefix + s + dtype_str + suffix diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 96a865fcb6..261385a3b3 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -26,6 +26,7 @@ import dpctl import dpctl.memory as dpmem from ._device import Device +from ._print import _usm_ndarray_repr, _usm_ndarray_str from cpython.mem cimport PyMem_Free from cpython.tuple cimport PyTuple_New, PyTuple_SetItem @@ -1131,6 +1132,12 @@ cdef class usm_ndarray: self.__setitem__(Ellipsis, res) return self + def __str__(self): + return _usm_ndarray_str(self) + + def __repr__(self): + return _usm_ndarray_repr(self) + cdef usm_ndarray _real_view(usm_ndarray ary): """ diff --git a/dpctl/tests/test_usm_ndarray_print.py b/dpctl/tests/test_usm_ndarray_print.py new file mode 100644 index 0000000000..47e4910921 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_print.py @@ -0,0 +1,274 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from helper import get_queue_or_skip, skip_if_dtype_not_supported + +import dpctl.tensor as dpt + + +class TestPrint: + def setup_method(self): + self._retain_options = dpt.get_print_options() + + def teardown_method(self): + dpt.set_print_options(**self._retain_options) + + +class TestArgValidation(TestPrint): + @pytest.mark.parametrize( + "arg,err", + [ + ({"linewidth": "I"}, TypeError), + ({"edgeitems": "I"}, TypeError), + ({"threshold": "I"}, TypeError), + ({"precision": "I"}, TypeError), + ({"floatmode": "I"}, ValueError), + ({"edgeitems": "I"}, TypeError), + ({"sign": "I"}, ValueError), + ({"nanstr": np.nan}, TypeError), + ({"infstr": np.nan}, TypeError), + ], + ) + def test_print_option_arg_validation(self, arg, err): + with pytest.raises(err): + dpt.set_print_options(**arg) + + +class TestSetPrintOptions(TestPrint): + def test_set_linewidth(self): + q = get_queue_or_skip() + + dpt.set_print_options(linewidth=1) + x = dpt.asarray([0, 1], sycl_queue=q) + assert str(x) == "[0\n 1]" + + def test_set_precision(self): + q = get_queue_or_skip() + + dpt.set_print_options(precision=4) + x = dpt.asarray([1.23450], sycl_queue=q) + assert str(x) == "[1.2345]" + + def test_threshold_edgeitems(self): + q = get_queue_or_skip() + + dpt.set_print_options(threshold=1, edgeitems=1) + x = dpt.arange(9, sycl_queue=q) + assert str(x) == "[0 ... 8]" + dpt.set_print_options(edgeitems=9) + assert str(x) == "[0 1 2 3 4 5 6 7 8]" + + def test_floatmodes(self): + q = get_queue_or_skip() + + x = dpt.asarray([0.1234, 0.1234678], sycl_queue=q) + dpt.set_print_options(floatmode="fixed", precision=4) + assert str(x) == "[0.1234 0.1235]" + + dpt.set_print_options(floatmode="unique") + assert str(x) == "[0.1234 0.1234678]" + + dpt.set_print_options(floatmode="maxprec") + assert str(x) == "[0.1234 0.1235]" + + dpt.set_print_options(floatmode="maxprec", precision=8) + assert str(x) == "[0.1234 0.1234678]" + + dpt.set_print_options(floatmode="maxprec_equal", precision=4) + assert str(x) == "[0.1234 0.1235]" + + dpt.set_print_options(floatmode="maxprec_equal", precision=8) + assert str(x) == "[0.1234000 0.1234678]" + + def test_nan_inf_suppress(self): + q = get_queue_or_skip() + + dpt.set_print_options(nanstr="nan1", infstr="inf1") + x = dpt.asarray([np.nan, np.inf], sycl_queue=q) + assert str(x) == "[nan1 inf1]" + + def test_suppress_small(self): + q = get_queue_or_skip() + + dpt.set_print_options(suppress=True) + x = dpt.asarray(5e-10, sycl_queue=q) + assert str(x) == "0." + + def test_sign(self): + q = get_queue_or_skip() + + x = dpt.asarray([0.0, 1.0, 2.0], sycl_queue=q) + y = dpt.asarray(1.0, sycl_queue=q) + z = dpt.asarray([1.0 + 1.0j], sycl_queue=q) + assert str(x) == "[0. 1. 2.]" + assert str(y) == "1." + assert str(z) == "[1.+1.j]" + + dpt.set_print_options(sign="+") + assert str(x) == "[+0. +1. +2.]" + assert str(y) == "+1." + assert str(z) == "[+1.+1.j]" + + dpt.set_print_options(sign=" ") + assert str(x) == "[ 0. 1. 2.]" + assert str(y) == " 1." + assert str(z) == "[ 1.+1.j]" + + def test_numpy(self): + dpt.set_print_options(numpy=True) + options = dpt.get_print_options() + np_options = np.get_printoptions() + assert all(np_options[k] == options[k] for k in options.keys()) + + +class TestPrintFns(TestPrint): + @pytest.mark.parametrize( + "dtype,x_str", + [ + ("b1", "[False True True True]"), + ("i1", "[0 1 2 3]"), + ("u1", "[0 1 2 3]"), + ("i2", "[0 1 2 3]"), + ("u2", "[0 1 2 3]"), + ("i4", "[0 1 2 3]"), + ("u4", "[0 1 2 3]"), + ("i8", "[0 1 2 3]"), + ("u8", "[0 1 2 3]"), + ("f2", "[0. 1. 2. 3.]"), + ("f4", "[0. 1. 2. 3.]"), + ("f8", "[0. 1. 2. 3.]"), + ("c8", "[0.+0.j 1.+0.j 2.+0.j 3.+0.j]"), + ("c16", "[0.+0.j 1.+0.j 2.+0.j 3.+0.j]"), + ], + ) + def test_print_types(self, dtype, x_str): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.asarray([0, 1, 2, 3], dtype=dtype, sycl_queue=q) + assert str(x) == x_str + + def test_print_str(self): + q = get_queue_or_skip() + + x = dpt.asarray(0, sycl_queue=q) + assert str(x) == "0" + + x = dpt.asarray([np.nan, np.inf], sycl_queue=q) + assert str(x) == "[nan inf]" + + x = dpt.arange(9, sycl_queue=q) + assert str(x) == "[0 1 2 3 4 5 6 7 8]" + + y = dpt.reshape(x, (3, 3), copy=True) + assert str(y) == "[[0 1 2]\n [3 4 5]\n [6 7 8]]" + + def test_print_str_abbreviated(self): + q = get_queue_or_skip() + + dpt.set_print_options(threshold=0, edgeitems=1) + x = dpt.arange(9, sycl_queue=q) + assert str(x) == "[0 ... 8]" + + x = dpt.reshape(x, (3, 3)) + assert str(x) == "[[0 ... 2]\n ...\n [6 ... 8]]" + + def test_print_repr(self): + q = get_queue_or_skip() + + x = dpt.asarray(0, dtype="int64", sycl_queue=q) + assert repr(x) == "usm_ndarray(0)" + + x = dpt.asarray([np.nan, np.inf], sycl_queue=q) + assert repr(x) == "usm_ndarray([nan, inf])" + + x = dpt.arange(9, sycl_queue=q, dtype="int64") + assert repr(x) == "usm_ndarray([0, 1, 2, 3, 4, 5, 6, 7, 8])" + + x = dpt.reshape(x, (3, 3)) + np.testing.assert_equal( + repr(x), + "usm_ndarray([[0, 1, 2]," + "\n [3, 4, 5]," + "\n [6, 7, 8]])", + ) + + x = dpt.arange(4, dtype="i4", sycl_queue=q) + assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)" + + def test_print_repr_abbreviated(self): + q = get_queue_or_skip() + + dpt.set_print_options(threshold=0, edgeitems=1) + x = dpt.arange(9, dtype="int64", sycl_queue=q) + assert repr(x) == "usm_ndarray([0, ..., 8])" + + y = dpt.asarray(x, dtype="i4", copy=True) + assert repr(y) == "usm_ndarray([0, ..., 8], dtype=int32)" + + x = dpt.reshape(x, (3, 3)) + np.testing.assert_equal( + repr(x), + "usm_ndarray([[0, ..., 2]," + "\n ...," + "\n [6, ..., 8]])", + ) + + y = dpt.reshape(y, (3, 3)) + np.testing.assert_equal( + repr(y), + "usm_ndarray([[0, ..., 2]," + "\n ...," + "\n [6, ..., 8]], dtype=int32)", + ) + + @pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "u8", + "f2", + "f4", + "c8", + ], + ) + def test_repr_appended_dtype(self, dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.empty(4, dtype=dtype) + assert repr(x).split("=")[-1][:-1] == x.dtype.name + + +class TestContextManager: + def test_context_manager_basic(self): + options = dpt.get_print_options() + with dpt.print_options(precision=4): + s = str(dpt.asarray(1.234567)) + assert s == "1.2346" + assert options == dpt.get_print_options() + + def test_context_manager_as(self): + with dpt.print_options(precision=4) as x: + options = x.copy() + assert options["precision"] == 4