From 7a33dd282c83de632370a25140682eea8b26c83f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 20 Mar 2023 22:17:07 +0100 Subject: [PATCH 1/3] Add an implementation of dpt.isdtype() --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_data_types.py | 45 +++++++ dpctl/tests/test_tensor_dtype_routines.py | 155 ++++++++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 dpctl/tests/test_tensor_dtype_routines.py diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 2a2afd60a4..77f102fa56 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -51,6 +51,7 @@ int16, int32, int64, + isdtype, uint8, uint16, uint32, @@ -125,6 +126,7 @@ "tril", "triu", "dtype", + "isdtype", "bool", "int8", "uint8", diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index c97afe37be..d4c0866cb8 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -31,8 +31,53 @@ complex64 = dtype("complex64") complex128 = dtype("complex128") + +def isdtype(dtype_, kind): + """ + Returns a boolean indicating whether a provided dtype is + of a specified data type `kind`. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more information + """ + + if not isinstance(dtype_, dtype): + raise TypeError("Expected instance of `dpt.dtype`, got {dtype_}") + + if isinstance(kind, dtype): + return dtype_ == kind + + elif isinstance(kind, str): + if kind == "bool": + return dtype_ == dtype("bool") + elif kind == "signed integer": + return dtype_.kind == "i" + elif kind == "unsigned integer": + return dtype_.kind == "u" + elif kind == "integral": + return dtype_.kind in ("u", "i") + elif kind == "real floating": + return dtype_.kind == "f" + elif kind == "complex floating": + return dtype_.kind == "c" + elif kind == "numeric": + return isdtype( + dtype_, ("integral", "real floating", "complex floating") + ) + else: + raise ValueError(f"Unrecognized data type kind: {kind}") + + elif isinstance(kind, tuple): + return any(isdtype(dtype_, k) for k in kind) + + else: + raise TypeError(f"Unsupported data type kind: {kind}") + + __all__ = [ "dtype", + "isdtype", "bool", "int8", "uint8", diff --git a/dpctl/tests/test_tensor_dtype_routines.py b/dpctl/tests/test_tensor_dtype_routines.py new file mode 100644 index 0000000000..50becdd976 --- /dev/null +++ b/dpctl/tests/test_tensor_dtype_routines.py @@ -0,0 +1,155 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import dpctl.tensor as dpt + +list_dtypes = [ + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "complex64", + "complex128", +] + + +dtype_categories = { + "bool": ["bool"], + "signed integer": ["int8", "int16", "int32", "int64"], + "unsigned integer": ["uint8", "uint16", "uint32", "uint64"], + "integral": [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + ], + "real floating": ["float16", "float32", "float64"], + "complex floating": ["complex64", "complex128"], + "numeric": [d for d in list_dtypes if d != "bool"], +} + + +@pytest.mark.parametrize("kind_str", dtype_categories.keys()) +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_str(dtype_str, kind_str): + if dtype_str in dtype_categories[kind_str]: + assert dpt.isdtype(dpt.dtype(dtype_str), kind_str) + else: + assert not dpt.isdtype(dpt.dtype(dtype_str), kind_str) + + +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_tuple(dtype_str): + if dtype_str.startswith("bool"): + assert dpt.isdtype(dpt.dtype(dtype_str), ("real floating", "bool")) + assert not dpt.isdtype( + dpt.dtype(dtype_str), + ("integral", "real floating", "complex floating"), + ) + elif dtype_str.startswith("int"): + assert dpt.isdtype( + dpt.dtype(dtype_str), ("real floating", "signed integer") + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), ("bool", "unsigned integer", "real floating") + ) + elif dtype_str.startswith("uint"): + assert dpt.isdtype(dpt.dtype(dtype_str), ("bool", "unsigned integer")) + assert not dpt.isdtype( + dpt.dtype(dtype_str), ("real floating", "complex floating") + ) + elif dtype_str.startswith("float"): + assert dpt.isdtype( + dpt.dtype(dtype_str), ("complex floating", "real floating") + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), ("integral", "complex floating") + ) + else: + assert dpt.isdtype( + dpt.dtype(dtype_str), ("integral", "complex floating") + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), ("bool", "integral", "real floating") + ) + + +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_tuple_dtypes(dtype_str): + if dtype_str.startswith("bool"): + assert dpt.isdtype(dpt.dtype(dtype_str), (dpt.int32, dpt.bool)) + assert not dpt.isdtype( + dpt.dtype(dtype_str), (dpt.int16, dpt.uint32, dpt.float64) + ) + elif dtype_str.startswith("int"): + assert dpt.isdtype( + dpt.dtype(dtype_str), (dpt.int8, dpt.int16, dpt.int32, dpt.int64) + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), (dpt.bool, dpt.float32, dpt.complex64) + ) + elif dtype_str.startswith("uint"): + assert dpt.isdtype( + dpt.dtype(dtype_str), + (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64), + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), (dpt.bool, dpt.int32, dpt.float32) + ) + elif dtype_str.startswith("float"): + assert dpt.isdtype( + dpt.dtype(dtype_str), (dpt.float16, dpt.float32, dpt.float64) + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), (dpt.bool, dpt.complex64, dpt.int8) + ) + else: + assert dpt.isdtype( + dpt.dtype(dtype_str), (dpt.complex64, dpt.complex128) + ) + assert not dpt.isdtype( + dpt.dtype(dtype_str), (dpt.bool, dpt.uint64, dpt.int8) + ) + + +@pytest.mark.parametrize( + "kind", + [ + [dpt.int32, dpt.bool], + "f4", + float, + 123, + "complex", + ], +) +def test_isdtype_invalid_kind(kind): + with pytest.raises((TypeError, ValueError)): + dpt.isdtype(dpt.int32, kind) From 2c7c0978c2850a4aac3c0d3da3fe648c27dbabfb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 22 Mar 2023 12:06:17 +0100 Subject: [PATCH 2/3] =?UTF-8?q?Fix=20do=D1=81strings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dpctl/tensor/_data_types.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index d4c0866cb8..0203df7687 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -33,13 +33,14 @@ def isdtype(dtype_, kind): - """ - Returns a boolean indicating whether a provided dtype is + """isdtype(dtype, kind) + + Returns a boolean indicating whether a provided `dtype` is of a specified data type `kind`. - See - https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html - for more information + See [array API](array_api) for more information. + + [array_api]: https://data-apis.org/array-api/latest/ """ if not isinstance(dtype_, dtype): From 36233fdb4524f372bf4e2b5fff1694ae70536d23 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 22 Mar 2023 15:30:23 +0100 Subject: [PATCH 3/3] Fix remarks --- dpctl/tensor/_data_types.py | 6 +- dpctl/tests/test_tensor_dtype_routines.py | 86 ++++++++--------------- 2 files changed, 32 insertions(+), 60 deletions(-) diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index 0203df7687..70129363de 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -57,15 +57,13 @@ def isdtype(dtype_, kind): elif kind == "unsigned integer": return dtype_.kind == "u" elif kind == "integral": - return dtype_.kind in ("u", "i") + return dtype_.kind in "iu" elif kind == "real floating": return dtype_.kind == "f" elif kind == "complex floating": return dtype_.kind == "c" elif kind == "numeric": - return isdtype( - dtype_, ("integral", "real floating", "complex floating") - ) + return dtype_.kind in "iufc" else: raise ValueError(f"Unrecognized data type kind: {kind}") diff --git a/dpctl/tests/test_tensor_dtype_routines.py b/dpctl/tests/test_tensor_dtype_routines.py index 50becdd976..acb1bb6d8b 100644 --- a/dpctl/tests/test_tensor_dtype_routines.py +++ b/dpctl/tests/test_tensor_dtype_routines.py @@ -60,84 +60,58 @@ @pytest.mark.parametrize("kind_str", dtype_categories.keys()) @pytest.mark.parametrize("dtype_str", list_dtypes) def test_isdtype_kind_str(dtype_str, kind_str): - if dtype_str in dtype_categories[kind_str]: - assert dpt.isdtype(dpt.dtype(dtype_str), kind_str) - else: - assert not dpt.isdtype(dpt.dtype(dtype_str), kind_str) + dt = dpt.dtype(dtype_str) + is_in_kind = dpt.isdtype(dt, kind_str) + expected = dtype_str in dtype_categories[kind_str] + assert is_in_kind == expected @pytest.mark.parametrize("dtype_str", list_dtypes) def test_isdtype_kind_tuple(dtype_str): + dt = dpt.dtype(dtype_str) if dtype_str.startswith("bool"): - assert dpt.isdtype(dpt.dtype(dtype_str), ("real floating", "bool")) + assert dpt.isdtype(dt, ("real floating", "bool")) assert not dpt.isdtype( - dpt.dtype(dtype_str), - ("integral", "real floating", "complex floating"), + dt, ("integral", "real floating", "complex floating") ) elif dtype_str.startswith("int"): - assert dpt.isdtype( - dpt.dtype(dtype_str), ("real floating", "signed integer") - ) + assert dpt.isdtype(dt, ("real floating", "signed integer")) assert not dpt.isdtype( - dpt.dtype(dtype_str), ("bool", "unsigned integer", "real floating") + dt, ("bool", "unsigned integer", "real floating") ) elif dtype_str.startswith("uint"): - assert dpt.isdtype(dpt.dtype(dtype_str), ("bool", "unsigned integer")) - assert not dpt.isdtype( - dpt.dtype(dtype_str), ("real floating", "complex floating") - ) + assert dpt.isdtype(dt, ("bool", "unsigned integer")) + assert not dpt.isdtype(dt, ("real floating", "complex floating")) elif dtype_str.startswith("float"): - assert dpt.isdtype( - dpt.dtype(dtype_str), ("complex floating", "real floating") - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), ("integral", "complex floating") - ) + assert dpt.isdtype(dt, ("complex floating", "real floating")) + assert not dpt.isdtype(dt, ("integral", "complex floating", "bool")) else: - assert dpt.isdtype( - dpt.dtype(dtype_str), ("integral", "complex floating") - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), ("bool", "integral", "real floating") - ) + assert dpt.isdtype(dt, ("integral", "complex floating")) + assert not dpt.isdtype(dt, ("bool", "integral", "real floating")) @pytest.mark.parametrize("dtype_str", list_dtypes) def test_isdtype_kind_tuple_dtypes(dtype_str): + dt = dpt.dtype(dtype_str) if dtype_str.startswith("bool"): - assert dpt.isdtype(dpt.dtype(dtype_str), (dpt.int32, dpt.bool)) - assert not dpt.isdtype( - dpt.dtype(dtype_str), (dpt.int16, dpt.uint32, dpt.float64) - ) + assert dpt.isdtype(dt, (dpt.int32, dpt.bool)) + assert not dpt.isdtype(dt, (dpt.int16, dpt.uint32, dpt.float64)) + elif dtype_str.startswith("int"): - assert dpt.isdtype( - dpt.dtype(dtype_str), (dpt.int8, dpt.int16, dpt.int32, dpt.int64) - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), (dpt.bool, dpt.float32, dpt.complex64) - ) + assert dpt.isdtype(dt, (dpt.int8, dpt.int16, dpt.int32, dpt.int64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.float32, dpt.complex64)) + elif dtype_str.startswith("uint"): - assert dpt.isdtype( - dpt.dtype(dtype_str), - (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64), - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), (dpt.bool, dpt.int32, dpt.float32) - ) + assert dpt.isdtype(dt, (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.int32, dpt.float32)) + elif dtype_str.startswith("float"): - assert dpt.isdtype( - dpt.dtype(dtype_str), (dpt.float16, dpt.float32, dpt.float64) - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), (dpt.bool, dpt.complex64, dpt.int8) - ) + assert dpt.isdtype(dt, (dpt.float16, dpt.float32, dpt.float64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.complex64, dpt.int8)) + else: - assert dpt.isdtype( - dpt.dtype(dtype_str), (dpt.complex64, dpt.complex128) - ) - assert not dpt.isdtype( - dpt.dtype(dtype_str), (dpt.bool, dpt.uint64, dpt.int8) - ) + assert dpt.isdtype(dt, (dpt.complex64, dpt.complex128)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.uint64, dpt.int8)) @pytest.mark.parametrize(