diff --git a/.flake8 b/.flake8 index 5a20d20b6e..798db28c42 100644 --- a/.flake8 +++ b/.flake8 @@ -25,6 +25,7 @@ per-file-ignores = dpctl/program/_program.pyx: E999, E225, E226, E227 dpctl/tensor/_usmarray.pyx: E999, E225, E226, E227 dpctl/tensor/_dlpack.pyx: E999, E225, E226, E227 + dpctl/tensor/_flags.pyx: E999, E225, E226, E227 dpctl/tensor/numpy_usm_shared.py: F821 dpctl/tests/_cython_api.pyx: E999, E225, E227, E402 dpctl/utils/_compute_follows_data.pyx: E999, E225, E227 diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 495f66616e..af4266e783 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -261,18 +261,18 @@ def copy(usm_ary, order="K"): elif order == "F": copy_order = order elif order == "A": - if usm_ary.flags & 2: + if usm_ary.flags.f_contiguous: copy_order = "F" elif order == "K": - if usm_ary.flags & 2: + if usm_ary.flags.f_contiguous: copy_order = "F" else: raise ValueError( "Unrecognized value of the order keyword. " "Recognized values are 'A', 'C', 'F', or 'K'" ) - c_contig = usm_ary.flags & 1 - f_contig = usm_ary.flags & 2 + c_contig = usm_ary.flags.c_contiguous + f_contig = usm_ary.flags.f_contiguous R = dpt.usm_ndarray( usm_ary.shape, dtype=usm_ary.dtype, @@ -325,8 +325,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): ary_dtype, newdtype, casting ) ) - c_contig = usm_ary.flags & 1 - f_contig = usm_ary.flags & 2 + c_contig = usm_ary.flags.c_contiguous + f_contig = usm_ary.flags.f_contiguous needs_copy = copy or not (ary_dtype == target_dtype) if not needs_copy and (order != "K"): needs_copy = (c_contig and order not in ["A", "C"]) or ( @@ -339,10 +339,10 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): elif order == "F": copy_order = order elif order == "A": - if usm_ary.flags & 2: + if usm_ary.flags.f_contiguous: copy_order = "F" elif order == "K": - if usm_ary.flags & 2: + if usm_ary.flags.f_contiguous: copy_order = "F" else: raise ValueError( diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 5d2ca8e303..a21d8d7459 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -133,9 +133,9 @@ def _asarray_from_usm_ndarray( # sycl_queue is unchanged can_zero_copy = can_zero_copy and copy_q is usm_ndary.sycl_queue # order is unchanged - c_contig = usm_ndary.flags & 1 - f_contig = usm_ndary.flags & 2 - fc_contig = usm_ndary.flags & 3 + c_contig = usm_ndary.flags.c_contiguous + f_contig = usm_ndary.flags.f_contiguous + fc_contig = usm_ndary.flags.forc if can_zero_copy: if order == "C" and c_contig: pass @@ -1130,7 +1130,7 @@ def tril(X, k=0): k = operator.index(k) # F_CONTIGUOUS = 2 - order = "F" if (X.flags & 2) else "C" + order = "F" if (X.flags.f_contiguous) else "C" shape = X.shape nd = X.ndim @@ -1171,7 +1171,7 @@ def triu(X, k=0): k = operator.index(k) # F_CONTIGUOUS = 2 - order = "F" if (X.flags & 2) else "C" + order = "F" if (X.flags.f_contiguous) else "C" shape = X.shape nd = X.ndim diff --git a/dpctl/tensor/_flags.pyx b/dpctl/tensor/_flags.pyx new file mode 100644 index 0000000000..74774e3f2c --- /dev/null +++ b/dpctl/tensor/_flags.pyx @@ -0,0 +1,111 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# distutils: language = c++ +# cython: language_level=3 +# cython: linetrace=True + +from libcpp cimport bool as cpp_bool + +from dpctl.tensor._usmarray cimport ( + USM_ARRAY_C_CONTIGUOUS, + USM_ARRAY_F_CONTIGUOUS, + USM_ARRAY_WRITEABLE, + usm_ndarray, +) + + +cdef cpp_bool _check_bit(int flag, int mask): + return (flag & mask) == mask + + +cdef class Flags: + """Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`.""" + cdef int flags_ + cdef usm_ndarray arr_ + + def __cinit__(self, usm_ndarray arr, int flags): + self.arr_ = arr + self.flags_ = flags + + @property + def flags(self): + return self.flags_ + + @property + def c_contiguous(self): + return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS) + + @property + def f_contiguous(self): + return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS) + + @property + def writable(self): + return _check_bit(self.flags_, USM_ARRAY_WRITEABLE) + + @property + def fc(self): + return ( + _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS) + and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS) + ) + + @property + def forc(self): + return ( + _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS) + or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS) + ) + + @property + def fnc(self): + return ( + _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS) + and not _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS) + ) + + @property + def contiguous(self): + return self.forc + + def __getitem__(self, name): + if name in ["C_CONTIGUOUS", "C"]: + return self.c_contiguous + elif name in ["F_CONTIGUOUS", "F"]: + return self.f_contiguous + elif name == "WRITABLE": + return self.writable + elif name == "FC": + return self.fc + elif name == "CONTIGUOUS": + return self.forc + + def __repr__(self): + out = [] + for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE": + out.append(" {} : {}".format(name, self[name])) + return '\n'.join(out) + + def __eq__(self, other): + cdef Flags other_ + if isinstance(other, self.__class__): + other_ = other + return self.flags_ == other_.flags_ + elif isinstance(other, int): + return self.flags_ == other + else: + return False diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 48f60b76b8..2e12af2d94 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -33,6 +33,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem cimport dpctl.tensor._dlpack as c_dlpack +import dpctl.tensor._flags as _flags include "_stride_utils.pxi" include "_types.pxi" @@ -503,9 +504,9 @@ cdef class usm_ndarray: @property def flags(self): """ - Currently returns integer whose bits correspond to the flags. + Returns dpctl.tensor._flags object. """ - return self.flags_ + return _flags.Flags(self, self.flags_) @property def usm_type(self): @@ -663,7 +664,7 @@ cdef class usm_ndarray: strides=self.strides, offset=self.get_offset() ) - res.flags_ = self.flags + res.flags_ = self.flags.flags return res else: nbytes = self.usm_data.nbytes @@ -678,7 +679,7 @@ cdef class usm_ndarray: strides=self.strides, offset=self.get_offset() ) - res.flags_ = self.flags + res.flags_ = self.flags.flags return res def _set_namespace(self, mod): diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index e4d41a11e3..5a1fca5470 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type): def test_usm_ndarray_flags(): - assert dpt.usm_ndarray((5,)).flags == 3 - assert dpt.usm_ndarray((5, 2)).flags == 1 - assert dpt.usm_ndarray((5, 2), order="F").flags == 2 - assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2 - assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1 - assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2 - assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3 + assert dpt.usm_ndarray((5,)).flags.fc + assert dpt.usm_ndarray((5, 2)).flags.c_contiguous + assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous + assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous + assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous + assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous + assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fc @pytest.mark.parametrize( @@ -465,7 +465,7 @@ def test_pyx_capi_get_flags(): fn_restype=ctypes.c_int, ) flags = get_flags_fn(X) - assert type(flags) is int and flags == X.flags + assert type(flags) is int and X.flags == flags def test_pyx_capi_get_offset(): @@ -753,7 +753,7 @@ def relaxed_strides_equal(st1, st2, sh): X.shape = sh_f assert X.shape == sh_f assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) - assert X.flags & 1, "reshaped array expected to be C-contiguous" + assert X.flags.c_contiguous, "reshaped array expected to be C-contiguous" sh_s = ( 2, @@ -1516,3 +1516,18 @@ def test_common_arg_validation(): dpt.triu(X) with pytest.raises(TypeError): dpt.meshgrid(X) + + +def test_flags(): + x = dpt.empty(tuple(), "i4") + f = x.flags + f.__repr__() + f.c_contiguous + f.f_contiguous + f.contiguous + f.fc + f.fnc + f.forc + f.writable + # check comparison with generic types + f == Ellipsis