diff --git a/dpctl/tensor/_flags.pyx b/dpctl/tensor/_flags.pyx index 54198ad726..0c1beeb025 100644 --- a/dpctl/tensor/_flags.pyx +++ b/dpctl/tensor/_flags.pyx @@ -75,6 +75,12 @@ cdef class Flags: """ return _check_bit(self.flags_, USM_ARRAY_WRITABLE) + @writable.setter + def writable(self, new_val): + if not isinstance(new_val, bool): + raise TypeError("Expecting a boolean value") + self.arr_._set_writable_flag(new_val) + @property def fc(self): """ @@ -129,6 +135,14 @@ cdef class Flags: elif name == "CONTIGUOUS": return self.forc + def __setitem__(self, name, val): + if name in ["WRITABLE", "W"]: + self.writable = val + else: + raise ValueError( + "Only writable ('W' or 'WRITABLE') flag can be set" + ) + def __repr__(self): out = [] for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE": diff --git a/dpctl/tensor/_usmarray.pxd b/dpctl/tensor/_usmarray.pxd index 97d72204cd..603b296103 100644 --- a/dpctl/tensor/_usmarray.pxd +++ b/dpctl/tensor/_usmarray.pxd @@ -72,4 +72,6 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]: cdef dpctl.DPCTLSyclQueueRef get_queue_ref(self) except * cdef dpctl.SyclQueue get_sycl_queue(self) + cdef _set_writable_flag(self, int) + cdef __cythonbufferdefaults__ = {"mode": "strided"} diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index ea57959cb6..a5e9ac7315 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -532,6 +532,12 @@ cdef class usm_ndarray: """ return _flags.Flags(self, self.flags_) + cdef _set_writable_flag(self, int flag): + cdef int arr_fl = self.flags_ + arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag + arr_fl |= (USM_ARRAY_WRITABLE if flag else 0) + self.flags_ = arr_fl + @property def usm_type(self): """ @@ -1390,12 +1396,10 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr): allocation""" return arr.get_offset() + cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag): """Set/unset USM_ARRAY_WRITABLE in the given array `arr`.""" - cdef int arr_fl = arr.flags_ - arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag - arr_fl |= (USM_ARRAY_WRITABLE if flag else 0) - arr.flags_ = arr_fl + arr._set_writable_flag(flag) cdef api object UsmNDArray_MakeSimpleFromMemory( int nd, const Py_ssize_t *shape, int typenum, diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 54ccc85916..45b43383fe 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -57,6 +57,7 @@ def test_allocate_usm_ndarray(shape, usm_type): def test_usm_ndarray_flags(): + get_queue_or_skip() assert dpt.usm_ndarray((5,), dtype="i4").flags.fc assert dpt.usm_ndarray((5, 2), dtype="i4").flags.c_contiguous assert dpt.usm_ndarray((5, 2), dtype="i4", order="F").flags.f_contiguous @@ -68,6 +69,17 @@ def test_usm_ndarray_flags(): (5, 1, 2), dtype="i4", strides=(1, 0, 5) ).flags.f_contiguous assert dpt.usm_ndarray((5, 1, 1), dtype="i4", strides=(1, 0, 1)).flags.fc + x = dpt.empty(5, dtype="u2") + assert x.flags.writable is True + x.flags.writable = False + assert x.flags.writable is False + with pytest.raises(ValueError): + x[:] = 0 + x.flags["W"] = True + assert x.flags.writable is True + x.flags["WRITABLE"] = True + assert x.flags.writable is True + x[:] = 0 @pytest.mark.parametrize(