Skip to content

Commit ae7a445

Browse files
usm_ndarray API extended by GetOffset function (#769)
also make type constants and flags API symbols as well, so that pybind11 extensions can use them.
1 parent fb6b54e commit ae7a445

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

dpctl/tensor/_usmarray.pxd

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,27 @@
44
cimport dpctl
55

66

7-
cdef public int USM_ARRAY_C_CONTIGUOUS
8-
cdef public int USM_ARRAY_F_CONTIGUOUS
9-
cdef public int USM_ARRAY_WRITEABLE
10-
11-
cdef public int UAR_BOOL
12-
cdef public int UAR_BYTE
13-
cdef public int UAR_UBYTE
14-
cdef public int UAR_SHORT
15-
cdef public int UAR_USHORT
16-
cdef public int UAR_INT
17-
cdef public int UAR_UINT
18-
cdef public int UAR_LONG
19-
cdef public int UAR_ULONG
20-
cdef public int UAR_LONGLONG
21-
cdef public int UAR_ULONGLONG
22-
cdef public int UAR_FLOAT
23-
cdef public int UAR_DOUBLE
24-
cdef public int UAR_CFLOAT
25-
cdef public int UAR_CDOUBLE
26-
cdef public int UAR_TYPE_SENTINEL
27-
cdef public int UAR_HALF
7+
cdef public api int USM_ARRAY_C_CONTIGUOUS
8+
cdef public api int USM_ARRAY_F_CONTIGUOUS
9+
cdef public api int USM_ARRAY_WRITEABLE
10+
11+
cdef public api int UAR_BOOL
12+
cdef public api int UAR_BYTE
13+
cdef public api int UAR_UBYTE
14+
cdef public api int UAR_SHORT
15+
cdef public api int UAR_USHORT
16+
cdef public api int UAR_INT
17+
cdef public api int UAR_UINT
18+
cdef public api int UAR_LONG
19+
cdef public api int UAR_ULONG
20+
cdef public api int UAR_LONGLONG
21+
cdef public api int UAR_ULONGLONG
22+
cdef public api int UAR_FLOAT
23+
cdef public api int UAR_DOUBLE
24+
cdef public api int UAR_CFLOAT
25+
cdef public api int UAR_CDOUBLE
26+
cdef public api int UAR_TYPE_SENTINEL
27+
cdef public api int UAR_HALF
2828

2929

3030
cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,3 +1174,9 @@ cdef api int UsmNDArray_GetFlags(usm_ndarray arr):
11741174
cdef api c_dpctl.DPCTLSyclQueueRef UsmNDArray_GetQueueRef(usm_ndarray arr):
11751175
"""Get DPCTLSyclQueueRef for queue associated with the array"""
11761176
return arr.get_queue_ref()
1177+
1178+
1179+
cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
1180+
"""Get offset of zero-index array element from the beginning of the USM
1181+
allocation."""
1182+
return arr.get_offset()

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,19 @@ def test_pyx_capi_get_flags():
427427
assert type(flags) is int and flags == X.flags
428428

429429

430+
def test_pyx_capi_get_offset():
431+
X = dpt.usm_ndarray(17)[1::2]
432+
get_offset_fn = _pyx_capi_fnptr_to_callable(
433+
X,
434+
"UsmNDArray_GetOffset",
435+
b"Py_ssize_t (struct PyUSMArrayObject *)",
436+
fn_restype=ctypes.c_longlong,
437+
)
438+
offset = get_offset_fn(X)
439+
assert type(offset) is int
440+
assert offset == X.__sycl_usm_array_interface__["offset"]
441+
442+
430443
def test_pyx_capi_get_queue_ref():
431444
X = dpt.usm_ndarray(17)[1::2]
432445
get_queue_ref_fn = _pyx_capi_fnptr_to_callable(

0 commit comments

Comments
 (0)