diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 6baa5e5d66..763e1c239a 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -716,7 +716,7 @@ def full( dtype=None, order="C", device=None, - usm_type="device", + usm_type=None, sycl_queue=None, ): """ @@ -750,8 +750,29 @@ def full( "Unrecognized order keyword value, expecting 'F' or 'C'." ) order = order[0].upper() - dpctl.utils.validate_usm_type(usm_type, allow_none=False) + dpctl.utils.validate_usm_type(usm_type, allow_none=True) + + if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)): + if ( + isinstance(fill_value, dpt.usm_ndarray) + and sycl_queue is None + and device is None + ): + sycl_queue = fill_value.sycl_queue + else: + sycl_queue = normalize_queue_device( + sycl_queue=sycl_queue, device=device + ) + X = dpt.asarray( + fill_value, + dtype=dtype, + usm_type=usm_type, + sycl_queue=sycl_queue, + ) + return dpt.broadcast_to(X, sh) + sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) + usm_type = usm_type if usm_type is not None else "device" dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value)) res = dpt.usm_ndarray( sh, diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index e00747bfe9..2a41dfbddc 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -992,6 +992,41 @@ def test_full_dtype_inference(): assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating) +def test_full_fill_array(): + q = get_queue_or_skip() + + Xnp = np.array([1, 2, 3], dtype="i4") + X = dpt.asarray(Xnp, sycl_queue=q) + + shape = (3, 3) + Y = dpt.full(shape, X) + Ynp = np.full(shape, Xnp) + + assert Y.dtype == Ynp.dtype + assert Y.usm_type == "device" + assert np.array_equal(dpt.asnumpy(Y), Ynp) + + +def test_full_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + + X = dpt.arange(10, dtype="i4", sycl_queue=q1, usm_type="shared") + Y = dpt.full(10, X[3]) + + assert Y.dtype == X.dtype + assert Y.usm_type == X.usm_type + assert dpctl.utils.get_execution_queue((Y.sycl_queue, X.sycl_queue)) + assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="i4")) + + Y = dpt.full(10, X[3], dtype="f4", sycl_queue=q2, usm_type="host") + + assert Y.dtype == dpt.dtype("f4") + assert Y.usm_type == "host" + assert dpctl.utils.get_execution_queue((Y.sycl_queue, q2)) + assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="f4")) + + @pytest.mark.parametrize( "dt", _all_dtypes[1:],