From d7c2e3b4915c39a9317049565b4e70e7e892b4f3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 9 Mar 2023 13:53:20 -0800 Subject: [PATCH 1/2] Some function signatures changed to meet array API - tests adjusted for keyword argument changes --- dpctl/tensor/_ctors.py | 37 +++++++++----- dpctl/tensor/_manipulation_functions.py | 52 ++++++++++---------- dpctl/tensor/_reshape.py | 28 +++++------ dpctl/tests/test_usm_ndarray_manipulation.py | 10 ++-- 4 files changed, 71 insertions(+), 56 deletions(-) diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 16388bda5b..b62816ac95 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -472,7 +472,12 @@ def asarray( def empty( - sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): """ Creates `usm_ndarray` from uninitialized USM allocation. @@ -509,7 +514,7 @@ def empty( dtype = _get_dtype(dtype, sycl_queue) _ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device) res = dpt.usm_ndarray( - sh, + shape, dtype=dtype, buffer=usm_type, order=order, @@ -650,7 +655,12 @@ def arange( def zeros( - sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): """ Creates `usm_ndarray` with zero elements. @@ -687,7 +697,7 @@ def zeros( dtype = _get_dtype(dtype, sycl_queue) _ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device) res = dpt.usm_ndarray( - sh, + shape, dtype=dtype, buffer=usm_type, order=order, @@ -698,7 +708,12 @@ def zeros( def ones( - sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): """ Creates `usm_ndarray` with elements of one. @@ -734,7 +749,7 @@ def ones( sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) dtype = _get_dtype(dtype, sycl_queue) res = dpt.usm_ndarray( - sh, + shape, dtype=dtype, buffer=usm_type, order=order, @@ -746,7 +761,7 @@ def ones( def full( - sh, + shape, fill_value, dtype=None, order="C", @@ -805,14 +820,14 @@ def full( usm_type=usm_type, sycl_queue=sycl_queue, ) - return dpt.copy(dpt.broadcast_to(X, sh), order=order) + return dpt.copy(dpt.broadcast_to(X, shape), order=order) sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) usm_type = usm_type if usm_type is not None else "device" fill_value_type = type(fill_value) dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type) res = dpt.usm_ndarray( - sh, + shape, dtype=dtype, buffer=usm_type, order=order, @@ -872,11 +887,11 @@ def empty_like( if device is None and sycl_queue is None: device = x.device sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) - sh = x.shape + shape = x.shape dtype = dpt.dtype(dtype) _ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device) res = dpt.usm_ndarray( - sh, + shape, dtype=dtype, buffer=usm_type, order=order, diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 094b959efe..139428cd06 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -122,46 +122,46 @@ def permute_dims(X, axes): ) -def expand_dims(X, axes): +def expand_dims(X, axis): """ - expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray + expand_dims(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray Expands the shape of an array by inserting a new axis (dimension) - of size one at the position specified by axes; returns a view, if possible, + of size one at the position specified by axis; returns a view, if possible, a copy otherwise with the number of dimensions increased. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") - if not isinstance(axes, (tuple, list)): - axes = (axes,) + if not isinstance(axis, (tuple, list)): + axis = (axis,) - out_ndim = len(axes) + X.ndim - axes = normalize_axis_tuple(axes, out_ndim) + out_ndim = len(axis) + X.ndim + axis = normalize_axis_tuple(axis, out_ndim) shape_it = iter(X.shape) - shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim)) + shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim)) return dpt.reshape(X, shape) -def squeeze(X, axes=None): +def squeeze(X, axis=None): """ - squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray + squeeze(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray - Removes singleton dimensions (axes) from X; returns a view, if possible, + Removes singleton dimensions (axis) from X; returns a view, if possible, a copy otherwise, but with all or a subset of the dimensions of length 1 removed. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") X_shape = X.shape - if axes is not None: - if not isinstance(axes, (tuple, list)): - axes = (axes,) - axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1) + if axis is not None: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1) new_shape = [] for i, x in enumerate(X_shape): - if i not in axes: + if i not in axis: new_shape.append(x) else: if x != 1: @@ -222,9 +222,9 @@ def broadcast_arrays(*args): return [broadcast_to(X, shape) for X in args] -def flip(X, axes=None): +def flip(X, axis=None): """ - flip(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray + flip(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray Reverses the order of elements in an array along the given axis. The shape of the array is preserved, but the elements are reordered; @@ -233,20 +233,20 @@ def flip(X, axes=None): if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") X_ndim = X.ndim - if axes is None: + if axis is None: indexer = (np.s_[::-1],) * X_ndim else: - axes = normalize_axis_tuple(axes, X_ndim) + axis = normalize_axis_tuple(axis, X_ndim) indexer = tuple( - np.s_[::-1] if i in axes else np.s_[:] for i in range(X.ndim) + np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim) ) return X[indexer] -def roll(X, shift, axes=None): +def roll(X, shift, axis=None): """ roll(X: usm_ndarray, shift: int or tuple or list,\ - axes: int or tuple or list) -> usm_ndarray + axis: int or tuple or list) -> usm_ndarray Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced @@ -257,7 +257,7 @@ def roll(X, shift, axes=None): """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") - if axes is None: + if axis is None: res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue ) @@ -266,8 +266,8 @@ def roll(X, shift, axes=None): ) hev.wait() return res - axes = normalize_axis_tuple(axes, X.ndim, allow_duplicate=True) - broadcasted = np.broadcast(shift, axes) + axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True) + broadcasted = np.broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError("'shift' and 'axis' should be scalars or 1D sequences") shifts = {ax: 0 for ax in range(X.ndim)} diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py index 96214813c9..8b812dcd14 100644 --- a/dpctl/tensor/_reshape.py +++ b/dpctl/tensor/_reshape.py @@ -75,17 +75,17 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"): return new_sts if valid else None -def reshape(X, newshape, order="C", copy=None): +def reshape(X, shape, order="C", copy=None): """ - reshape(X: usm_ndarray, newshape: tuple, order="C") -> usm_ndarray + reshape(X: usm_ndarray, shape: tuple, order="C") -> usm_ndarray Reshapes given usm_ndarray into new shape. Returns a view, if possible, a copy otherwise. Memory layout of the copy is controlled by order keyword. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError - if not isinstance(newshape, (list, tuple)): - newshape = (newshape,) + if not isinstance(shape, (list, tuple)): + shape = (shape,) if order in "cfCF": order = order.upper() else: @@ -97,9 +97,9 @@ def reshape(X, newshape, order="C", copy=None): f"Keyword 'copy' not recognized. Expecting True, False, " f"or None, got {copy}" ) - newshape = [operator.index(d) for d in newshape] + shape = [operator.index(d) for d in shape] negative_ones_count = 0 - for nshi in newshape: + for nshi in shape: if nshi == -1: negative_ones_count = negative_ones_count + 1 if (nshi < -1) or negative_ones_count > 1: @@ -108,14 +108,14 @@ def reshape(X, newshape, order="C", copy=None): "value which can only be -1" ) if negative_ones_count: - v = X.size // (-np.prod(newshape)) - newshape = [v if d == -1 else d for d in newshape] - if X.size != np.prod(newshape): - raise ValueError(f"Can not reshape into {newshape}") + v = X.size // (-np.prod(shape)) + shape = [v if d == -1 else d for d in shape] + if X.size != np.prod(shape): + raise ValueError(f"Can not reshape into {shape}") if X.size: - newsts = reshaped_strides(X.shape, X.strides, newshape, order=order) + newsts = reshaped_strides(X.shape, X.strides, shape, order=order) else: - newsts = (1,) * len(newshape) + newsts = (1,) * len(shape) copy_required = newsts is None if copy_required and (copy is False): raise ValueError( @@ -141,11 +141,11 @@ def reshape(X, newshape, order="C", copy=None): flat_res[i], X[np.unravel_index(i, X.shape, order=order)] ) return dpt.usm_ndarray( - tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order + tuple(shape), dtype=X.dtype, buffer=flat_res, order=order ) # can form a view return dpt.usm_ndarray( - newshape, + shape, dtype=X.dtype, buffer=X, strides=tuple(newsts), diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 967f58cc21..86c3a17dee 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -483,7 +483,7 @@ def test_incompatible_shapes_raise_valueerror(shapes): assert_broadcast_arrays_raise(input_shapes[::-1]) -def test_flip_axes_incorrect(): +def test_flip_axis_incorrect(): try: q = dpctl.SyclQueue() except dpctl.SyclQueueCreationError: @@ -492,10 +492,10 @@ def test_flip_axes_incorrect(): X_np = np.ones((4, 4)) X = dpt.asarray(X_np, sycl_queue=q) - pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axes=1) - pytest.raises(np.AxisError, dpt.flip, X, axes=2) - pytest.raises(np.AxisError, dpt.flip, X, axes=-3) - pytest.raises(np.AxisError, dpt.flip, X, axes=(0, 3)) + pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1) + pytest.raises(np.AxisError, dpt.flip, X, axis=2) + pytest.raises(np.AxisError, dpt.flip, X, axis=-3) + pytest.raises(np.AxisError, dpt.flip, X, axis=(0, 3)) def test_flip_0d(): From 40120396a9697d6932e09bf225eb61da405d8bd5 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 9 Mar 2023 15:45:02 -0800 Subject: [PATCH 2/2] Added finfo_object subclass to np.finfo - Improves array API conformity --- dpctl/tensor/_manipulation_functions.py | 26 +++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 139428cd06..49c401397b 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -31,6 +31,29 @@ ) +class finfo_object(np.finfo): + """ + numpy.finfo subclass which returns Python floating-point scalars for + eps, max, min, and smallest_normal. + """ + + def __init__(self, dtype): + _supported_dtype([dpt.dtype(dtype)]) + super().__init__() + + self.eps = float(self.eps) + self.max = float(self.max) + self.min = float(self.min) + + @property + def smallest_normal(self): + return float(super().smallest_normal) + + @property + def tiny(self): + return float(super().tiny) + + def _broadcast_strides(X_shape, X_strides, res_ndim): """ Broadcasts strides to match the given dimensions; @@ -495,8 +518,7 @@ def finfo(dtype): """ if isinstance(dtype, dpt.usm_ndarray): raise TypeError("Expected dtype type, got {to}.") - _supported_dtype([dpt.dtype(dtype)]) - return np.finfo(dtype) + return finfo_object(dtype) def _supported_dtype(dtypes):