From c92b50385657a3665232ed6a2929fca297cd0460 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 1 Aug 2023 01:58:36 -0500 Subject: [PATCH 1/2] Closes gh-1241 The DLPack exporter functionality was not populating strides information for F-contiguous arrays where usm_ndarray may carry null strides. This PR changes that fixing the reported bug. --- dpctl/tensor/_dlpack.pyx | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index 62ab1ca7e0..9005886e31 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -32,7 +32,7 @@ from .._backend cimport ( DPCTLSyclDeviceRef, DPCTLSyclUSMRef, ) -from ._usmarray cimport usm_ndarray +from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, usm_ndarray from platform import system as sys_platform @@ -158,9 +158,11 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): cdef int64_t *shape_strides_ptr = NULL cdef int i = 0 cdef int device_id = -1 + cdef int flags = 0 cdef char *base_ptr = NULL cdef Py_ssize_t element_offset = 0 cdef Py_ssize_t byte_offset = 0 + cdef Py_ssize_t si = 1 ary_base = usm_ary.get_base() ary_sycl_queue = usm_ary.get_sycl_queue() @@ -223,9 +225,17 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): for i in range(nd): shape_strides_ptr[i] = shape_ptr[i] strides_ptr = usm_ary.get_strides() + flags = usm_ary.flags_ if strides_ptr: for i in range(nd): shape_strides_ptr[nd + i] = strides_ptr[i] + else: + if not (flags & USM_ARRAY_C_CONTIGUOUS): + si = 1 + for i in range(0, nd): + shape_strides_ptr[nd + i] = si + si = si * shape_ptr[i] + strides_ptr = &shape_strides_ptr[nd] ary_dt = usm_ary.dtype ary_dtk = ary_dt.kind From 9d2c969c1e8ef71ebc79dffdbc18164f82d4609d Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 1 Aug 2023 02:01:14 -0500 Subject: [PATCH 2/2] Add tests based on gh-1241 example --- dpctl/tests/test_usm_ndarray_dlpack.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index 5801ddf2d1..c82a27807c 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -178,3 +178,18 @@ def __dlpack__(self): with pytest.raises(TypeError): dpt.from_dlpack(DummyWithMethod()) + + +def test_from_dlpack_fortran_contig_array_roundtripping(): + """Based on examples from issue gh-1241""" + n0, n1 = 3, 5 + try: + ar1d = dpt.arange(n0 * n1, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + ar2d_c = dpt.reshape(ar1d, (n0, n1), order="C") + ar2d_f = dpt.asarray(ar2d_c, order="F") + ar2d_r = dpt.from_dlpack(ar2d_f) + + assert dpt.all(dpt.equal(ar2d_f, ar2d_r)) + assert dpt.all(dpt.equal(ar2d_c, ar2d_r))