Skip to content

Commit 5bfc097

Browse files
Merge pull request #1106 from vlad-perevezentsev/fix_dpt_place_func
Update empty vals check in dpt.place
2 parents 78875e9 + 10c2022 commit 5bfc097

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,12 @@ def place(arr, mask, vals):
293293
raise dpctl.utils.ExecutionPlacementError
294294
if arr.shape != mask.shape or vals.ndim != 1:
295295
raise ValueError("Array sizes are not as required")
296-
if vals.size == 0:
297-
raise ValueError("Cannot insert from an empty array!")
298296
cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q)
299297
nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q)
300298
if nz_count == 0:
301299
return
300+
if vals.size == 0:
301+
raise ValueError("Cannot insert from an empty array!")
302302
if vals.dtype == arr.dtype:
303303
rhs = vals
304304
else:

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,16 @@ def test_place_empty_vals_error():
12111211
dpt.place(x, sel, y)
12121212

12131213

1214+
def test_place_empty_vals_full_false_mask():
1215+
get_queue_or_skip()
1216+
x = dpt.ones(10, dtype="f4")
1217+
y = dpt.empty((0,), dtype=x.dtype)
1218+
sel = dpt.zeros(x.size, dtype="?")
1219+
expected = np.ones(10, dtype=x.dtype)
1220+
dpt.place(x, sel, y)
1221+
assert (dpt.asnumpy(x) == expected).all()
1222+
1223+
12141224
def test_nonzero():
12151225
get_queue_or_skip()
12161226
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))

0 commit comments

Comments
 (0)