diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index bed326f4b0..f98a09ca72 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -311,7 +311,7 @@ def flip(X, /, *, axis=None): return X[indexer] -def roll(X, /, shift, *, axis=None): +def roll(x, /, shift, *, axis=None): """ roll(x, shift, axis) @@ -343,18 +343,20 @@ def roll(X, /, shift, *, axis=None): `device` attributes as `x` and whose elements are shifted relative to `x`. """ - if not isinstance(X, dpt.usm_ndarray): - raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") - exec_q = X.sycl_queue + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") + exec_q = x.sycl_queue _manager = dputils.SequentialOrderManager[exec_q] if axis is None: shift = operator.index(shift) - dep_evs = _manager.submitted_events res = dpt.empty( - X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q + x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q ) + sz = operator.index(x.size) + shift = (shift % sz) if sz > 0 else 0 + dep_evs = _manager.submitted_events hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d( - src=X, + src=x, dst=res, shift=shift, sycl_queue=exec_q, @@ -362,22 +364,24 @@ def roll(X, /, shift, *, axis=None): ) _manager.add_event_pair(hev, roll_ev) return res - axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True) + axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) broadcasted = np.broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError("'shift' and 'axis' should be scalars or 1D sequences") shifts = [ 0, - ] * X.ndim + ] * x.ndim + shape = x.shape for sh, ax in broadcasted: - shifts[ax] += sh - + n_i = operator.index(shape[ax]) + shifted = shifts[ax] + operator.index(sh) + shifts[ax] = (shifted % n_i) if n_i > 0 else 0 res = dpt.empty( - X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q + x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q ) dep_evs = _manager.submitted_events ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd( - src=X, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs + src=x, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_e, roll_ev) return res diff --git a/dpctl/tensor/libtensor/source/copy_for_roll.cpp b/dpctl/tensor/libtensor/source/copy_for_roll.cpp index 774228c6a7..b624d02882 100644 --- a/dpctl/tensor/libtensor/source/copy_for_roll.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_roll.cpp @@ -326,7 +326,7 @@ copy_usm_ndarray_for_roll_nd(const dpctl::tensor::usm_ndarray &src, // normalize shift parameter to be 0 <= offset < dim py::ssize_t dim = src_shape_ptr[i]; size_t offset = - (shifts[i] > 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim); + (shifts[i] >= 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim); normalized_shifts.push_back(offset); } diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 17262e2141..882a001827 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -657,6 +657,30 @@ def test_roll_2d(data): assert_array_equal(Ynp, dpt.asnumpy(Y)) +def test_roll_out_bounds_shifts(): + "See gh-1857" + get_queue_or_skip() + + x = dpt.arange(4) + y = dpt.roll(x, np.uint64(2**63 + 2)) + expected = dpt.roll(x, 2) + assert dpt.all(y == expected) + + x_empty = x[1:1] + y = dpt.roll(x_empty, 11) + assert y.size == 0 + + x_2d = dpt.reshape(x, (2, 2)) + y = dpt.roll(x_2d, np.uint64(2**63 + 1), axis=1) + expected = dpt.roll(x_2d, 1, axis=1) + assert dpt.all(y == expected) + + x_2d_empty = x_2d[:, 1:1] + y = dpt.roll(x_2d_empty, 3, axis=1) + expected = dpt.empty_like(x_2d_empty) + assert dpt.all(y == expected) + + def test_roll_validation(): get_queue_or_skip()