From 4a0976637a19d20ae9f9340d023787f13e9cad47 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Tue, 6 Jun 2023 22:14:39 -0500 Subject: [PATCH] Fixed behavior of mathematical functions for floating-point and complex floating-point scalars. --- dpctl/tensor/_elementwise_common.py | 62 +++++++++++++++++------- dpctl/tests/elementwise/test_multiply.py | 21 +++++--- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index a775376d95..ebe6670425 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -162,9 +162,18 @@ def get(self): return self.o_ -class WeakInexactType: - """Python type representing type of Python real- or - complex-valued floating point objects""" +class WeakFloatingType: + """Python type representing type of Python floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakComplexType: + """Python type representing type of Python complex floating point objects""" def __init__(self, o): self.o_ = o @@ -189,14 +198,17 @@ def _get_dtype(o, dev): return WeakBooleanType(o) if isinstance(o, int): return WeakIntegralType(o) - if isinstance(o, (float, complex)): - return WeakInexactType(o) + if isinstance(o, float): + return WeakFloatingType(o) + if isinstance(o, complex): + return WeakComplexType(o) return np.object_ def _validate_dtype(dt) -> bool: return isinstance( - dt, (WeakBooleanType, WeakInexactType, WeakIntegralType) + dt, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), ) or ( isinstance(dt, dpt.dtype) and dt @@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool: def _weak_type_num_kind(o): - _map = {"?": 0, "i": 1, "f": 2} + _map = {"?": 0, "i": 1, "f": 2, "c": 3} if isinstance(o, WeakBooleanType): return _map["?"] if isinstance(o, WeakIntegralType): return _map["i"] - if isinstance(o, WeakInexactType): + if isinstance(o, WeakFloatingType): return _map["f"] + if isinstance(o, WeakComplexType): + return _map["c"] raise TypeError( f"Unexpected type {o} while expecting " - "`WeakBooleanType`, `WeakIntegralType`, or " - "`WeakInexactType`." + "`WeakBooleanType`, `WeakIntegralType`," + "`WeakFloatingType`, or `WeakComplexType`." ) def _strong_dtype_num_kind(o): - _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 2} + _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3} if not isinstance(o, dpt.dtype): raise TypeError k = o.kind @@ -247,20 +261,29 @@ def _strong_dtype_num_kind(o): def _resolve_weak_types(o1_dtype, o2_dtype, dev): "Resolves weak data type per NEP-0050" if isinstance( - o1_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + o1_dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), ): if isinstance( - o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + o2_dtype, + ( + WeakBooleanType, + WeakIntegralType, + WeakFloatingType, + WeakComplexType, + ), ): raise ValueError o1_kind_num = _weak_type_num_kind(o1_dtype) o2_kind_num = _strong_dtype_num_kind(o2_dtype) - if o1_kind_num > o2_kind_num or o1_kind_num == 2: + if o1_kind_num > o2_kind_num: if isinstance(o1_dtype, WeakBooleanType): return dpt.bool, o2_dtype if isinstance(o1_dtype, WeakIntegralType): return dpt.int64, o2_dtype - if isinstance(o1_dtype.get(), complex): + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype return ( _to_device_supported_dtype(dpt.complex128, dev), o2_dtype, @@ -269,16 +292,19 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): else: return o2_dtype, o2_dtype elif isinstance( - o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + o2_dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), ): o1_kind_num = _strong_dtype_num_kind(o1_dtype) o2_kind_num = _weak_type_num_kind(o2_dtype) - if o2_kind_num > o1_kind_num or o2_kind_num == 2: + if o2_kind_num > o1_kind_num: if isinstance(o2_dtype, WeakBooleanType): return o1_dtype, dpt.bool if isinstance(o2_dtype, WeakIntegralType): return o1_dtype, dpt.int64 - if isinstance(o2_dtype.get(), complex): + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) return ( o1_dtype, diff --git a/dpctl/tests/elementwise/test_multiply.py b/dpctl/tests/elementwise/test_multiply.py index 1305154021..17c0f905e7 100644 --- a/dpctl/tests/elementwise/test_multiply.py +++ b/dpctl/tests/elementwise/test_multiply.py @@ -154,16 +154,21 @@ def test_multiply_python_scalar(arr_dt): assert isinstance(R, dpt.usm_ndarray) -def test_multiply_python_scalar_gh1219(): +@pytest.mark.parametrize("arr_dt", _all_dtypes) +@pytest.mark.parametrize("sc", [bool(1), int(1), float(1), complex(1)]) +def test_multiply_python_scalar_gh1219(arr_dt, sc): q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) - X = dpt.ones(4, dtype="f4", sycl_queue=q) + Xnp = np.ones(4, dtype=arr_dt) - r = dpt.multiply(X, 2j) - expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q)) - assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + X = dpt.ones(4, dtype=arr_dt, sycl_queue=q) + + R = dpt.multiply(X, sc) + Rnp = np.multiply(Xnp, sc) + assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q) # symmetric case - r = dpt.multiply(2j, X) - expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X) - assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + R = dpt.multiply(sc, X) + Rnp = np.multiply(sc, Xnp) + assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)