diff --git a/dpctl/tensor/_dlpack.pxd b/dpctl/tensor/_dlpack.pxd index 4746432803..9846f54be6 100644 --- a/dpctl/tensor/_dlpack.pxd +++ b/dpctl/tensor/_dlpack.pxd @@ -24,15 +24,24 @@ from ._usmarray cimport usm_ndarray cdef extern from 'dlpack/dlpack.h' nogil: int device_CPU 'kDLCPU' - int device_oneAPI 'kDLOneAPI' + int device_CUDA 'kDLCUDA' + int device_CUDAHost 'kDLCUDAHost' + int device_CUDAManaged 'kDLCUDAManaged' + int device_DLROCM 'kDLROCM' + int device_ROCMHost 'kDLROCMHost' int device_OpenCL 'kDLOpenCL' - + int device_Vulkan 'kDLVulkan' + int device_Metal 'kDLMetal' + int device_VPI 'kDLVPI' + int device_OneAPI 'kDLOneAPI' + int device_WebGPU 'kDLWebGPU' + int device_Hexagon 'kDLHexagon' + int device_MAIA 'kDLMAIA' cpdef object to_dlpack_capsule(usm_ndarray array) except + +cpdef object to_dlpack_versioned_capsule(usm_ndarray array, bint copied) except + cpdef usm_ndarray from_dlpack_capsule(object dltensor) except + -cpdef from_dlpack(array) - cdef int get_parent_device_ordinal_id(SyclDevice dev) except * cdef class DLPackCreationError(Exception): diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index 8664c1af54..7c0f96ec7f 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -20,7 +20,7 @@ cimport cpython from libc cimport stdlib -from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t +from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem @@ -32,7 +32,7 @@ from .._backend cimport ( DPCTLSyclDeviceRef, DPCTLSyclUSMRef, ) -from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, usm_ndarray +from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, USM_ARRAY_WRITABLE, usm_ndarray from platform import system as sys_platform @@ -41,13 +41,25 @@ import numpy as np import dpctl import dpctl.memory as dpmem +from ._device import Device + cdef bint _IS_LINUX = sys_platform() == "Linux" del sys_platform cdef extern from 'dlpack/dlpack.h' nogil: - cdef int DLPACK_VERSION + cdef int DLPACK_MAJOR_VERSION + + cdef int DLPACK_MINOR_VERSION + + cdef int DLPACK_FLAG_BITMASK_READ_ONLY + + cdef int DLPACK_FLAG_BITMASK_IS_COPIED + + ctypedef struct DLPackVersion: + uint32_t major + uint32_t minor cdef enum DLDeviceType: kDLCPU @@ -61,6 +73,9 @@ cdef extern from 'dlpack/dlpack.h' nogil: kDLMetal kDLVPI kDLOneAPI + kDLWebGPU + kDLHexagon + kDLMAIA ctypedef struct DLDevice: DLDeviceType device_type @@ -93,17 +108,28 @@ cdef extern from 'dlpack/dlpack.h' nogil: void *manager_ctx void (*deleter)(DLManagedTensor *) # noqa: E211 + ctypedef struct DLManagedTensorVersioned: + DLPackVersion version + void *manager_ctx + void (*deleter)(DLManagedTensorVersioned *) # noqa: E211 + uint64_t flags + DLTensor dl_tensor + def get_build_dlpack_version(): """ - Returns the string value of DLPACK_VERSION from dlpack.h - :module:`dpctl.tensor` was built with. + Returns a tuple of integers representing the `major` and `minor` + version of DLPack :module:`dpctl.tensor` was built with. + This tuple can be passed as the `max_version` argument to + `__dlpack__` to guarantee module:`dpctl.tensor` can properly + consume capsule. Returns: - A string value of the version of DLPack used to build - `dpctl`. + Tuple[int, int] + A tuple of integers representing the `major` and `minor` + version of DLPack used to build :module:`dpctl.tensor`. """ - return str(DLPACK_VERSION) + return (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION) cdef void _pycapsule_deleter(object dlt_capsule) noexcept: @@ -116,11 +142,32 @@ cdef void _pycapsule_deleter(object dlt_capsule) noexcept: cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil: if dlm_tensor is not NULL: + # we only delete shape, because we make single allocation to + # acommodate both shape and strides if strides are needed stdlib.free(dlm_tensor.dl_tensor.shape) - cpython.Py_DECREF(dlm_tensor.manager_ctx) + cpython.Py_DECREF(dlm_tensor.manager_ctx) dlm_tensor.manager_ctx = NULL stdlib.free(dlm_tensor) + +cdef void _pycapsule_versioned_deleter(object dlt_capsule) noexcept: + cdef DLManagedTensorVersioned *dlmv_tensor = NULL + if cpython.PyCapsule_IsValid(dlt_capsule, 'dltensor_versioned'): + dlmv_tensor = cpython.PyCapsule_GetPointer( + dlt_capsule, 'dltensor_versioned') + dlmv_tensor.deleter(dlmv_tensor) + + +cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tensor) noexcept with gil: + if dlmv_tensor is not NULL: + # we only delete shape, because we make single allocation to + # acommodate both shape and strides if strides are needed + stdlib.free(dlmv_tensor.dl_tensor.shape) + cpython.Py_DECREF(dlmv_tensor.manager_ctx) + dlmv_tensor.manager_ctx = NULL + stdlib.free(dlmv_tensor) + + cdef object _get_default_context(c_dpctl.SyclDevice dev) except *: try: if _IS_LINUX: @@ -155,33 +202,71 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *: return dev.get_overall_ordinal() +cdef int get_array_dlpack_device_id( + usm_ndarray usm_ary +) except *: + """Finds ordinal number of the parent of device where array + was allocated. + """ + cdef c_dpctl.SyclQueue ary_sycl_queue + cdef c_dpctl.SyclDevice ary_sycl_device + cdef DPCTLSyclDeviceRef pDRef = NULL + cdef int device_id = -1 + + ary_sycl_queue = usm_ary.get_sycl_queue() + ary_sycl_device = ary_sycl_queue.get_sycl_device() + + default_context = _get_default_context(ary_sycl_device) + if default_context is None: + # check that ary_sycl_device is a non-partitioned device + pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) + if pDRef is not NULL: + DPCTLDevice_Delete(pDRef) + raise DLPackCreationError( + "to_dlpack_capsule: DLPack can only export arrays allocated " + "on non-partitioned SYCL devices on platforms where " + "default_context oneAPI extension is not supported." + ) + device_id = ary_sycl_device.get_overall_ordinal() + else: + if not usm_ary.sycl_context == default_context: + raise DLPackCreationError( + "to_dlpack_capsule: DLPack can only export arrays based on USM " + "allocations bound to a default platform SYCL context" + ) + device_id = get_parent_device_ordinal_id(ary_sycl_device) + + if device_id < 0: + raise DLPackCreationError( + "get_array_dlpack_device_id: failed to determine device_id" + ) + + return device_id + + cpdef to_dlpack_capsule(usm_ndarray usm_ary): """ to_dlpack_capsule(usm_ary) Constructs named Python capsule object referencing - instance of `DLManagerTensor` from + instance of ``DLManagedTensor`` from :class:`dpctl.tensor.usm_ndarray` instance. Args: usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray` Returns: - Python a new capsule with name "dltensor" that contains - a pointer to `DLManagedTensor` struct. + A new capsule with name ``"dltensor"`` that contains + a pointer to ``DLManagedTensor`` struct. Raises: DLPackCreationError: when array can be represented as DLPack tensor. This may happen when array was allocated on a partitioned sycl device, or its USM allocation is not bound to the platform default SYCL context. - MemoryError: when host allocation to needed for `DLManagedTensor` + MemoryError: when host allocation to needed for ``DLManagedTensor`` did not succeed. ValueError: when array elements data type could not be represented - in `DLManagedTensor`. + in ``DLManagedTensor``. """ - cdef c_dpctl.SyclQueue ary_sycl_queue - cdef c_dpctl.SyclDevice ary_sycl_device - cdef DPCTLSyclDeviceRef pDRef = NULL - cdef DPCTLSyclDeviceRef tDRef = NULL cdef DLManagedTensor *dlm_tensor = NULL cdef DLTensor *dl_tensor = NULL cdef int nd = usm_ary.get_ndim() @@ -198,42 +283,8 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): cdef Py_ssize_t si = 1 ary_base = usm_ary.get_base() - ary_sycl_queue = usm_ary.get_sycl_queue() - ary_sycl_device = ary_sycl_queue.get_sycl_device() - default_context = _get_default_context(ary_sycl_device) - if default_context is None: - # check that ary_sycl_device is a non-partitioned device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - DPCTLDevice_Delete(pDRef) - raise DLPackCreationError( - "to_dlpack_capsule: DLPack can only export arrays allocated " - "on non-partitioned SYCL devices on platforms where " - "default_context oneAPI extension is not supported." - ) - else: - if not usm_ary.sycl_context == default_context: - raise DLPackCreationError( - "to_dlpack_capsule: DLPack can only export arrays based on USM " - "allocations bound to a default platform SYCL context" - ) - # Find the unpartitioned parent of the allocation device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - tDRef = DPCTLDevice_GetParentDevice(pDRef) - while tDRef is not NULL: - DPCTLDevice_Delete(pDRef) - pDRef = tDRef - tDRef = DPCTLDevice_GetParentDevice(pDRef) - ary_sycl_device = c_dpctl.SyclDevice._create(pDRef) - - # Find ordinal number of the parent device - device_id = ary_sycl_device.get_overall_ordinal() - if device_id < 0: - raise DLPackCreationError( - "to_dlpack_capsule: failed to determine device_id" - ) + device_id = get_array_dlpack_device_id(usm_ary) dlm_tensor = stdlib.malloc( sizeof(DLManagedTensor)) @@ -296,19 +347,149 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): stdlib.free(dlm_tensor) raise ValueError("Unrecognized array data type") - dlm_tensor.manager_ctx = usm_ary - cpython.Py_INCREF(usm_ary) + dlm_tensor.manager_ctx = ary_base + cpython.Py_INCREF(ary_base) dlm_tensor.deleter = _managed_tensor_deleter return cpython.PyCapsule_New(dlm_tensor, 'dltensor', _pycapsule_deleter) +cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied): + """ + to_dlpack_versioned_capsule(usm_ary, copied) + + Constructs named Python capsule object referencing + instance of ``DLManagedTensorVersioned`` from + :class:`dpctl.tensor.usm_ndarray` instance. + + Args: + usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray` + copied: A bint representing whether the data was previously + copied in order to set the flags with the is-copied + bitmask. + Returns: + A new capsule with name ``"dltensor_versioned"`` that + contains a pointer to ``DLManagedTensorVersioned`` struct. + Raises: + DLPackCreationError: when array can be represented as + DLPack tensor. This may happen when array was allocated + on a partitioned sycl device, or its USM allocation is + not bound to the platform default SYCL context. + MemoryError: when host allocation to needed for + ``DLManagedTensorVersioned`` did not succeed. + ValueError: when array elements data type could not be represented + in ``DLManagedTensorVersioned``. + """ + cdef DLManagedTensorVersioned *dlmv_tensor = NULL + cdef DLTensor *dl_tensor = NULL + cdef uint32_t dlmv_flags = 0 + cdef int nd = usm_ary.get_ndim() + cdef char *data_ptr = usm_ary.get_data() + cdef Py_ssize_t *shape_ptr = NULL + cdef Py_ssize_t *strides_ptr = NULL + cdef int64_t *shape_strides_ptr = NULL + cdef int i = 0 + cdef int device_id = -1 + cdef int flags = 0 + cdef char *base_ptr = NULL + cdef Py_ssize_t element_offset = 0 + cdef Py_ssize_t byte_offset = 0 + cdef Py_ssize_t si = 1 + + ary_base = usm_ary.get_base() + + # Find ordinal number of the parent device + device_id = get_array_dlpack_device_id(usm_ary) + + dlmv_tensor = stdlib.malloc( + sizeof(DLManagedTensorVersioned)) + if dlmv_tensor is NULL: + raise MemoryError( + "to_dlpack_versioned_capsule: Could not allocate memory " + "for DLManagedTensorVersioned" + ) + shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) + if shape_strides_ptr is NULL: + stdlib.free(dlmv_tensor) + raise MemoryError( + "to_dlpack_versioned_capsule: Could not allocate memory " + "for shape/strides" + ) + # this can be a separate function for handling shapes and strides + shape_ptr = usm_ary.get_shape() + for i in range(nd): + shape_strides_ptr[i] = shape_ptr[i] + strides_ptr = usm_ary.get_strides() + flags = usm_ary.flags_ + if strides_ptr: + for i in range(nd): + shape_strides_ptr[nd + i] = strides_ptr[i] + else: + if not (flags & USM_ARRAY_C_CONTIGUOUS): + si = 1 + for i in range(0, nd): + shape_strides_ptr[nd + i] = si + si = si * shape_ptr[i] + strides_ptr = &shape_strides_ptr[nd] + + # this can all be a function for building the dl_tensor + # object (separate from dlm/dlmv) + ary_dt = usm_ary.dtype + ary_dtk = ary_dt.kind + element_offset = usm_ary.get_offset() + byte_offset = element_offset * (ary_dt.itemsize) + + dl_tensor = &dlmv_tensor.dl_tensor + dl_tensor.data = (data_ptr - byte_offset) + dl_tensor.ndim = nd + dl_tensor.byte_offset = byte_offset + dl_tensor.shape = &shape_strides_ptr[0] + if strides_ptr is NULL: + dl_tensor.strides = NULL + else: + dl_tensor.strides = &shape_strides_ptr[nd] + dl_tensor.device.device_type = kDLOneAPI + dl_tensor.device.device_id = device_id + dl_tensor.dtype.lanes = 1 + dl_tensor.dtype.bits = (ary_dt.itemsize * 8) + if (ary_dtk == "b"): + dl_tensor.dtype.code = kDLBool + elif (ary_dtk == "u"): + dl_tensor.dtype.code = kDLUInt + elif (ary_dtk == "i"): + dl_tensor.dtype.code = kDLInt + elif (ary_dtk == "f"): + dl_tensor.dtype.code = kDLFloat + elif (ary_dtk == "c"): + dl_tensor.dtype.code = kDLComplex + else: + stdlib.free(shape_strides_ptr) + stdlib.free(dlmv_tensor) + raise ValueError("Unrecognized array data type") + + # set flags down here + if copied: + dlmv_flags |= DLPACK_FLAG_BITMASK_IS_COPIED + if not (flags & USM_ARRAY_WRITABLE): + dlmv_flags |= DLPACK_FLAG_BITMASK_READ_ONLY + dlmv_tensor.flags = dlmv_flags + + dlmv_tensor.version.major = DLPACK_MAJOR_VERSION + dlmv_tensor.version.minor = DLPACK_MINOR_VERSION + + dlmv_tensor.manager_ctx = ary_base + cpython.Py_INCREF(ary_base) + dlmv_tensor.deleter = _managed_tensor_versioned_deleter + + return cpython.PyCapsule_New(dlmv_tensor, 'dltensor_versioned', _pycapsule_versioned_deleter) + + cdef class _DLManagedTensorOwner: """ Helper class managing the lifetime of the DLManagedTensor struct transferred from a 'dlpack' capsule. """ - cdef DLManagedTensor *dlm_tensor + cdef DLManagedTensor * dlm_tensor def __cinit__(self): self.dlm_tensor = NULL @@ -316,6 +497,7 @@ cdef class _DLManagedTensorOwner: def __dealloc__(self): if self.dlm_tensor: self.dlm_tensor.deleter(self.dlm_tensor) + self.dlm_tensor = NULL @staticmethod cdef _DLManagedTensorOwner _create(DLManagedTensor *dlm_tensor_src): @@ -324,6 +506,28 @@ cdef class _DLManagedTensorOwner: return res +cdef class _DLManagedTensorVersionedOwner: + """ + Helper class managing the lifetime of the DLManagedTensorVersioned + struct transferred from a 'dlpack_versioned' capsule. + """ + cdef DLManagedTensorVersioned * dlmv_tensor + + def __cinit__(self): + self.dlmv_tensor = NULL + + def __dealloc__(self): + if self.dlmv_tensor: + self.dlmv_tensor.deleter(self.dlmv_tensor) + self.dlmv_tensor = NULL + + @staticmethod + cdef _DLManagedTensorVersionedOwner _create(DLManagedTensorVersioned *dlmv_tensor_src): + cdef _DLManagedTensorVersionedOwner res = _DLManagedTensorVersionedOwner.__new__(_DLManagedTensorVersionedOwner) + res.dlmv_tensor = dlmv_tensor_src + return res + + cpdef usm_ndarray from_dlpack_capsule(object py_caps): """ from_dlpack_capsule(caps) @@ -372,8 +576,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps): ) else: raise TypeError( - f"A Python 'dltensor' capsule was expected, " - "got {type(dlm_tensor)}" + "`from_dlpack_capsule` expects a Python 'dltensor' capsule" ) dlm_tensor = cpython.PyCapsule_GetPointer( py_caps, "dltensor") @@ -508,29 +711,248 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps): ) -cpdef from_dlpack(array): - """ from_dlpack(obj) +cpdef usm_ndarray from_dlpack_versioned_capsule(object py_caps): + """ + from_dlpack_versioned_capsule(caps) + + Reconstructs instance of :class:`dpctl.tensor.usm_ndarray` from + named Python capsule object referencing instance of + ``DLManagedTensorVersioned`` without copy. The instance forms a + view in the memory of the tensor. + + Args: + caps: + Python capsule with name ``"dltensor_versioned"`` expected + to reference an instance of ``DLManagedTensorVersioned`` + struct. + Returns: + Instance of :class:`dpctl.tensor.usm_ndarray` with a view into + memory of the tensor. Capsule is renamed to + ``"used_dltensor_versioned"`` upon success. + Raises: + TypeError: + if argument is not a ``"dltensor_versioned"`` capsule. + ValueError: + if argument is ``"used_dltensor_versioned"`` capsule + BufferError: + if the USM pointer is not bound to the reconstructed + sycl context, or the DLPack's device_type is not supported + by :mod:`dpctl`. + """ + cdef DLManagedTensorVersioned *dlmv_tensor = NULL + cdef bytes usm_type + cdef size_t sz = 1 + cdef size_t alloc_sz = 1 + cdef int i + cdef int device_id = -1 + cdef int element_bytesize = 0 + cdef Py_ssize_t offset_min = 0 + cdef Py_ssize_t offset_max = 0 + cdef char *mem_ptr = NULL + cdef Py_ssize_t mem_ptr_delta = 0 + cdef Py_ssize_t element_offset = 0 + cdef int64_t stride_i = -1 + cdef int64_t shape_i = -1 + + if not cpython.PyCapsule_IsValid(py_caps, 'dltensor_versioned'): + if cpython.PyCapsule_IsValid(py_caps, 'used_dltensor_versioned'): + raise ValueError( + "A DLPack tensor object can not be consumed multiple times" + ) + else: + raise TypeError( + "`from_dlpack_versioned_capsule` expects a Python " + "'dltensor_versioned' capsule" + ) + dlmv_tensor = cpython.PyCapsule_GetPointer( + py_caps, "dltensor_versioned") + # Verify that we can work with this device + if dlmv_tensor.dl_tensor.device.device_type == kDLOneAPI: + device_id = dlmv_tensor.dl_tensor.device.device_id + root_device = dpctl.SyclDevice(str(device_id)) + try: + if _IS_LINUX: + default_context = root_device.sycl_platform.default_context + else: + default_context = get_device_cached_queue(root_device).sycl_context + except RuntimeError: + default_context = get_device_cached_queue(root_device).sycl_context + if dlmv_tensor.dl_tensor.data is NULL: + usm_type = b"device" + q = get_device_cached_queue((default_context, root_device,)) + else: + usm_type = c_dpmem._Memory.get_pointer_type( + dlmv_tensor.dl_tensor.data, + default_context) + if usm_type == b"unknown": + raise BufferError( + "Data pointer in DLPack is not bound to default sycl " + f"context of device '{device_id}', translated to " + f"{root_device.filter_string}" + ) + alloc_device = c_dpmem._Memory.get_pointer_device( + dlmv_tensor.dl_tensor.data, + default_context + ) + q = get_device_cached_queue((default_context, alloc_device,)) + if dlmv_tensor.dl_tensor.dtype.bits % 8: + raise BufferError( + "Can not import DLPack tensor whose element's " + "bitsize is not a multiple of 8" + ) + if dlmv_tensor.dl_tensor.dtype.lanes != 1: + raise BufferError( + "Can not import DLPack tensor with lanes != 1" + ) + if dlmv_tensor.version.major > DLPACK_MAJOR_VERSION: + raise BufferError( + "Can not import DLPack tensor with major version " + f"greater than {DLPACK_MAJOR_VERSION}" + ) + offset_min = 0 + if dlmv_tensor.dl_tensor.strides is NULL: + for i in range(dlmv_tensor.dl_tensor.ndim): + sz = sz * dlmv_tensor.dl_tensor.shape[i] + offset_max = sz - 1 + else: + offset_max = 0 + for i in range(dlmv_tensor.dl_tensor.ndim): + stride_i = dlmv_tensor.dl_tensor.strides[i] + shape_i = dlmv_tensor.dl_tensor.shape[i] + if shape_i > 1: + shape_i -= 1 + if stride_i > 0: + offset_max = offset_max + stride_i * shape_i + else: + offset_min = offset_min + stride_i * shape_i + sz = offset_max - offset_min + 1 + if sz == 0: + sz = 1 + + element_bytesize = (dlmv_tensor.dl_tensor.dtype.bits // 8) + sz = sz * element_bytesize + element_offset = dlmv_tensor.dl_tensor.byte_offset // element_bytesize + + # transfer dlmv_tensor ownership + dlmv_holder = _DLManagedTensorVersionedOwner._create(dlmv_tensor) + cpython.PyCapsule_SetName(py_caps, 'used_dltensor_versioned') + + if dlmv_tensor.dl_tensor.data is NULL: + usm_mem = dpmem.MemoryUSMDevice(sz, q) + else: + mem_ptr_delta = dlmv_tensor.dl_tensor.byte_offset - ( + element_offset * element_bytesize + ) + mem_ptr = dlmv_tensor.dl_tensor.data + alloc_sz = dlmv_tensor.dl_tensor.byte_offset + ( + (offset_max + 1) * element_bytesize) + tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref( + mem_ptr, + max(alloc_sz, element_bytesize), + (q).get_queue_ref(), + memory_owner=dlmv_holder + ) + if mem_ptr_delta == 0: + usm_mem = tmp + else: + alloc_sz = dlmv_tensor.dl_tensor.byte_offset + ( + (offset_max * element_bytesize + mem_ptr_delta)) + usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref( + (mem_ptr + (element_bytesize - mem_ptr_delta)), + max(alloc_sz, element_bytesize), + (q).get_queue_ref(), + memory_owner=tmp + ) + py_shape = list() + for i in range(dlmv_tensor.dl_tensor.ndim): + py_shape.append(dlmv_tensor.dl_tensor.shape[i]) + if (dlmv_tensor.dl_tensor.strides is NULL): + py_strides = None + else: + py_strides = list() + for i in range(dlmv_tensor.dl_tensor.ndim): + py_strides.append(dlmv_tensor.dl_tensor.strides[i]) + if (dlmv_tensor.dl_tensor.dtype.code == kDLUInt): + ary_dt = np.dtype("u" + str(element_bytesize)) + elif (dlmv_tensor.dl_tensor.dtype.code == kDLInt): + ary_dt = np.dtype("i" + str(element_bytesize)) + elif (dlmv_tensor.dl_tensor.dtype.code == kDLFloat): + ary_dt = np.dtype("f" + str(element_bytesize)) + elif (dlmv_tensor.dl_tensor.dtype.code == kDLComplex): + ary_dt = np.dtype("c" + str(element_bytesize)) + elif (dlmv_tensor.dl_tensor.dtype.code == kDLBool): + ary_dt = np.dtype("?") + else: + raise BufferError( + "Can not import DLPack tensor with type code {}.".format( + dlmv_tensor.dl_tensor.dtype.code + ) + ) + res_ary = usm_ndarray( + py_shape, + dtype=ary_dt, + buffer=usm_mem, + strides=py_strides, + offset=element_offset + ) + if (dlmv_tensor.flags & DLPACK_FLAG_BITMASK_READ_ONLY): + res_ary.flags_ = (res_ary.flags_ & ~USM_ARRAY_WRITABLE) + return res_ary + else: + raise BufferError( + "The DLPack tensor resides on unsupported device." + ) + + +def from_dlpack(x, /, *, device=None, copy=None): + """ from_dlpack(x, /, *, device=None, copy=None) Constructs :class:`dpctl.tensor.usm_ndarray` instance from a Python - object ``obj`` that implements ``__dlpack__`` protocol. The output - array is always a zero-copy view of the input. + object ``x`` that implements ``__dlpack__`` protocol. Args: - obj: + x (Python object): A Python object representing an array that supports ``__dlpack__`` protocol. + device (Optional[str, + :class:`dpctl.SyclDevice`, + :class:`dpctl.SyclQueue`, + :class:`dpctl.tensor.Device`, + tuple([enum.Enum, int])])): + Array API concept of a device where the output array is to be placed. + ``device`` can be ``None``, a oneAPI filter selector + string, an instance of :class:`dpctl.SyclDevice` corresponding to + a non-partitioned SYCL device, an instance of + :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device` object + returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a + 2-tuple matching the format of the output of the ``__dlpack_device__`` + method, an integer enumerator representing the device type followed by + an integer representing the index of the device. + Default: ``None``. + copy (bool, optional) + Boolean indicating whether or not to copy the input. + + * If ``copy`` is ``True``, the input will always be + copied. + * If ``False``, a ``BufferError`` will be raised if a + copy is deemed necessary. + * If ``None``, a copy will be made only if deemed + necessary, otherwise, the existing memory buffer will + be reused. + + Default: ``None``. Returns: usm_ndarray: - An array with a view into the tensor underlying the - input ``obj``. + An array containing the data in ``x``. When ``copy`` is + ``None`` or ``False``, this may be a view into the original + memory. Raises: TypeError: - if ``obj`` does not implement ``__dlpack__`` method + if ``x`` does not implement ``__dlpack__`` method ValueError: - if zero copy view can not be constructed because - the input array resides on an unsupported device + if the input array resides on an unsupported device See https://dmlc.github.io/dlpack/latest/ for more details. @@ -556,16 +978,31 @@ cpdef from_dlpack(array): X = dpt.from_dlpack(C) """ - if not hasattr(array, "__dlpack__"): + if not hasattr(x, "__dlpack__"): raise TypeError( - "The argument of type {type(array)} does not implement " + f"The argument of type {type(x)} does not implement " "`__dlpack__` method." ) - dlpack_attr = getattr(array, "__dlpack__") + dlpack_attr = getattr(x, "__dlpack__") if not callable(dlpack_attr): raise TypeError( - "The argument of type {type(array)} does not implement " + f"The argument of type {type(x)} does not implement " "`__dlpack__` method." ) - dlpack_capsule = dlpack_attr() - return from_dlpack_capsule(dlpack_capsule) + try: + # device is converted to a dlpack_device if necessary + dl_device = None + if device: + if isinstance(device, tuple): + dl_device = device + else: + if not isinstance(device, dpctl.SyclDevice): + d = Device.create_device(device).sycl_device + dl_device = (device_OneAPI, get_parent_device_ordinal_id(d)) + else: + dl_device = (device_OneAPI, get_parent_device_ordinal_id(device)) + dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy) + return from_dlpack_versioned_capsule(dlpack_capsule) + except TypeError: + dlpack_capsule = dlpack_attr() + return from_dlpack_capsule(dlpack_capsule) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 928be8d9a1..0c28380222 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -36,8 +36,12 @@ cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem cimport dpctl.tensor._dlpack as c_dlpack +from ._dlpack import get_build_dlpack_version + from .._sycl_device_factory cimport _cached_default_device +from enum import IntEnum + import dpctl.tensor._flags as _flags from dpctl.tensor._tensor_impl import default_device_fp_type @@ -46,6 +50,23 @@ include "_types.pxi" include "_slicing.pxi" +class DLDeviceType(IntEnum): + kDLCPU = c_dlpack.device_CPU + kDLCUDA = c_dlpack.device_CUDA + kDLCUDAHost = c_dlpack.device_CUDAHost + kDLCUDAManaged = c_dlpack.device_CUDAManaged + kDLROCM = c_dlpack.device_DLROCM + kDLROCMHost = c_dlpack.device_ROCMHost + kDLOpenCL = c_dlpack.device_OpenCL + kDLVulkan = c_dlpack.device_Vulkan + kDLMetal = c_dlpack.device_Metal + kDLVPI = c_dlpack.device_VPI + kDLOneAPI = c_dlpack.device_OneAPI + kDLWebGPU = c_dlpack.device_WebGPU + kDLHexagon = c_dlpack.device_Hexagon + kDLMAIA = c_dlpack.device_MAIA + + cdef class InternalUSMArrayError(Exception): """ An InternalUSMArrayError exception is raised when internal @@ -177,6 +198,7 @@ cdef class usm_ndarray: cdef int itemsize = 0 cdef int err = 0 cdef int contig_flag = 0 + cdef int writable_flag = USM_ARRAY_WRITABLE cdef Py_ssize_t *shape_ptr = NULL cdef Py_ssize_t ary_nelems = 0 cdef Py_ssize_t ary_nbytes = 0 @@ -269,6 +291,8 @@ cdef class usm_ndarray: "an instance of `MemoryUSM*` object, or a usm_ndarray" "").format(buffer)) elif isinstance(buffer, usm_ndarray): + if not buffer.flags.writable: + writable_flag = 0 _buffer = buffer.usm_data else: self._cleanup() @@ -293,7 +317,7 @@ cdef class usm_ndarray: self.shape_ = shape_ptr self.strides_ = strides_ptr self.typenum_ = typenum - self.flags_ = (contig_flag | USM_ARRAY_WRITABLE) + self.flags_ = (contig_flag | writable_flag) self.nd_ = nd self.array_namespace_ = array_namespace @@ -917,7 +941,7 @@ cdef class usm_ndarray: cdef c_dpmem._Memory arr_buf d = Device.create_device(target_device) - if (stream is None or type(stream) is not dpctl.SyclQueue or + if (stream is None or not isinstance(stream, dpctl.SyclQueue) or stream == self.sycl_queue): pass else: @@ -1043,14 +1067,42 @@ cdef class usm_ndarray: "Implementation for operator.and" return dpctl.tensor.bitwise_and(self, other) - def __dlpack__(self, stream=None): + def __dlpack__(self, *, stream=None, max_version=None, dl_device=None, copy=None): """ Produces DLPack capsule. Args: stream (:class:`dpctl.SyclQueue`, optional): - Execution queue to synchronize with. If ``None``, - synchronization is not performed. + Execution queue to synchronize with. + If ``None``, synchronization is not performed. + Default: ``None``. + max_version (tuple[int, int], optional): + The maximum DLPack version the consumer (caller of + ``__dlpack__``) supports. As ``__dlpack__`` may not + always return a DLPack capsule with version + `max_version`, the consumer must verify the version + even if this argument is passed. + Default: ``None``. + dl_device (tuple[enum.Enum, int], optional): + The device the returned DLPack capsule will be + placed on. + The device must be a 2-tuple matching the format of + ``__dlpack_device__`` method, an integer enumerator + representing the device type followed by an integer + representing the index of the device. + Default: ``None``. + copy (bool, optional): + Boolean indicating whether or not to copy the input. + + * If ``copy`` is ``True``, the input will always be + copied. + * If ``False``, a ``BufferError`` will be raised if a + copy is deemed necessary. + * If ``None``, a copy will be made only if deemed + necessary, otherwise, the existing memory buffer will + be reused. + + Default: ``None``. Raises: MemoryError: @@ -1058,15 +1110,82 @@ cdef class usm_ndarray: DLPackCreationError: when array is allocated on a partitioned SYCL device, or with a non-default context. + BufferError: + when a copy is deemed necessary but ``copy`` + is ``False`` or when the provided ``dl_device`` + cannot be handled. """ - _caps = c_dlpack.to_dlpack_capsule(self) - if (stream is None or type(stream) is not dpctl.SyclQueue or - stream == self.sycl_queue): - pass + if max_version is None: + # legacy path for DLManagedTensor + # copy kwarg ignored because copy flag can't be set + _caps = c_dlpack.to_dlpack_capsule(self) + if (stream is None or type(stream) is not dpctl.SyclQueue or + stream == self.sycl_queue): + pass + else: + ev = self.sycl_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + return _caps else: - ev = self.sycl_queue.submit_barrier() - stream.submit_barrier(dependent_events=[ev]) - return _caps + if not isinstance(max_version, tuple) or len(max_version) != 2: + raise TypeError( + "`__dlpack__` expects `max_version` to be a " + "2-tuple of integers `(major, minor)`, instead " + f"got {type(max_version)}" + ) + dpctl_dlpack_version = get_build_dlpack_version() + if max_version[0] >= dpctl_dlpack_version[0]: + # DLManagedTensorVersioned path + # TODO: add logic for targeting a device + if dl_device is not None: + if dl_device != self.__dlpack_device__(): + raise NotImplementedError( + "targeting a device with `__dlpack__` is not " + "currently implemented" + ) + if copy is None: + copy = False + # TODO: strategy for handling stream on different device from dl_device + if copy: + if (stream is None or type(stream) is not dpctl.SyclQueue or + stream == self.sycl_queue): + pass + else: + ev = self.sycl_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + nbytes = self.usm_data.nbytes + copy_buffer = type(self.usm_data)( + nbytes, queue=self.sycl_queue + ) + copy_buffer.copy_from_device(self.usm_data) + _copied_arr = usm_ndarray( + self.shape, + self.dtype, + buffer=copy_buffer, + strides=self.strides, + offset=self.get_offset() + ) + _copied_arr.flags_ = self.flags_ + _caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy) + else: + _caps = c_dlpack.to_dlpack_versioned_capsule(self, copy) + if (stream is None or type(stream) is not dpctl.SyclQueue or + stream == self.sycl_queue): + pass + else: + ev = self.sycl_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + return _caps + else: + # legacy path for DLManagedTensor + _caps = c_dlpack.to_dlpack_capsule(self) + if (stream is None or type(stream) is not dpctl.SyclQueue or + stream == self.sycl_queue): + pass + else: + ev = self.sycl_queue.submit_barrier() + stream.submit_barrier(dependent_events=[ev]) + return _caps def __dlpack_device__(self): """ @@ -1087,7 +1206,7 @@ cdef class usm_ndarray: ) else: return ( - c_dlpack.device_oneAPI, + DLDeviceType.kDLOneAPI, dev_id, ) diff --git a/dpctl/tensor/include/dlpack/README.md b/dpctl/tensor/include/dlpack/README.md index 2c22e9aa8d..3a7bc6d422 100644 --- a/dpctl/tensor/include/dlpack/README.md +++ b/dpctl/tensor/include/dlpack/README.md @@ -1,7 +1,7 @@ # DLPack header -The header `dlpack.h` downloaded from `https://github.com/dmlc/dlpack.git` remote at tag v0.8 commit [`365b823`](https://github.com/dmlc/dlpack/commit/365b823cedb281cd0240ca601aba9b78771f91a3). +The header `dlpack.h` downloaded from `https://github.com/dmlc/dlpack.git` remote at tag v1.0rc commit [`62100c1`](https://github.com/dmlc/dlpack/commit/62100c123144ae7a80061f4220be2dbd3cbaefc7). -The file can also be viewed using github web interface at https://github.com/dmlc/dlpack/blob/e2bdd3bee8cb6501558042633fa59144cc8b7f5f/include/dlpack/dlpack.h +The file can also be viewed using github web interface at https://github.com/dmlc/dlpack/blob/62100c123144ae7a80061f4220be2dbd3cbaefc7/include/dlpack/dlpack.h License file was retrieved from https://github.com/dmlc/dlpack/blob/main/LICENSE diff --git a/dpctl/tensor/include/dlpack/dlpack.h b/dpctl/tensor/include/dlpack/dlpack.h index 672448d1c6..bcb77949a8 100644 --- a/dpctl/tensor/include/dlpack/dlpack.h +++ b/dpctl/tensor/include/dlpack/dlpack.h @@ -15,11 +15,11 @@ #define DLPACK_EXTERN_C #endif -/*! \brief The current version of dlpack */ -#define DLPACK_VERSION 80 +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 -/*! \brief The current ABI version of dlpack */ -#define DLPACK_ABI_VERSION 1 +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 0 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -38,6 +38,33 @@ #ifdef __cplusplus extern "C" { #endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + /*! * \brief The device type in DLDevice. */ @@ -89,6 +116,8 @@ typedef enum { kDLWebGPU = 15, /*! \brief Qualcomm Hexagon DSP */ kDLHexagon = 16, + /*! \brief Microsoft MAIA devices */ + kDLMAIA = 17, } DLDeviceType; /*! @@ -168,7 +197,7 @@ typedef struct { * `byte_offset` field should be used to point to the beginning of the data. * * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte alignment requirement + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. @@ -186,6 +215,9 @@ typedef struct { * return size; * } * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. */ void* data; /*! \brief The device of the tensor */ @@ -211,6 +243,13 @@ typedef struct { * not meant to transfer the tensor. When the borrowing framework doesn't need * the tensor, it should call the deleter to notify the host that the resource * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned */ typedef struct DLManagedTensor { /*! \brief DLTensor which is being memory managed */ @@ -219,13 +258,74 @@ typedef struct DLManagedTensor { * which DLManagedTensor is used in the framework. It can also be NULL. */ void * manager_ctx; - /*! \brief Destructor signature void (*)(void*) - this should be called - * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - * if there is no way for the caller to provide a reasonable destructor. - * The destructors deletes the argument self as well. + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. */ void (*deleter)(struct DLManagedTensor * self); } DLManagedTensor; + +// bit masks used in in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +}; + #ifdef __cplusplus } // DLPACK_EXTERN_C #endif diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 0c15038b47..a0f2414fce 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -148,6 +148,15 @@ def test_usm_ndarray_writable_flag_views(): assert not a.imag.flags.writable +def test_usm_ndarray_from_usm_ndarray_readonly(): + get_queue_or_skip() + + x1 = dpt.arange(10, dtype="f4") + x1.flags["W"] = False + x2 = dpt.usm_ndarray(x1.shape, dtype="f4", buffer=x1) + assert not x2.flags.writable + + @pytest.mark.parametrize( "dtype", [ @@ -2159,9 +2168,6 @@ def test_meshgrid2(): assert z1.shape == z2.shape and z2.shape == z3.shape assert y1.shape == (len(x2), len(x1), len(x3)) assert z1.shape == (len(x1), len(x2), len(x3)) - # FIXME: uncomment out once gh-921 is merged - # assert all(z.flags["C"] for z in (z1, z2, z3)) - # assert all(y.flags["C"] for y in (y1, y2, y3)) def test_common_arg_validation(): diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index 92727f02cd..a4994f01e6 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import ctypes import pytest @@ -21,6 +22,7 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._dlpack as _dlp device_oneAPI = 14 # DLDeviceType.kDLOneAPI @@ -55,8 +57,23 @@ def typestr(request): return request.param -def test_dlpack_device(usm_type): - all_root_devices = dpctl.get_devices() +@pytest.fixture +def all_root_devices(): + """ + Caches root devices. For the sake of speed + of test suite execution, keep at most two + devices from each platform + """ + devs = dpctl.get_devices() + devs_per_platform = collections.defaultdict(list) + for dev in devs: + devs_per_platform[dev.sycl_platform].append(dev) + + pruned = map(lambda li: li[:2], devs_per_platform.values()) + return sum(pruned, start=[]) + + +def test_dlpack_device(usm_type, all_root_devices): for sycl_dev in all_root_devices: X = dpt.empty((64,), dtype="u1", usm_type=usm_type, device=sycl_dev) dev = X.__dlpack_device__() @@ -66,11 +83,10 @@ def test_dlpack_device(usm_type): assert sycl_dev == all_root_devices[dev[1]] -def test_dlpack_exporter(typestr, usm_type): +def test_dlpack_exporter(typestr, usm_type, all_root_devices): caps_fn = ctypes.pythonapi.PyCapsule_IsValid caps_fn.restype = bool caps_fn.argtypes = [ctypes.py_object, ctypes.c_char_p] - all_root_devices = dpctl.get_devices() for sycl_dev in all_root_devices: skip_if_dtype_not_supported(typestr, sycl_dev) X = dpt.empty((64,), dtype=typestr, usm_type=usm_type, device=sycl_dev) @@ -119,8 +135,7 @@ def test_dlpack_exporter_stream(): @pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)]) -def test_from_dlpack(shape, typestr, usm_type): - all_root_devices = dpctl.get_devices() +def test_from_dlpack(shape, typestr, usm_type, all_root_devices): for sycl_dev in all_root_devices: skip_if_dtype_not_supported(typestr, sycl_dev) X = dpt.empty(shape, dtype=typestr, usm_type=usm_type, device=sycl_dev) @@ -139,8 +154,7 @@ def test_from_dlpack(shape, typestr, usm_type): @pytest.mark.parametrize("mod", [2, 5]) -def test_from_dlpack_strides(mod, typestr, usm_type): - all_root_devices = dpctl.get_devices() +def test_from_dlpack_strides(mod, typestr, usm_type, all_root_devices): for sycl_dev in all_root_devices: skip_if_dtype_not_supported(typestr, sycl_dev) X0 = dpt.empty( @@ -163,8 +177,8 @@ def test_from_dlpack_strides(mod, typestr, usm_type): def test_from_dlpack_input_validation(): - vstr = dpt._dlpack.get_build_dlpack_version() - assert type(vstr) is str + v = dpt._dlpack.get_build_dlpack_version() + assert type(v) is tuple with pytest.raises(TypeError): dpt.from_dlpack(None) @@ -215,9 +229,8 @@ def test_dlpack_from_subdevice(): except dpctl.SyclSubDeviceCreationError: sdevs = None try: - sdevs = ( - dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs - ) + if sdevs is None: + sdevs = dev.create_sub_devices(partition=[1, 1]) except dpctl.SyclSubDeviceCreationError: pytest.skip("Default device can not be partitioned") assert isinstance(sdevs, list) and len(sdevs) > 0 @@ -233,3 +246,227 @@ def test_dlpack_from_subdevice(): ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q) ar2 = dpt.from_dlpack(ar) assert ar2.sycl_device == sdevs[0] + + +def test_legacy_dlpack_capsule(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + legacy_ver = (0, 8) + + cap = x.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + del cap + assert x._pointer == y._pointer + + x = dpt.arange(100, dtype="u4") + x2 = dpt.reshape(x, (10, 10)).mT + cap = x2.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + del cap + assert x2._pointer == y._pointer + del x2 + + x = dpt.arange(100, dtype="f4") + x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F") + cap = x2.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + del cap + assert x2._pointer == y._pointer + + x = dpt.arange(100, dtype="c8") + x3 = x[::-2] + cap = x3.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + assert x3._pointer == y._pointer + del x3, y, x + del cap + + x = dpt.ones(100, dtype="?") + x4 = x[::-2] + cap = x4.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + assert x4._pointer == y._pointer + del x4, y, x + del cap + + +def test_versioned_dlpack_capsule(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + max_supported_ver = _dlp.get_build_dlpack_version() + cap = x.__dlpack__(max_version=max_supported_ver) + y = _dlp.from_dlpack_versioned_capsule(cap) + del cap + assert x._pointer == y._pointer + + x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F") + cap = x2.__dlpack__(max_version=max_supported_ver) + y = _dlp.from_dlpack_versioned_capsule(cap) + del cap + assert x2._pointer == y._pointer + del x2 + + x3 = x[::-2] + cap = x3.__dlpack__(max_version=max_supported_ver) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert x3._pointer == y._pointer + del x3, y, x + del cap + + # read-only array + x = dpt.arange(100, dtype="i4") + x.flags["W"] = False + cap = x.__dlpack__(max_version=max_supported_ver) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert x._pointer == y._pointer + assert not y.flags.writable + + # read-only array, and copy + cap = x.__dlpack__(max_version=max_supported_ver, copy=True) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert x._pointer != y._pointer + assert not y.flags.writable + + +def test_from_dlpack_kwargs(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + y = dpt.from_dlpack(x, copy=True) + assert x._pointer != y._pointer + + z = dpt.from_dlpack(x, device=x.sycl_device) + assert z._pointer == x._pointer + + +def test_dlpack_deleters(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + legacy_ver = (0, 8) + cap = x.__dlpack__(max_version=legacy_ver) + del cap + + max_supported_ver = _dlp.get_build_dlpack_version() + cap = x.__dlpack__(max_version=max_supported_ver) + del cap + + +def test_from_dlpack_device(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + out = dpt.from_dlpack(x, device=x.__dlpack_device__()) + assert x.device == out.device + assert x._pointer == out._pointer + + out = dpt.from_dlpack(x, device=x.device) + assert x.device == out.device + assert x._pointer == out._pointer + + out = dpt.from_dlpack(x, device=x.sycl_device) + assert x.device == out.device + assert x._pointer == out._pointer + + +def test_used_dlpack_capsule(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + legacy_ver = (0, 8) + cap = x.__dlpack__(max_version=legacy_ver) + _dlp.from_dlpack_capsule(cap) + with pytest.raises( + ValueError, + match="A DLPack tensor object can not be consumed multiple times", + ): + _dlp.from_dlpack_capsule(cap) + del cap + + max_supported_ver = _dlp.get_build_dlpack_version() + cap = x.__dlpack__(max_version=max_supported_ver) + _dlp.from_dlpack_versioned_capsule(cap) + with pytest.raises( + ValueError, + match="A DLPack tensor object can not be consumed multiple times", + ): + _dlp.from_dlpack_versioned_capsule(cap) + del cap + + +def test_dlpack_size_0(): + try: + x = dpt.ones(0, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + legacy_ver = (0, 8) + cap = x.__dlpack__(max_version=legacy_ver) + y = _dlp.from_dlpack_capsule(cap) + assert y._pointer == x._pointer + + max_supported_ver = _dlp.get_build_dlpack_version() + cap = x.__dlpack__(max_version=max_supported_ver) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert y._pointer == x._pointer + + +def test_dlpack_max_version_validation(): + try: + x = dpt.ones(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + with pytest.raises( + TypeError, + match=r"`__dlpack__` expects `max_version` to be a " + r"2-tuple of integers `\(major, minor\)`, instead " + r"got .*", + ): + x.__dlpack__(max_version=1) + + +def test_dlpack_kwargs(): + try: + q1 = dpctl.SyclQueue() + q2 = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Could not create default queues") + x = dpt.arange(100, dtype="i4", sycl_queue=q1) + + legacy_ver = (0, 8) + cap = x.__dlpack__(stream=q2, max_version=legacy_ver, copy=True) + # `copy` ignored for legacy path + y = _dlp.from_dlpack_capsule(cap) + assert y._pointer == x._pointer + del x, y + del cap + + x1 = dpt.arange(100, dtype="i4", sycl_queue=q1) + max_supported_ver = _dlp.get_build_dlpack_version() + cap = x1.__dlpack__(stream=q2, max_version=max_supported_ver, copy=False) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert y._pointer == x1._pointer + del x1, y + del cap + + x2 = dpt.arange(100, dtype="i4", sycl_queue=q1) + cap = x2.__dlpack__(stream=q2, max_version=max_supported_ver, copy=True) + y = _dlp.from_dlpack_versioned_capsule(cap) + assert y._pointer != x2._pointer + del x2, y + del cap