From e7e85083c2e430807d0df9ce440161803cb572ed Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Mon, 28 Nov 2022 16:10:42 -0600 Subject: [PATCH] Fixed error in cast dtype for full() function. --- dpctl/tensor/_ctors.py | 8 +++++++- dpctl/tests/test_usm_ndarray_ctor.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 763e1c239a..6e26ecee9b 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -773,7 +773,8 @@ def full( sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) usm_type = usm_type if usm_type is not None else "device" - dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value)) + fill_value_type = type(fill_value) + dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type) res = dpt.usm_ndarray( sh, dtype=dtype, @@ -781,6 +782,11 @@ def full( order=order, buffer_ctor_kwargs={"queue": sycl_queue}, ) + if fill_value_type in [float, complex] and np.issubdtype(dtype, np.integer): + fill_value = int(fill_value.real) + elif fill_value_type is complex and np.issubdtype(dtype, np.floating): + fill_value = fill_value.real + hev, _ = ti._full_usm_ndarray(fill_value, res, sycl_queue) hev.wait() return res diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 2a41dfbddc..b0bb429841 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -991,6 +991,10 @@ def test_full_dtype_inference(): assert np.issubdtype(dpt.full(10, 12.3).dtype, np.floating) assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating) + assert np.issubdtype(dpt.full(10, 12.3, dtype=int).dtype, np.integer) + assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=int).dtype, np.integer) + assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=float).dtype, np.floating) + def test_full_fill_array(): q = get_queue_or_skip()