diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c06affac..586b65652c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866) ### Fixed +* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877) ### Maintenance diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 5defd154df..f279052f94 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes): target_dev = d inspected = True + if not dtypes and weak_dtypes: + dtypes.append(weak_dtypes[0].get()) + if not (has_fp16 and has_fp64): for dt in dtypes: if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 882a001827..4bfd6dab9f 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -15,6 +15,8 @@ # limitations under the License. +import itertools + import numpy as np import pytest from numpy.testing import assert_, assert_array_equal, assert_raises_regex @@ -1555,3 +1557,26 @@ def test_repeat_0_size(): res = dpt.repeat(x, repetitions, axis=1) axis_sz = 2 * x.shape[1] assert res.shape == (0, axis_sz, 0) + + +def test_result_type_bug_1874(): + py_sc = True + np_sc = np.asarray([py_sc])[0] + dts_bool = [py_sc, np_sc] + py_sc = int(1) + np_sc = np.asarray([py_sc])[0] + dts_ints = [py_sc, np_sc] + dts_floats = [float(1), np.float64(1)] + dts_complexes = [complex(1), np.complex128(1)] + + # iterate over two categories + for dts1, dts2 in itertools.product( + [dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2 + ): + res_dts = [] + # iterate over Python scalar/NumPy scalar choices within categories + for dt1, dt2 in itertools.product(dts1, dts2): + res_dt = dpt.result_type(dt1, dt2) + res_dts.append(res_dt) + # check that all results are the same + assert res_dts and all(res_dts[0] == el for el in res_dts[1:])