diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 654d31e136..efb7f9dabf 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -29,8 +29,16 @@ """ +from dpctl.tensor._copy_utils import astype +from dpctl.tensor._copy_utils import copy_from_numpy as from_numpy +from dpctl.tensor._copy_utils import copy_to_numpy as to_numpy +from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray __all__ = [ "usm_ndarray", + "astype", + "reshape", + "from_numpy", + "to_numpy", ] diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py new file mode 100644 index 0000000000..a9efad3a80 --- /dev/null +++ b/dpctl/tensor/_copy_utils.py @@ -0,0 +1,306 @@ +import operator + +import numpy as np + +import dpctl.memory as dpm +import dpctl.tensor as dpt + + +def contract_iter2(shape, strides1, strides2): + p = np.argsort(np.abs(strides1))[::-1] + sh = [operator.index(shape[i]) for i in p] + disp1 = 0 + disp2 = 0 + st1 = [] + st2 = [] + contractable = True + for i in p: + this_stride1 = operator.index(strides1[i]) + this_stride2 = operator.index(strides2[i]) + if this_stride1 < 0 and this_stride2 < 0: + disp1 += this_stride1 * (shape[i] - 1) + this_stride1 = -this_stride1 + disp2 += this_stride2 * (shape[i] - 1) + this_stride2 = -this_stride2 + if this_stride1 < 0 or this_stride2 < 0: + contractable = False + st1.append(this_stride1) + st2.append(this_stride2) + while contractable: + changed = False + k = len(sh) - 1 + for i in range(k): + step1 = st1[i + 1] + jump1 = st1[i] - (sh[i + 1] - 1) * step1 + step2 = st2[i + 1] + jump2 = st2[i] - (sh[i + 1] - 1) * step2 + if jump1 == step1 and jump2 == step2: + changed = True + st1[i:-1] = st1[i + 1 :] + st2[i:-1] = st2[i + 1 :] + sh[i] *= sh[i + 1] + sh[i + 1 : -1] = sh[i + 2 :] + sh = sh[:-1] + st1 = st1[:-1] + st2 = st2[:-1] + break + if not changed: + break + return (sh, st1, disp1, st2, disp2) + + +def has_memory_overlap(x1, x2): + m1 = dpm.as_usm_memory(x1) + m2 = dpm.as_usm_memory(x2) + if m1.sycl_device == m2.sycl_device: + p1_beg = m1._pointer + p1_end = p1_beg + m1.nbytes + p2_beg = m2._pointer + p2_end = p2_beg + m2.nbytes + return p1_beg > p2_end or p2_beg < p1_end + else: + return False + + +def copy_to_numpy(ary): + if type(ary) is not dpt.usm_ndarray: + raise TypeError + h = ary.usm_data.copy_to_host().view(ary.dtype) + itsz = ary.itemsize + strides_bytes = tuple(si * itsz for si in ary.strides) + offset = ary.__sycl_usm_array_interface__.get("offset", 0) * itsz + return np.ndarray( + ary.shape, + dtype=ary.dtype, + buffer=h, + strides=strides_bytes, + offset=offset, + ) + + +def copy_from_numpy(np_ary, usm_type="device", queue=None): + "Copies numpy array `np_ary` into a new usm_ndarray" + # This may peform a copy to meet stated requirements + Xnp = np.require(np_ary, requirements=["A", "O", "C", "E"]) + if queue: + ctor_kwargs = {"queue": queue} + else: + ctor_kwargs = dict() + Xusm = dpt.usm_ndarray( + Xnp.shape, + dtype=Xnp.dtype, + buffer=usm_type, + buffer_ctor_kwargs=ctor_kwargs, + ) + Xusm.usm_data.copy_from_host(Xnp.reshape((-1)).view("u1")) + return Xusm + + +def copy_from_numpy_into(dst, np_ary): + if not isinstance(np_ary, np.ndarray): + raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary))) + src_ary = np.broadcast_to(np.asarray(np_ary, dtype=dst.dtype), dst.shape) + for i in range(dst.size): + mi = np.unravel_index(i, dst.shape) + host_buf = np.array(src_ary[mi], ndmin=1).view("u1") + usm_mem = dpm.as_usm_memory(dst[mi]) + usm_mem.copy_from_host(host_buf) + + +class Dummy: + def __init__(self, iface): + self.__sycl_usm_array_interface__ = iface + + +def copy_same_dtype(dst, src): + if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray: + raise TypeError + + if dst.shape != src.shape: + raise ValueError + + if dst.dtype != src.dtype: + raise ValueError + + # check that memory regions do not overlap + if has_memory_overlap(dst, src): + tmp = copy_to_numpy(src) + copy_from_numpy_into(dst, tmp) + return + + if (dst.flags & 1) and (src.flags & 1): + dst_mem = dpm.as_usm_memory(dst) + src_mem = dpm.as_usm_memory(src) + dst_mem.copy_from_device(src_mem) + return + + # simplify strides + sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2( + dst.shape, dst.strides, src.strides + ) + # sh_i, dst_st, dst_disp, src_st, src_disp = ( + # dst.shape, dst.strides, 0, src.strides, 0 + # ) + src_iface = src.__sycl_usm_array_interface__ + dst_iface = dst.__sycl_usm_array_interface__ + src_iface["shape"] = tuple() + src_iface.pop("strides", None) + dst_iface["shape"] = tuple() + dst_iface.pop("strides", None) + dst_disp = dst_disp + dst_iface.get("offset", 0) + src_disp = src_disp + src_iface.get("offset", 0) + for i in range(dst.size): + mi = np.unravel_index(i, sh_i) + dst_offset = dst_disp + src_offset = src_disp + for j, dst_stj, src_stj in zip(mi, dst_st, src_st): + dst_offset = dst_offset + j * dst_stj + src_offset = src_offset + j * src_stj + dst_iface["offset"] = dst_offset + src_iface["offset"] = src_offset + msrc = dpm.as_usm_memory(Dummy(src_iface)) + mdst = dpm.as_usm_memory(Dummy(dst_iface)) + mdst.copy_from_device(msrc) + + +def copy_same_shape(dst, src): + if src.dtype == dst.dtype: + copy_same_dtype(dst, src) + + # check that memory regions do not overlap + if has_memory_overlap(dst, src): + tmp = copy_to_numpy(src) + tmp = tmp.astype(dst.dtype) + copy_from_numpy_into(dst, tmp) + return + + # simplify strides + sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2( + dst.shape, dst.strides, src.strides + ) + # sh_i, dst_st, dst_disp, src_st, src_disp = ( + # dst.shape, dst.strides, 0, src.strides, 0 + # ) + src_iface = src.__sycl_usm_array_interface__ + dst_iface = dst.__sycl_usm_array_interface__ + src_iface["shape"] = tuple() + src_iface.pop("strides", None) + dst_iface["shape"] = tuple() + dst_iface.pop("strides", None) + dst_disp = dst_disp + dst_iface.get("offset", 0) + src_disp = src_disp + src_iface.get("offset", 0) + for i in range(dst.size): + mi = np.unravel_index(i, sh_i) + dst_offset = dst_disp + src_offset = src_disp + for j, dst_stj, src_stj in zip(mi, dst_st, src_st): + dst_offset = dst_offset + j * dst_stj + src_offset = src_offset + j * src_stj + dst_iface["offset"] = dst_offset + src_iface["offset"] = src_offset + msrc = dpm.as_usm_memory(Dummy(src_iface)) + mdst = dpm.as_usm_memory(Dummy(dst_iface)) + tmp = msrc.copy_to_host().view(src.dtype) + tmp = tmp.astype(dst.dtype) + mdst.copy_from_host(tmp.view("u1")) + + +def copy_from_usm_ndarray_to_usm_ndarray(dst, src): + if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray: + raise TypeError + + if dst.ndim == src.ndim and dst.shape == src.shape: + copy_same_shape(dst, src) + + try: + common_shape = np.broadcast_shapes(dst.shape, src.shape) + except ValueError: + raise ValueError + + if dst.size < src.size: + raise ValueError + + if len(common_shape) > dst.ndim: + ones_count = len(common_shape) - dst.ndim + for k in range(ones_count): + if common_shape[k] != 1: + raise ValueError + common_shape = common_shape[ones_count:] + + if src.ndim < len(common_shape): + new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides + src_same_shape = dpt.usm_ndarray( + common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides + ) + else: + src_same_shape = src + + copy_same_shape(dst, src_same_shape) + + +def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): + """ + astype(usm_array, new_dtype, order="K", casting="unsafe", copy=True) + + Returns a copy of the array, cast to a specified type. + + A view can be returned, if possible, when `copy=False` is used. + """ + if not isinstance(usm_ary, dpt.usm_ndarray): + return TypeError( + "Expected object of type dpt.usm_ndarray, got {}".format( + type(usm_ary) + ) + ) + ary_dtype = usm_ary.dtype + target_dtype = np.dtype(newdtype) + if not np.can_cast(ary_dtype, target_dtype, casting=casting): + raise TypeError( + "Can not cast from {} to {} according to rule {}".format( + ary_dtype, newdtype, casting + ) + ) + c_contig = usm_ary.flags & 1 + f_contig = usm_ary.flags & 2 + needs_copy = copy or not (ary_dtype == target_dtype) + if not needs_copy and (order != "K"): + needs_copy = (c_contig and order not in ["A", "C"]) or ( + f_contig and order not in ["A", "F"] + ) + if needs_copy: + copy_order = "C" + if order == "C": + pass + elif order == "F": + copy_order = order + elif order == "A": + if usm_ary.flags & 2: + copy_order = "F" + elif order == "K": + if usm_ary.flags & 2: + copy_order = "F" + R = dpt.usm_ndarray( + usm_ary.shape, + dtype=target_dtype, + buffer=usm_ary.usm_type, + order=copy_order, + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, + ) + if order == "K" and (not c_contig and not f_contig): + original_strides = usm_ary.strides + ind = sorted( + range(usm_ary.ndim), + key=lambda i: abs(original_strides[i]), + reverse=True, + ) + new_strides = tuple(R.strides[ind[i]] for i in ind) + R = dpt.usm_ndarray( + usm_ary.shape, + dtype=target_dtype, + buffer=R.usm_data, + strides=new_strides, + ) + copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary) + return R + else: + return usm_ary diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py new file mode 100644 index 0000000000..9f94351357 --- /dev/null +++ b/dpctl/tensor/_reshape.py @@ -0,0 +1,114 @@ +import operator + +import numpy as np + +import dpctl.tensor as dpt +from dpctl.tensor._copy_utils import copy_same_dtype + + +def _make_unit_indexes(shape): + """ + Construct a diagonal matrix with with one on the diagonal + except if the corresponding element of shape is 1. + """ + nd = len(shape) + mi = np.zeros((nd, nd), dtype="u4") + for i, dim in enumerate(shape): + mi[i, i] = 1 if dim > 1 else 0 + return mi + + +def reshaped_strides(old_sh, old_sts, new_sh, order="C"): + """ + When reshaping array with `old_sh` shape and `old_sts` strides + into the new shape `new_sh`, returns the new stride if the reshape + can be a view, otherwise returns `None`. + """ + eye_new_mi = _make_unit_indexes(new_sh) + new_sts = [ + sum( + st_i * ind_i + for st_i, ind_i in zip( + old_sts, np.unravel_index(flat_index, old_sh, order=order) + ) + ) + for flat_index in [ + np.ravel_multi_index(unitvec, new_sh, order=order) + for unitvec in eye_new_mi + ] + ] + eye_old_mi = _make_unit_indexes(old_sh) + check_sts = [ + sum( + st_i * ind_i + for st_i, ind_i in zip( + new_sts, np.unravel_index(flat_index, new_sh, order=order) + ) + ) + for flat_index in [ + np.ravel_multi_index(unitvec, old_sh, order=order) + for unitvec in eye_old_mi + ] + ] + valid = all( + [check_st == old_st for check_st, old_st in zip(check_sts, old_sts)] + ) + return new_sts if valid else None + + +def reshape(X, newshape, order="C"): + """ + reshape(X: usm_ndarray, newshape: tuple, order="C") -> usm_ndarray + + Reshapes given usm_ndarray into new shape. Returns a view, if possible, + a copy otherwise. Memory layout of the copy is controlled by order keyword. + """ + if type(X) is not dpt.usm_ndarray: + raise TypeError + if not isinstance(newshape, (list, tuple)): + newshape = (newshape,) + if order not in ["C", "F"]: + raise ValueError( + f"Keyword 'order' not recognized. Expecting 'C' or 'F', got {order}" + ) + newshape = [operator.index(d) for d in newshape] + negative_ones_count = 0 + for i in range(len(newshape)): + if newshape[i] == -1: + negative_ones_count = negative_ones_count + 1 + if (newshape[i] < -1) or negative_ones_count > 1: + raise ValueError( + "Target shape should have at most 1 negative " + "value which can only be -1" + ) + if negative_ones_count: + v = X.size // (-np.prod(newshape)) + newshape = [v if d == -1 else d for d in newshape] + if X.size != np.prod(newshape): + raise ValueError("Can not reshape into {}".format(newshape)) + newsts = reshaped_strides(X.shape, X.strides, newshape, order=order) + if newsts is None: + # must perform a copy + flat_res = dpt.usm_ndarray( + (X.size,), + dtype=X.dtype, + buffer=X.usm_type, + buffer_ctor_kwargs={"queue": X.sycl_queue}, + order=order, + ) + for i in range(X.size): + copy_same_dtype( + flat_res[i], X[np.unravel_index(i, X.shape, order=order)] + ) + return dpt.usm_ndarray( + tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order + ) + else: + # can form a view + return dpt.usm_ndarray( + newshape, + dtype=X.dtype, + buffer=X, + strides=tuple(newsts), + offset=X.__sycl_usm_array_interface__.get("offset", 0), + ) diff --git a/dpctl/tensor/_stride_utils.pxi b/dpctl/tensor/_stride_utils.pxi index a3fe92579b..37d5a366b7 100644 --- a/dpctl/tensor/_stride_utils.pxi +++ b/dpctl/tensor/_stride_utils.pxi @@ -138,17 +138,28 @@ cdef int _from_input_shape_strides( max_disp[0] = max_shift if max_shift == min_shift + (elem_count - 1): if nd == 1: - contig[0] = USM_ARRAY_C_CONTIGUOUS + if strides_arr[0] == 1: + contig[0] = USM_ARRAY_C_CONTIGUOUS + else: + contig[0] = 0 return 0 for i in range(0, nd - 1): if all_incr: - all_incr = strides_arr[i] < strides_arr[i + 1] + all_incr = ( + (strides_arr[i] > 0) and + (strides_arr[i+1] > 0) and + (strides_arr[i] <= strides_arr[i + 1]) + ) if all_decr: - all_decr = strides_arr[i] > strides_arr[i + 1] + all_decr = ( + (strides_arr[i] > 0) and + (strides_arr[i+1] > 0) and + (strides_arr[i] >= strides_arr[i + 1]) + ) if all_incr: - contig[0] = USM_ARRAY_C_CONTIGUOUS - elif all_decr: contig[0] = USM_ARRAY_F_CONTIGUOUS + elif all_decr: + contig[0] = USM_ARRAY_C_CONTIGUOUS else: contig[0] = 0 return 0 diff --git a/dpctl/tensor/_usmarray.pxd b/dpctl/tensor/_usmarray.pxd index 1063e9135e..e0d53b05af 100644 --- a/dpctl/tensor/_usmarray.pxd +++ b/dpctl/tensor/_usmarray.pxd @@ -8,16 +8,35 @@ cdef public int USM_ARRAY_C_CONTIGUOUS cdef public int USM_ARRAY_F_CONTIGUOUS cdef public int USM_ARRAY_WRITEABLE - -cdef public api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]: +cdef public int UAR_BOOL +cdef public int UAR_BYTE +cdef public int UAR_UBYTE +cdef public int UAR_SHORT +cdef public int UAR_USHORT +cdef public int UAR_INT +cdef public int UAR_UINT +cdef public int UAR_LONG +cdef public int UAR_ULONG +cdef public int UAR_LONGLONG +cdef public int UAR_ULONGLONG +cdef public int UAR_FLOAT +cdef public int UAR_DOUBLE +cdef public int UAR_CFLOAT +cdef public int UAR_CDOUBLE +cdef public int UAR_TYPE_SENTINEL +cdef public int UAR_HALF + + +cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]: # data fields cdef char* data_ - cdef readonly int nd_ + cdef int nd_ cdef Py_ssize_t *shape_ cdef Py_ssize_t *strides_ - cdef readonly int typenum_ - cdef readonly int flags_ - cdef readonly object base_ + cdef int typenum_ + cdef int flags_ + cdef object base_ + cdef object array_namespace_ # make usm_ndarray weak-referenceable cdef object __weakref__ diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 148847ba2e..38b25b5dce 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -18,6 +18,8 @@ # cython: language_level=3 # cython: linetrace=True +import sys + import numpy as np import dpctl @@ -31,17 +33,63 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem - -cdef extern from "usm_array.hpp" namespace "usm_array": - cdef cppclass usm_array: - usm_array(char *, int, size_t*, Py_ssize_t *, - int, int, c_dpctl.DPCTLSyclQueueRef) except + - - include "_stride_utils.pxi" include "_types.pxi" include "_slicing.pxi" + +def _dispatch_unary_elementwise(ary, name): + try: + mod = ary.__array_namespace__() + except AttributeError: + return NotImplemented + if mod is None and "dpnp" in sys.modules: + fn = getattr(sys.modules["dpnp"], name) + if callable(fn): + return fn(ary) + elif hasattr(mod, name): + fn = getattr(mod, name) + if callable(fn): + return fn(ary) + + return NotImplemented + + +def _dispatch_binary_elementwise(ary, name, other): + try: + mod = ary.__array_namespace__() + except AttributeError: + return NotImplemented + if mod is None and "dpnp" in sys.modules: + fn = getattr(sys.modules["dpnp"], name) + if callable(fn): + return fn(ary, other) + elif hasattr(mod, name): + fn = getattr(mod, name) + if callable(fn): + return fn(ary, other) + + return NotImplemented + + +def _dispatch_binary_elementwise2(other, name, ary): + try: + mod = ary.__array_namespace__() + except AttributeError: + return NotImplemented + mod = ary.__array_namespace__() + if mod is None and "dpnp" in sys.modules: + fn = getattr(sys.modules["dpnp"], name) + if callable(fn): + return fn(other, ary) + elif hasattr(mod, name): + fn = getattr(mod, name) + if callable(fn): + return fn(other, ary) + + return NotImplemented + + cdef class InternalUSMArrayError(Exception): """ A InternalError exception is raised when internal @@ -55,7 +103,8 @@ cdef class usm_ndarray: usm_ndarray( shape, dtype="|f8", strides=None, buffer='device', offset=0, order='C', - buffer_ctor_kwargs=dict() + buffer_ctor_kwargs=dict(), + array_namespace=None ) See :class:`dpctl.memory.MemoryUSMShared` for allowed @@ -76,6 +125,7 @@ cdef class usm_ndarray: Initializes member fields """ self.base_ = None + self.array_namespace_ = None self.nd_ = -1 self.data_ = 0 self.shape_ = 0 @@ -106,13 +156,16 @@ cdef class usm_ndarray: order=('C' if (self.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F') ) res.flags_ = self.flags_ + res.array_namespace_ = self.array_namespace_ if (res.data_ != self.data_): raise InternalUSMArrayError( "Data pointers of cloned and original objects are different.") return res def __cinit__(self, shape, dtype="|f8", strides=None, buffer='device', - Py_ssize_t offset=0, order='C', buffer_ctor_kwargs=dict()): + Py_ssize_t offset=0, order='C', + buffer_ctor_kwargs=dict(), + array_namespace=None): """ strides and offset must be given in units of array elements. buffer can be strings ('device'|'shared'|'host' to allocate new memory) @@ -208,10 +261,15 @@ cdef class usm_ndarray: self.typenum_ = typenum self.flags_ = contig_flag self.nd_ = nd + self.array_namespace_ = array_namespace def __dealloc__(self): self._cleanup() + @property + def _pointer(self): + return self.get_data() + cdef Py_ssize_t get_offset(self) except *: cdef char *mem_ptr = NULL cdef char *ary_ptr = self.get_data() @@ -303,7 +361,12 @@ cdef class usm_ndarray: cdef char *mem_ptr = NULL cdef char *ary_ptr = NULL if (not isinstance(self.base_, dpmem._memory._Memory)): - raise ValueError("Invalid instance of usm_ndarray ecountered") + raise InternalUSMArrayError( + "Invalid instance of usm_ndarray ecountered. " + "Private field base_ has an unexpected type {}.".format( + type(self.base_) + ) + ) ary_iface = self.base_.__sycl_usm_array_interface__ mem_ptr = ( ary_iface['data'][0]) ary_ptr = ( self.data_) @@ -318,8 +381,9 @@ cdef class usm_ndarray: elif (self.flags_ & USM_ARRAY_F_CONTIGUOUS): ary_iface['strides'] = _f_contig_strides(self.nd_, self.shape_) else: - raise ValueError("USM Array is not contiguous and " - "has empty strides") + raise InternalUSMArrayError( + "USM Array is not contiguous and has empty strides" + ) ary_iface['typestr'] = _make_typestr(self.typenum_) byte_offset = ary_ptr - mem_ptr item_size = self.get_itemsize() @@ -342,7 +406,7 @@ cdef class usm_ndarray: """ Gives USM memory object underlying usm_array instance. """ - return self.base_ + return self.get_base() @property def shape(self): @@ -355,6 +419,68 @@ cdef class usm_ndarray: else: return tuple() + @shape.setter + def shape(self, new_shape): + """ + Setting shape is only allowed when reshaping to the requested + dimensions can be returned as view. Use `dpctl.tensor.reshape` + to reshape the array in all other cases. + """ + cdef int new_nd = -1 + cdef Py_ssize_t nelems = -1 + cdef int err = 0 + cdef Py_ssize_t min_disp = 0 + cdef Py_ssize_t max_disp = 0 + cdef int contig_flag = 0 + cdef Py_ssize_t *shape_ptr = NULL + cdef Py_ssize_t *strides_ptr = NULL + import operator + + from ._reshape import reshaped_strides + + new_nd = len(new_shape) + try: + new_shape = tuple(operator.index(dim) for dim in new_shape) + except TypeError: + raise TypeError( + "Target shape must be a finite iterable of integers" + ) + if not np.prod(new_shape) == shape_to_elem_count(self.nd_, self.shape_): + raise TypeError( + f"Can not reshape array of size {self.size} into {new_shape}" + ) + new_strides = reshaped_strides( + self.shape, + self.strides, + new_shape + ) + if new_strides is None: + raise AttributeError( + "Incompatible shape for in-place modification. " + "Use `reshape()` to make a copy with the desired shape." + ) + err = _from_input_shape_strides( + new_nd, new_shape, new_strides, + self.get_itemsize(), + b"C", + &shape_ptr, &strides_ptr, + &nelems, &min_disp, &max_disp, &contig_flag + ) + if (err == 0): + if (self.shape_): + PyMem_Free(self.shape_) + if (self.strides_): + PyMem_Free(self.strides_) + print(contig_flag) + self.flags_ = contig_flag + self.nd_ = new_nd + self.shape_ = shape_ptr + self.strides_ = strides_ptr + else: + raise InternalUSMArrayError( + "Encountered in shape setter, error code {err}".format(err) + ) + @property def strides(self): """ @@ -489,8 +615,53 @@ cdef class usm_ndarray: offset=_meta[2] ) res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE) + res.array_namespace_ = self.array_namespace_ return res + def to_device(self, target_device): + """ + Transfer array to target device + """ + d = Device.create_device(target_device) + if (d.sycl_device == self.sycl_device): + return self + elif (d.sycl_context == self.sycl_context): + res = usm_ndarray( + self.shape, + self.dtype, + buffer=self.usm_data, + strides=self.strides, + offset=self.get_offset() + ) + res.flags_ = self.flags + return res + else: + nbytes = self.usm_data.nbytes + new_buffer = type(self.usm_data)( + nbytes, queue=d.sycl_queue + ) + new_buffer.copy_from_device(self.usm_data) + res = usm_ndarray( + self.shape, + self.dtype, + buffer=new_buffer, + strides=self.strides, + offset=self.get_offset() + ) + res.flags_ = self.flags + return res + + def _set_namespace(self, mod): + """ Sets array namespace to given module `mod`. """ + self.array_namespace_ = mod + + def __array_namespace__(self, api_version=None): + """ + Returns array namespace, member functions of which + implement data API. + """ + return self.array_namespace_ + def __bool__(self): if self.size == 1: mem_view = dpmem.as_usm_memory(self) @@ -539,38 +710,305 @@ cdef class usm_ndarray: raise IndexError("only integer arrays are valid indices") - def to_device(self, target_device): + def __abs__(self): + return _dispatch_unary_elementwise(self, "abs") + + def __add__(first, other): """ - Transfer array to target device + Cython 0.* never calls `__radd__`, always calls `__add__` + but first argument need not be an instance of this class, + so dispatching is needed. + + This changes in Cython 3.0, where first is guaranteed to + be `self`. + + [1] http://docs.cython.org/en/latest/src/userguide/special_methods.html """ - d = Device.create_device(target_device) - if (d.sycl_device == self.sycl_device): - return self - elif (d.sycl_context == self.sycl_context): - res = usm_ndarray( - self.shape, - self.dtype, - buffer=self.usm_data, - strides=self.strides, - offset=self.get_offset() - ) - res.flags_ = self.flags - return res + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "add", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "add", other) + return NotImplemented + + def __and__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "logical_and", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "logical_and", other) + return NotImplemented + + def __dlpack__(self, stream=None): + return NotImplemented + + def __dlpack_device__(self): + return NotImplemented + + def __eq__(self, other): + return _dispatch_binary_elementwise(self, "equal", other) + + def __floordiv__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "floor_divide", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "floor_divide", other) + return NotImplemented + + def __ge__(self, other): + return _dispatch_binary_elementwise(self, "greater_equal", other) + + def __gt__(self, other): + return _dispatch_binary_elementwise(self, "greater", other) + + def __invert__(self): + return _dispatch_unary_elementwise(self, "invert") + + def __le__(self, other): + return _dispatch_binary_elementwise(self, "less_equal", other) + + def __len__(self): + if (self.nd_): + return self.shape[0] else: - nbytes = self.usm_data.nbytes - new_buffer = type(self.usm_data)( - nbytes, queue=d.sycl_queue - ) - new_buffer.copy_from_device(self.usm_data) - res = usm_ndarray( - self.shape, - self.dtype, - buffer=new_buffer, - strides=self.strides, - offset=self.get_offset() - ) - res.flags_ = self.flags + raise TypeError("len() of unsized object") + + def __lshift__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "left_shift", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "left_shift", other) + return NotImplemented + + def __lt__(self, other): + return _dispatch_binary_elementwise(self, "less", other) + + def __matmul__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "matmul", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "matmul", other) + return NotImplemented + + def __mod__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "mod", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "mod", other) + return NotImplemented + + def __mul__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "multiply", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "multiply", other) + return NotImplemented + + def __ne__(self, other): + return _dispatch_binary_elementwise(self, "not_equal", other) + + def __neg__(self): + return _dispatch_unary_elementwise(self, "negative") + + def __or__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "logical_or", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "logical_or", other) + return NotImplemented + + def __pos__(self): + return _dispatch_unary_elementwise(self, "positive") + + def __pow__(first, other, mod): + "See comment in __add__" + if mod is None: + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "power", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise(first, "power", other) + return NotImplemented + + def __rshift__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "right_shift", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "right_shift", other) + return NotImplemented + + def __setitem__(self, key, val): + try: + Xv = self.__getitem__(key) + except (ValueError, IndexError) as e: + raise e + from ._copy_utils import ( + copy_from_numpy_into, + copy_from_usm_ndarray_to_usm_ndarray, + ) + if isinstance(val, usm_ndarray): + copy_from_usm_ndarray_to_usm_ndarray(Xv, val) + else: + copy_from_numpy_into(Xv, np.asarray(val)) + + def __sub__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "subtract", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "subtract", other) + return NotImplemented + + def __truediv__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "true_divide", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "true_divide", other) + return NotImplemented + + def __xor__(first, other): + "See comment in __add__" + if isinstance(first, usm_ndarray): + return _dispatch_binary_elementwise(first, "logical_xor", other) + elif isinstance(other, usm_ndarray): + return _dispatch_binary_elementwise2(first, "logical_xor", other) + return NotImplemented + + def __radd__(self, other): + return _dispatch_binary_elementwise(self, "add", other) + + def __rand__(self, other): + return _dispatch_binary_elementwise(self, "logical_and", other) + + def __rfloordiv__(self, other): + return _dispatch_binary_elementwise2(other, "floor_divide", self) + + def __rlshift__(self, other, mod): + return _dispatch_binary_elementwise2(other, "left_shift", self) + + def __rmatmul__(self, other): + return _dispatch_binary_elementwise2(other, "matmul", self) + + def __rmod__(self, other): + return _dispatch_binary_elementwise2(other, "mod", self) + + def __rmul__(self, other): + return _dispatch_binary_elementwise(self, "multiply", other) + + def __ror__(self, other): + return _dispatch_binary_elementwise(self, "logical_or", other) + + def __rpow__(self, other, mod): + return _dispatch_binary_elementwise2(other, "power", self) + + def __rrshift__(self, other, mod): + return _dispatch_binary_elementwise2(other, "right_shift", self) + + def __rsub__(self, other): + return _dispatch_binary_elementwise2(other, "subtract", self) + + def __rtruediv__(self, other): + return _dispatch_binary_elementwise2(other, "true_divide", self) + + def __rxor__(self, other): + return _dispatch_binary_elementwise2(other, "logical_xor", self) + + def __iadd__(self, other): + res = self.__add__(other) + if res is NotImplemented: return res + self.__setitem__(Ellipsis, res) + return self + + def __iand__(self, other): + res = self.__and__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __ifloordiv__(self, other): + res = self.__floordiv__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __ilshift__(self, other): + res = self.__lshift__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __imatmul__(self, other): + res = self.__matmul__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __imod__(self, other): + res = self.__mod__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __imul__(self, other): + res = self.__mul__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __ior__(self, other): + res = self.__or__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __ipow__(self, other): + res = self.__pow__(other, None) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __irshift__(self, other): + res = self.__rshift__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __isub__(self, other): + res = self.__sub__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __itruediv__(self, other): + res = self.__truediv__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self + + def __ixor__(self, other): + res = self.__xor__(other) + if res is NotImplemented: + return res + self.__setitem__(Ellipsis, res) + return self cdef usm_ndarray _real_view(usm_ndarray ary): @@ -635,3 +1073,39 @@ cdef usm_ndarray _zero_like(usm_ndarray ary): ) # TODO: call function to set array elements to zero return r + + +cdef api char* usm_ndarray_get_data(usm_ndarray arr): + """ + """ + return arr.get_data() + + +cdef api int usm_ndarray_get_ndim(usm_ndarray arr): + """""" + return arr.get_ndim() + + +cdef api Py_ssize_t* usm_ndarray_get_shape(usm_ndarray arr): + """ """ + return arr.get_shape() + + +cdef api Py_ssize_t* usm_ndarray_get_strides(usm_ndarray arr): + """ """ + return arr.get_strides() + + +cdef api int usm_ndarray_get_typenum(usm_ndarray arr): + """ """ + return arr.get_typenum() + + +cdef api int usm_ndarray_get_flags(usm_ndarray arr): + """ """ + return arr.get_flags() + + +cdef api c_dpctl.DPCTLSyclQueueRef usm_ndarray_get_queue_ref(usm_ndarray arr): + """ """ + return arr.get_queue_ref() diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7bb71da149..7182c5cf81 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import numbers import numpy as np @@ -93,11 +94,12 @@ def test_dtypes_invalid(dtype): dpt.usm_ndarray((1,), dtype=dtype) -def test_properties(): +@pytest.mark.parametrize("dt", ["d", "c16"]) +def test_properties(dt): """ Test that properties execute """ - X = dpt.usm_ndarray((3, 4, 5), dtype="c16") + X = dpt.usm_ndarray((3, 4, 5), dtype=dt) assert isinstance(X.sycl_queue, dpctl.SyclQueue) assert isinstance(X.sycl_device, dpctl.SyclDevice) assert isinstance(X.sycl_context, dpctl.SyclContext) @@ -112,6 +114,7 @@ def test_properties(): assert isinstance(X.size, numbers.Integral) assert isinstance(X.nbytes, numbers.Integral) assert isinstance(X.ndim, numbers.Integral) + assert isinstance(X._pointer, numbers.Integral) @pytest.mark.parametrize("func", [bool, float, int, complex]) @@ -144,32 +147,12 @@ def test_copy_scalar_invalid_shape(func, shape): func(X) -@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)]) -@pytest.mark.parametrize("index_dtype", ["|i8"]) -def test_usm_ndarray_as_index(shape, index_dtype): - X = dpt.usm_ndarray(shape, dtype=index_dtype) - Xnp = np.arange(1, X.size + 1, dtype=index_dtype).reshape(shape) - X.usm_data.copy_from_host(Xnp.reshape(-1).view("|u1")) - Y = np.arange(X.size + 1) - assert Y[X] == Y[1] +def test_index_noninteger(): + import operator - -@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)]) -@pytest.mark.parametrize("index_dtype", ["|i8"]) -def test_usm_ndarray_as_index_invalid_shape(shape, index_dtype): - X = dpt.usm_ndarray(shape, dtype=index_dtype) - Y = np.arange(X.size + 1) - with pytest.raises(IndexError): - Y[X] - - -@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)]) -@pytest.mark.parametrize("index_dtype", ["|f8"]) -def test_usm_ndarray_as_index_invalid_dtype(shape, index_dtype): - X = dpt.usm_ndarray(shape, dtype=index_dtype) - Y = np.arange(X.size + 1) + X = dpt.usm_ndarray(1, "d") with pytest.raises(IndexError): - Y[X] + operator.index(X) @pytest.mark.parametrize( @@ -373,3 +356,441 @@ def test_datapi_device(): X.device.sycl_queue X.device.sycl_device repr(X.device) + + +def _pyx_capi_fnptr_to_callable( + X, pyx_capi_name, caps_name, fn_restype=ctypes.c_void_p +): + import sys + + mod = sys.modules[X.__class__.__module__] + cap = mod.__pyx_capi__.get(pyx_capi_name, None) + if cap is None: + raise ValueError( + "__pyx_capi__ does not export {} capsule".format(pyx_capi_name) + ) + # construct Python callable to invoke these functions + cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer + cap_ptr_fn.restype = ctypes.c_void_p + cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p] + fn_ptr = cap_ptr_fn(cap, caps_name) + callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, ctypes.py_object) + return callable_maker_ptr(fn_ptr) + + +def test_pyx_capi_get_data(): + X = dpt.usm_ndarray(17)[1::2] + get_data_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_data", + b"char *(struct PyUSMArrayObject *)", + fn_restype=ctypes.c_void_p, + ) + r1 = get_data_fn(X) + sua_iface = X.__sycl_usm_array_interface__ + assert r1 == sua_iface["data"][0] + sua_iface.get("offset") * X.itemsize + + +def test_pyx_capi_get_shape(): + X = dpt.usm_ndarray(17)[1::2] + get_shape_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_shape", + b"Py_ssize_t *(struct PyUSMArrayObject *)", + fn_restype=ctypes.c_void_p, + ) + c_longlong_p = ctypes.POINTER(ctypes.c_longlong) + shape0 = ctypes.cast(get_shape_fn(X), c_longlong_p).contents.value + assert shape0 == X.shape[0] + + +def test_pyx_capi_get_strides(): + X = dpt.usm_ndarray(17)[1::2] + get_strides_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_strides", + b"Py_ssize_t *(struct PyUSMArrayObject *)", + fn_restype=ctypes.c_void_p, + ) + c_longlong_p = ctypes.POINTER(ctypes.c_longlong) + strides0_p = get_strides_fn(X) + if strides0_p: + strides0_p = ctypes.cast(strides0_p, c_longlong_p).contents + strides0_p = strides0_p.value + assert strides0_p == 0 or strides0_p == X.strides[0] + + +def test_pyx_capi_get_ndim(): + X = dpt.usm_ndarray(17)[1::2] + get_ndim_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_ndim", + b"int (struct PyUSMArrayObject *)", + fn_restype=ctypes.c_int, + ) + assert get_ndim_fn(X) == X.ndim + + +def test_pyx_capi_get_typenum(): + X = dpt.usm_ndarray(17)[1::2] + get_typenum_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_typenum", + b"int (struct PyUSMArrayObject *)", + fn_restype=ctypes.c_int, + ) + typenum = get_typenum_fn(X) + assert type(typenum) is int + assert typenum == X.dtype.num + + +def test_pyx_capi_get_flags(): + X = dpt.usm_ndarray(17)[1::2] + get_flags_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_flags", + b"int (struct PyUSMArrayObject *)", + fn_restype=ctypes.c_int, + ) + flags = get_flags_fn(X) + assert type(flags) is int and flags == X.flags + + +def test_pyx_capi_get_queue_ref(): + X = dpt.usm_ndarray(17)[1::2] + get_queue_ref_fn = _pyx_capi_fnptr_to_callable( + X, + "usm_ndarray_get_queue_ref", + b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)", + fn_restype=ctypes.c_void_p, + ) + queue_ref = get_queue_ref_fn(X) # address of a copy, should be unequal + assert queue_ref != X.sycl_queue.addressof_ref() + + +def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int): + import sys + + mod = sys.modules[X.__class__.__module__] + cap = mod.__pyx_capi__.get(pyx_capi_name, None) + if cap is None: + raise ValueError( + "__pyx_capi__ does not export {} capsule".format(pyx_capi_name) + ) + # construct Python callable to invoke these functions + cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer + cap_ptr_fn.restype = ctypes.c_void_p + cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p] + cap_ptr = cap_ptr_fn(cap, caps_name) + val_ptr = ctypes.cast(cap_ptr, ctypes.POINTER(val_restype)) + return val_ptr.contents.value + + +def test_pyx_capi_check_constants(): + X = dpt.usm_ndarray(17)[1::2] + cc_flag = _pyx_capi_int(X, "USM_ARRAY_C_CONTIGUOUS") + assert cc_flag > 0 and 0 == (cc_flag & (cc_flag - 1)) + fc_flag = _pyx_capi_int(X, "USM_ARRAY_F_CONTIGUOUS") + assert fc_flag > 0 and 0 == (fc_flag & (fc_flag - 1)) + w_flag = _pyx_capi_int(X, "USM_ARRAY_WRITEABLE") + assert w_flag > 0 and 0 == (w_flag & (w_flag - 1)) + + bool_typenum = _pyx_capi_int(X, "UAR_BOOL") + assert bool_typenum == np.dtype("bool_").num + + byte_typenum = _pyx_capi_int(X, "UAR_BYTE") + assert byte_typenum == np.dtype(np.byte).num + ubyte_typenum = _pyx_capi_int(X, "UAR_UBYTE") + assert ubyte_typenum == np.dtype(np.ubyte).num + + short_typenum = _pyx_capi_int(X, "UAR_SHORT") + assert short_typenum == np.dtype(np.short).num + ushort_typenum = _pyx_capi_int(X, "UAR_USHORT") + assert ushort_typenum == np.dtype(np.ushort).num + + int_typenum = _pyx_capi_int(X, "UAR_INT") + assert int_typenum == np.dtype(np.intc).num + uint_typenum = _pyx_capi_int(X, "UAR_UINT") + assert uint_typenum == np.dtype(np.uintc).num + + long_typenum = _pyx_capi_int(X, "UAR_LONG") + assert long_typenum == np.dtype(np.int_).num + ulong_typenum = _pyx_capi_int(X, "UAR_ULONG") + assert ulong_typenum == np.dtype(np.uint).num + + longlong_typenum = _pyx_capi_int(X, "UAR_LONGLONG") + assert longlong_typenum == np.dtype(np.longlong).num + ulonglong_typenum = _pyx_capi_int(X, "UAR_ULONGLONG") + assert ulonglong_typenum == np.dtype(np.ulonglong).num + + half_typenum = _pyx_capi_int(X, "UAR_HALF") + assert half_typenum == np.dtype(np.half).num + float_typenum = _pyx_capi_int(X, "UAR_FLOAT") + assert float_typenum == np.dtype(np.single).num + double_typenum = _pyx_capi_int(X, "UAR_DOUBLE") + assert double_typenum == np.dtype(np.double).num + + cfloat_typenum = _pyx_capi_int(X, "UAR_CFLOAT") + assert cfloat_typenum == np.dtype(np.csingle).num + cdouble_typenum = _pyx_capi_int(X, "UAR_CDOUBLE") + assert cdouble_typenum == np.dtype(np.cdouble).num + + +@pytest.mark.parametrize( + "shape", [tuple(), (1,), (5,), (2, 3), (2, 3, 4), (2, 2, 2, 2, 2)] +) +@pytest.mark.parametrize( + "dtype", + [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +def test_tofrom_numpy(shape, dtype, usm_type): + q = dpctl.SyclQueue() + Xnp = np.zeros(shape, dtype=dtype) + Xusm = dpt.from_numpy(Xnp, usm_type=usm_type, queue=q) + Ynp = np.ones(shape, dtype=dtype) + ind = (slice(None, None, None),) * Ynp.ndim + Xusm[ind] = Ynp + assert np.array_equal(dpt.to_numpy(Xusm), Ynp) + + +@pytest.mark.parametrize( + "dtype", + [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("src_usm_type", ["device", "shared", "host"]) +@pytest.mark.parametrize("dst_usm_type", ["device", "shared", "host"]) +def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type): + Xnp = ( + np.random.randint(-10, 10, size=2 * 3 * 4) + .astype(dtype) + .reshape((2, 4, 3)) + ) + Znp = np.zeros( + ( + 2, + 4, + 3, + ), + dtype=dtype, + ) + Zusm_0d = dpt.from_numpy(Znp[0, 0, 0], usm_type=dst_usm_type) + ind = (-1, -1, -1) + Xusm_0d = dpt.from_numpy(Xnp[ind], usm_type=src_usm_type) + Zusm_0d[Ellipsis] = Xusm_0d + assert np.array_equal(dpt.to_numpy(Zusm_0d), Xnp[ind]) + Zusm_1d = dpt.from_numpy(Znp[0, 1:3, 0], usm_type=dst_usm_type) + ind = (-1, slice(0, 2, None), -1) + Xusm_1d = dpt.from_numpy(Xnp[ind], usm_type=src_usm_type) + Zusm_1d[Ellipsis] = Xusm_1d + assert np.array_equal(dpt.to_numpy(Zusm_1d), Xnp[ind]) + Zusm_2d = dpt.from_numpy(Znp[:, 1:3, 0], usm_type=dst_usm_type)[::-1] + Xusm_2d = dpt.from_numpy(Xnp[:, 1:4, -1], usm_type=src_usm_type) + Zusm_2d[:] = Xusm_2d[:, 0:2] + assert np.array_equal(dpt.to_numpy(Zusm_2d), Xnp[:, 1:3, -1]) + Zusm_3d = dpt.from_numpy(Znp, usm_type=dst_usm_type) + Xusm_3d = dpt.from_numpy(Xnp, usm_type=src_usm_type) + Zusm_3d[:] = Xusm_3d + assert np.array_equal(dpt.to_numpy(Zusm_3d), Xnp) + Zusm_3d[::-1] = Xusm_3d[::-1] + assert np.array_equal(dpt.to_numpy(Zusm_3d), Xnp) + Zusm_3d[:] = Xusm_3d[0] + R1 = dpt.to_numpy(Zusm_3d) + R2 = np.broadcast_to(Xnp[0], R1.shape) + assert R1.shape == R2.shape + assert np.allclose(R1, R2) + + +@pytest.mark.parametrize( + "dtype", + [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +def test_setitem_scalar(dtype, usm_type): + X = dpt.usm_ndarray((6, 6), dtype=dtype, buffer=usm_type) + for i in range(X.size): + X[np.unravel_index(i, X.shape)] = np.asarray(i, dtype=dtype) + assert np.array_equal( + dpt.to_numpy(X), np.arange(X.size).astype(dtype).reshape(X.shape) + ) + Y = dpt.usm_ndarray((2, 3), dtype=dtype, buffer=usm_type) + for i in range(Y.size): + Y[np.unravel_index(i, Y.shape)] = i + assert np.array_equal( + dpt.to_numpy(Y), np.arange(Y.size).astype(dtype).reshape(Y.shape) + ) + + +def test_shape_setter(): + def cc_strides(sh): + return np.empty(sh, dtype="u1").strides + + def relaxed_strides_equal(st1, st2, sh): + eq_ = True + for s1, s2, d in zip(st1, st2, sh): + eq_ = eq_ and ((d == 1) or (s1 == s2)) + return eq_ + + sh_s = (2 * 3 * 4 * 5,) + sh_f = ( + 2, + 3, + 4, + 5, + ) + X = dpt.usm_ndarray(sh_s, dtype="d") + expected_flags = X.flags + X.shape = sh_f + assert X.shape == sh_f + assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) + assert X.flags == expected_flags + + sh_s = ( + 2, + 12, + 5, + ) + sh_f = ( + 2, + 3, + 4, + 5, + ) + X = dpt.usm_ndarray(sh_s, dtype="d", order="C") + X.shape = sh_f + assert X.shape == sh_f + assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) + + sh_s = (2, 3, 4, 5) + sh_f = (4, 3, 2, 5) + X = dpt.usm_ndarray(sh_s, dtype="d") + X.shape = sh_f + assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) + + sh_s = (2, 3, 4, 5) + sh_f = (4, 3, 1, 2, 5) + X = dpt.usm_ndarray(sh_s, dtype="d") + X.shape = sh_f + assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) + + X = dpt.usm_ndarray(sh_s, dtype="d") + with pytest.raises(TypeError): + X.shape = "abcbe" + X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2] + with pytest.raises(AttributeError): + X.shape = (4,) + + +def test_len(): + X = dpt.usm_ndarray(1, "i4") + assert len(X) == 1 + X = dpt.usm_ndarray((2, 1), "i4") + assert len(X) == 2 + X = dpt.usm_ndarray(tuple(), "i4") + with pytest.raises(TypeError): + len(X) + + +def test_array_namespace(): + X = dpt.usm_ndarray(1, "i4") + X.__array_namespace__() + X._set_namespace(dpt) + assert X.__array_namespace__() is dpt + + +def test_dlpack(): + X = dpt.usm_ndarray(1, "i4") + X.__dlpack_device__() + X.__dlpack__(stream=None) + + +def test_to_device(): + X = dpt.usm_ndarray(1, "d") + for dev in dpctl.get_devices(): + if dev.default_selector_score > 0: + Y = X.to_device(dev) + assert Y.sycl_device == dev + + +def test_astype(): + X = dpt.usm_ndarray((5, 5), "i4") + X[:] = np.full((5, 5), 7, dtype="i4") + Y = dpt.astype(X, "c16", order="C") + assert np.allclose(dpt.to_numpy(Y), np.full((5, 5), 7, dtype="c16")) + Y = dpt.astype(X, "f2", order="K") + assert np.allclose(dpt.to_numpy(Y), np.full((5, 5), 7, dtype="f2")) + Y = dpt.astype(X, "i4", order="K", copy=False) + assert Y.usm_data is X.usm_data + + +def test_ctor_invalid(): + m = dpm.MemoryUSMShared(12) + with pytest.raises(ValueError): + dpt.usm_ndarray((4,), dtype="i4", buffer=m) + m = dpm.MemoryUSMShared(64) + with pytest.raises(ValueError): + dpt.usm_ndarray((4,), dtype="u1", buffer=m, strides={"not": "valid"}) + + +def test_reshape(): + X = dpt.usm_ndarray((5, 5), "i4") + # can be done as views + Y = dpt.reshape(X, (25,)) + assert Y.shape == (25,) + Z = X[::2, ::2] + # requires a copy + W = dpt.reshape(Z, (Z.size,), order="F") + assert W.shape == (Z.size,) + with pytest.raises(TypeError): + dpt.reshape("invalid") + with pytest.raises(ValueError): + dpt.reshape(Z, (2, 2, 2, 2, 2)) + with pytest.raises(ValueError): + dpt.reshape(Z, Z.shape, order="invalid") + W = dpt.reshape(Z, (-1,), order="C") + assert W.shape == (Z.size,) diff --git a/dpctl/tests/test_usm_ndarray_operators.py b/dpctl/tests/test_usm_ndarray_operators.py new file mode 100644 index 0000000000..47d484c203 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_operators.py @@ -0,0 +1,111 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import dpctl.tensor as dpt + + +class Dummy: + @staticmethod + def abs(a): + return a + + @staticmethod + def add(a, b): + if isinstance(a, dpt.usm_ndarray): + return a + else: + return b + + @staticmethod + def subtract(a, b): + if isinstance(a, dpt.usm_ndarray): + return a + else: + return b + + @staticmethod + def multiply(a, b): + if isinstance(a, dpt.usm_ndarray): + return a + else: + return b + + +@pytest.mark.parametrize("namespace", [None, Dummy()]) +def test_fp_ops(namespace): + X = dpt.usm_ndarray(1, "d") + X._set_namespace(namespace) + assert X.__array_namespace__() is namespace + X[0] = -2.5 + X.__abs__() + X.__add__(1.0) + X.__radd__(1.0) + X.__sub__(1.0) + X.__rsub__(1.0) + X.__mul__(1.0) + X.__rmul__(1.0) + X.__truediv__(1.0) + X.__rtruediv__(1.0) + X.__floordiv__(1.0) + X.__rfloordiv__(1.0) + X.__pos__() + X.__neg__() + X.__eq__(-2.5) + X.__ne__(-2.5) + X.__le__(-2.5) + X.__ge__(-2.5) + X.__gt__(-2.0) + X.__iadd__(X) + X.__isub__(X) + X.__imul__(X) + X.__itruediv__(1.0) + X.__ifloordiv__(1.0) + + X = dpt.usm_ndarray(1, "i4") + X._set_namespace(namespace) + assert X.__array_namespace__() is namespace + X.__lshift__(2) + X.__rshift__(2) + X.__rlshift__(2) + X.__rrshift__(2) + X.__ilshift__(2) + X.__irshift__(2) + X.__and__(X) + X.__rand__(X) + X.__iand__(X) + X.__or__(X) + X.__ror__(X) + X.__ior__(X) + X.__xor__(X) + X.__rxor__(X) + X.__ixor__(X) + X.__invert__() + X.__mod__(5) + X.__rmod__(5) + X.__imod__(5) + X.__pow__(2) + X.__rpow__(2) + X.__ipow__(2) + + M = dpt.from_numpy(np.eye(3, 3, dtype="d")) + X._set_namespace(namespace) + assert X.__array_namespace__() is namespace + M.__matmul__(M) + M.__imatmul__(M) + M.__rmatmul__(M)