diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index eab1febb89..de36a6c8d0 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -24,7 +24,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.utils -from dpctl.tensor._ctors import _get_dtype +from dpctl.tensor._data_types import _get_dtype from dpctl.tensor._device import normalize_queue_device __doc__ = ( @@ -354,11 +354,11 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None): range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True ) inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) - st_sorted = [st[i] for i in perm] sh = X.shape sh_sorted = tuple(sh[i] for i in perm) R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") - if min(st_sorted) < 0: + if min(st) < 0: + st_sorted = [st[i] for i in perm] sl = tuple( slice(None, None, -1) if st_sorted[i] < 0 diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 19f65794f5..ba16f0f1fc 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -23,6 +23,8 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.utils +from dpctl.tensor._copy_utils import _empty_like_orderK +from dpctl.tensor._data_types import _get_dtype from dpctl.tensor._device import normalize_queue_device from dpctl.tensor._usmarray import _is_object_with_buffer_protocol @@ -32,24 +34,6 @@ _host_set = frozenset([None]) -def _get_dtype(dtype, sycl_obj, ref_type=None): - if dtype is None: - if ref_type in [None, float] or np.issubdtype(ref_type, np.floating): - dtype = ti.default_device_fp_type(sycl_obj) - return dpt.dtype(dtype) - if ref_type in [bool, np.bool_]: - dtype = ti.default_device_bool_type(sycl_obj) - return dpt.dtype(dtype) - if ref_type is int or np.issubdtype(ref_type, np.integer): - dtype = ti.default_device_int_type(sycl_obj) - return dpt.dtype(dtype) - if ref_type is complex or np.issubdtype(ref_type, np.complexfloating): - dtype = ti.default_device_complex_type(sycl_obj) - return dpt.dtype(dtype) - raise TypeError(f"Reference type {ref_type} not recognized.") - return dpt.dtype(dtype) - - def _array_info_dispatch(obj): if isinstance(obj, dpt.usm_ndarray): return obj.shape, obj.dtype, frozenset([obj.sycl_queue]) @@ -162,28 +146,7 @@ def _asarray_from_usm_ndarray( order = "C" if c_contig else "F" if order == "K": _ensure_native_dtype_device_support(dtype, copy_q.sycl_device) - # new USM allocation - res = dpt.usm_ndarray( - usm_ndary.shape, - dtype=dtype, - buffer=usm_type, - order="C", - buffer_ctor_kwargs={"queue": copy_q}, - ) - original_strides = usm_ndary.strides - ind = sorted( - range(usm_ndary.ndim), - key=lambda i: abs(original_strides[i]), - reverse=True, - ) - new_strides = tuple(res.strides[ind[i]] for i in ind) - # reuse previously made USM allocation - res = dpt.usm_ndarray( - usm_ndary.shape, - dtype=res.dtype, - buffer=res.usm_data, - strides=new_strides, - ) + res = _empty_like_orderK(usm_ndary, dtype, usm_type, copy_q) else: _ensure_native_dtype_device_support(dtype, copy_q.sycl_device) res = dpt.usm_ndarray( diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index a3ecda64c8..bf8a5f59c2 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -14,7 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from numpy import bool_ as np_bool_ +from numpy import complexfloating as np_complexfloating from numpy import dtype +from numpy import floating as np_floating +from numpy import integer as np_integer +from numpy import issubdtype as np_issubdtype + +from dpctl.tensor._tensor_impl import ( + default_device_bool_type as ti_default_device_bool_type, +) +from dpctl.tensor._tensor_impl import ( + default_device_complex_type as ti_default_device_complex_type, +) +from dpctl.tensor._tensor_impl import ( + default_device_fp_type as ti_default_device_fp_type, +) +from dpctl.tensor._tensor_impl import ( + default_device_int_type as ti_default_device_int_type, +) bool = dtype("bool") int8 = dtype("int8") @@ -74,6 +92,32 @@ def isdtype(dtype_, kind): raise TypeError(f"Unsupported data type kind: {kind}") +def _get_dtype(inp_dt, sycl_obj, ref_type=None): + """ + Type inference utility to construct data type + object with defaults based on reference type. + + _get_dtype is used by dpctl.tensor.asarray + to infer data type of the output array from the + input sequence. + """ + if inp_dt is None: + if ref_type in [None, float] or np_issubdtype(ref_type, np_floating): + fp_dt = ti_default_device_fp_type(sycl_obj) + return dtype(fp_dt) + if ref_type in [bool, np_bool_]: + bool_dt = ti_default_device_bool_type(sycl_obj) + return dtype(bool_dt) + if ref_type is int or np_issubdtype(ref_type, np_integer): + int_dt = ti_default_device_int_type(sycl_obj) + return dtype(int_dt) + if ref_type is complex or np_issubdtype(ref_type, np_complexfloating): + cfp_dt = ti_default_device_complex_type(sycl_obj) + return dtype(cfp_dt) + raise TypeError(f"Reference type {ref_type} not recognized.") + return dtype(inp_dt) + + __all__ = [ "dtype", "isdtype", diff --git a/dpctl/tests/test_tensor_asarray.py b/dpctl/tests/test_tensor_asarray.py index 2c1b6501a9..f9bc31972c 100644 --- a/dpctl/tests/test_tensor_asarray.py +++ b/dpctl/tests/test_tensor_asarray.py @@ -383,3 +383,15 @@ def test_ulonglong_gh_1167(): assert x.dtype == dpt.uint64 x = dpt.asarray(9223372036854775808, dtype="u8") assert x.dtype == dpt.uint64 + + +def test_orderK_gh_1350(): + get_queue_or_skip() + a = dpt.empty((2, 3, 4), dtype="u1") + b = dpt.permute_dims(a, (2, 0, 1)) + c = dpt.asarray(b, copy=True, order="K") + + assert c.shape == b.shape + assert c.strides == b.strides + assert c._element_offset == 0 + assert not c._pointer == b._pointer