Skip to content

Commit 8dcec9a

Browse files
committed
Added array support for fill_value for full() function.
1 parent c78974a commit 8dcec9a

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

dpctl/tensor/_ctors.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def full(
726726
dtype=None,
727727
order="C",
728728
device=None,
729-
usm_type="device",
729+
usm_type=None,
730730
sycl_queue=None,
731731
):
732732
"""
@@ -761,10 +761,20 @@ def full(
761761
)
762762
else:
763763
order = order[0].upper()
764-
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
764+
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
765765
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
766-
if dtype is None and isinstance(fill_value, (dpt.usm_ndarray, np.ndarray)):
767-
dtype = fill_value.dtype
766+
767+
if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)):
768+
X = dpt.asarray(
769+
fill_value,
770+
dtype=dtype,
771+
device=device,
772+
usm_type=usm_type,
773+
sycl_queue=sycl_queue,
774+
)
775+
return dpt.broadcast_to(X, sh)
776+
777+
usm_type = usm_type if usm_type is not None else "device"
768778
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
769779
res = dpt.usm_ndarray(
770780
sh,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -992,12 +992,15 @@ def test_full_dtype_inference():
992992
def test_full_fill_array():
993993
q = get_queue_or_skip()
994994

995-
dtype = np.int32
996-
X = dpt.full(10, dpt.usm_ndarray(1, dtype=dtype), sycl_queue=q)
997-
assert dtype == X.dtype
995+
Xnp = np.array([1, 2, 3], dtype=np.int32)
996+
X = dpt.asarray(Xnp, sycl_queue=q)
998997

999-
X = dpt.full(10, np.ndarray(1, dtype=dtype), sycl_queue=q)
1000-
assert dtype == X.dtype
998+
shape = (3, 3)
999+
Y = dpt.full(shape, X)
1000+
Ynp = np.full(shape, Xnp)
1001+
1002+
assert np.array_equal(dpt.asnumpy(Y), Ynp)
1003+
assert Ynp.dtype == Y.dtype
10011004

10021005

10031006
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)