Skip to content

Commit 0985218

Browse files
committed
Added tests for full() function
1 parent 8dcec9a commit 0985218

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

dpctl/tensor/_ctors.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,18 +762,27 @@ def full(
762762
else:
763763
order = order[0].upper()
764764
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
765-
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
766765

767766
if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)):
767+
if (
768+
isinstance(fill_value, dpt.usm_ndarray)
769+
and sycl_queue is None
770+
and device is None
771+
):
772+
sycl_queue = fill_value.sycl_queue
773+
else:
774+
sycl_queue = normalize_queue_device(
775+
sycl_queue=sycl_queue, device=device
776+
)
768777
X = dpt.asarray(
769778
fill_value,
770779
dtype=dtype,
771-
device=device,
772780
usm_type=usm_type,
773781
sycl_queue=sycl_queue,
774782
)
775783
return dpt.broadcast_to(X, sh)
776784

785+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
777786
usm_type = usm_type if usm_type is not None else "device"
778787
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
779788
res = dpt.usm_ndarray(

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,36 @@ def test_full_dtype_inference():
992992
def test_full_fill_array():
993993
q = get_queue_or_skip()
994994

995-
Xnp = np.array([1, 2, 3], dtype=np.int32)
995+
Xnp = np.array([1, 2, 3], dtype="i4")
996996
X = dpt.asarray(Xnp, sycl_queue=q)
997997

998998
shape = (3, 3)
999999
Y = dpt.full(shape, X)
10001000
Ynp = np.full(shape, Xnp)
10011001

1002+
assert Y.dtype == Ynp.dtype
1003+
assert Y.usm_type == "device"
10021004
assert np.array_equal(dpt.asnumpy(Y), Ynp)
1003-
assert Ynp.dtype == Y.dtype
1005+
1006+
1007+
def test_full_compute_follows_data():
1008+
q1 = get_queue_or_skip()
1009+
q2 = get_queue_or_skip()
1010+
1011+
X = dpt.arange(10, dtype="i4", sycl_queue=q1, usm_type="shared")
1012+
Y = dpt.full(10, X[3])
1013+
1014+
assert Y.dtype == X.dtype
1015+
assert Y.usm_type == X.usm_type
1016+
assert dpctl.utils.get_execution_queue((Y.sycl_queue, X.sycl_queue))
1017+
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="i4"))
1018+
1019+
Y = dpt.full(10, X[3], dtype="f4", sycl_queue=q2, usm_type="host")
1020+
1021+
assert Y.dtype == dpt.dtype("f4")
1022+
assert Y.usm_type == "host"
1023+
assert dpctl.utils.get_execution_queue((Y.sycl_queue, q2))
1024+
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="f4"))
10041025

10051026

10061027
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)