Skip to content

Commit 61430eb

Browse files
Merge pull request #1050 from IntelPython/add-usm-ndarray-creation-c-api
Add usm_ndarray creation c-api
2 parents 9e37ecb + 5145923 commit 61430eb

File tree

4 files changed

+222
-7
lines changed

4 files changed

+222
-7
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ class dpctl_capi
114114
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
115115
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
116116
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
117+
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
118+
PyObject *(*UsmNDArray_MakeFromMemory_)(int,
119+
const py::ssize_t *,
120+
int,
121+
Py_MemoryObject *,
122+
py::ssize_t,
123+
char);
124+
PyObject *(*UsmNDArray_MakeFromPtr_)(size_t,
125+
int,
126+
DPCTLSyclUSMRef,
127+
DPCTLSyclQueueRef,
128+
PyObject *);
117129

118130
int USM_ARRAY_C_CONTIGUOUS_;
119131
int USM_ARRAY_F_CONTIGUOUS_;
@@ -220,11 +232,13 @@ class dpctl_capi
220232
UsmNDArray_GetShape_(nullptr), UsmNDArray_GetStrides_(nullptr),
221233
UsmNDArray_GetTypenum_(nullptr), UsmNDArray_GetElementSize_(nullptr),
222234
UsmNDArray_GetFlags_(nullptr), UsmNDArray_GetQueueRef_(nullptr),
223-
UsmNDArray_GetOffset_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
224-
USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
225-
UAR_SHORT_(-1), UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1),
226-
UAR_LONG_(-1), UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1),
227-
UAR_FLOAT_(-1), UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
235+
UsmNDArray_GetOffset_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
236+
UsmNDArray_MakeFromMemory_(nullptr), UsmNDArray_MakeFromPtr_(nullptr),
237+
USM_ARRAY_C_CONTIGUOUS_(0), USM_ARRAY_F_CONTIGUOUS_(0),
238+
USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1), UAR_SHORT_(-1),
239+
UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1),
240+
UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
241+
UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
228242
UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
229243
UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
230244
UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
@@ -295,6 +309,9 @@ class dpctl_capi
295309
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
296310
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
297311
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
312+
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
313+
this->UsmNDArray_MakeFromMemory_ = UsmNDArray_MakeFromMemory;
314+
this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
298315

299316
// constants
300317
this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;

dpctl/tensor/_usmarray.pyx

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,3 +1308,55 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
13081308
"""Get offset of zero-index array element from the beginning of the USM
13091309
allocation"""
13101310
return arr.get_offset()
1311+
1312+
cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
1313+
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
1314+
cdef int arr_fl = arr.flags_
1315+
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
1316+
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
1317+
arr.flags_ = arr_fl
1318+
1319+
cdef api object UsmNDArray_MakeFromMemory(
1320+
int nd, const Py_ssize_t *shape, int typenum,
1321+
c_dpmem._Memory mobj, Py_ssize_t offset, char order
1322+
):
1323+
"""Create usm_ndarray.
1324+
1325+
Equivalent to usm_ndarray(
1326+
_make_tuple(nd, shape), dtype=_make_dtype(typenum),
1327+
buffer=mobj, offset=offset)
1328+
"""
1329+
cdef object shape_tuple = _make_int_tuple(nd, <Py_ssize_t *>shape)
1330+
cdef usm_ndarray arr = usm_ndarray(
1331+
shape_tuple,
1332+
dtype=_make_typestr(typenum),
1333+
buffer=mobj,
1334+
offset=offset,
1335+
order=<bytes>(order)
1336+
)
1337+
return arr
1338+
1339+
1340+
cdef api object UsmNDArray_MakeFromPtr(
1341+
size_t nelems,
1342+
int typenum,
1343+
c_dpctl.DPCTLSyclUSMRef ptr,
1344+
c_dpctl.DPCTLSyclQueueRef QRef,
1345+
object owner
1346+
):
1347+
"""Create usm_ndarray from pointer.
1348+
1349+
Argument owner=None implies transert of USM allocation ownership
1350+
to create array object.
1351+
"""
1352+
cdef size_t itemsize = type_bytesize(typenum)
1353+
cdef size_t nbytes = itemsize * nelems
1354+
cdef c_dpmem._Memory mobj = c_dpmem._Memory.create_from_usm_pointer_size_qref(
1355+
ptr, nbytes, QRef, memory_owner=owner
1356+
)
1357+
cdef usm_ndarray arr = usm_ndarray(
1358+
(nelems,),
1359+
dtype=_make_typestr(typenum),
1360+
buffer=mobj
1361+
)
1362+
return arr

dpctl/tests/test_sycl_usm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def test_cpython_api(memory_ctor):
545545
mem_q_ref_fn_cap = mod.__pyx_capi__["Memory_GetQueueRef"]
546546
mem_ctx_ref_fn_cap = mod.__pyx_capi__["Memory_GetContextRef"]
547547
mem_nby_fn_cap = mod.__pyx_capi__["Memory_GetNumBytes"]
548+
mem_make_fn_cap = mod.__pyx_capi__["Memory_Make"]
548549
# construct Python callable to invoke functions
549550
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
550551
cap_ptr_fn.restype = ctypes.c_void_p
@@ -561,11 +562,23 @@ def test_cpython_api(memory_ctor):
561562
mem_nby_fn_ptr = cap_ptr_fn(
562563
mem_nby_fn_cap, b"size_t (struct Py_MemoryObject *)"
563564
)
565+
mem_make_fn_ptr = cap_ptr_fn(
566+
mem_make_fn_cap,
567+
b"PyObject *(DPCTLSyclUSMRef, size_t, DPCTLSyclQueueRef, PyObject *)",
568+
)
564569
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
565570
get_ptr_fn = callable_maker(mem_ptr_fn_ptr)
566571
get_ctx_ref_fn = callable_maker(mem_ctx_ref_fn_ptr)
567572
get_q_ref_fn = callable_maker(mem_q_ref_fn_ptr)
568573
get_nby_fn = callable_maker(mem_nby_fn_ptr)
574+
make_callable_maker = ctypes.PYFUNCTYPE(
575+
ctypes.py_object,
576+
ctypes.c_void_p,
577+
ctypes.c_size_t,
578+
ctypes.c_void_p,
579+
ctypes.py_object,
580+
)
581+
make_fn = make_callable_maker(mem_make_fn_ptr)
569582

570583
capi_ptr = get_ptr_fn(mobj)
571584
direct_ptr = mobj._pointer
@@ -580,6 +593,15 @@ def test_cpython_api(memory_ctor):
580593
direct_nbytes = mobj.nbytes
581594
assert capi_nbytes == direct_nbytes
582595

596+
mobj2 = make_fn(
597+
mobj._pointer,
598+
ctypes.c_size_t(mobj.nbytes),
599+
mobj.sycl_queue.addressof_ref(),
600+
mobj,
601+
)
602+
assert mobj2._pointer == mobj._pointer
603+
assert mobj2.reference_obj is mobj
604+
583605

584606
def test_memory_construction_from_other_memory_objects():
585607
try:

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,11 @@ def test_datapi_device():
373373

374374

375375
def _pyx_capi_fnptr_to_callable(
376-
X, pyx_capi_name, caps_name, fn_restype=ctypes.c_void_p
376+
X,
377+
pyx_capi_name,
378+
caps_name,
379+
fn_restype=ctypes.c_void_p,
380+
fn_argtypes=(ctypes.py_object,),
377381
):
378382
import sys
379383

@@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
388392
cap_ptr_fn.restype = ctypes.c_void_p
389393
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
390394
fn_ptr = cap_ptr_fn(cap, caps_name)
391-
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, ctypes.py_object)
395+
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, *fn_argtypes)
392396
return callable_maker_ptr(fn_ptr)
393397

394398

@@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
399403
"UsmNDArray_GetData",
400404
b"char *(struct PyUSMArrayObject *)",
401405
fn_restype=ctypes.c_void_p,
406+
fn_argtypes=(ctypes.py_object,),
402407
)
403408
r1 = get_data_fn(X)
404409
sua_iface = X.__sycl_usm_array_interface__
@@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
412417
"UsmNDArray_GetShape",
413418
b"Py_ssize_t *(struct PyUSMArrayObject *)",
414419
fn_restype=ctypes.c_void_p,
420+
fn_argtypes=(ctypes.py_object,),
415421
)
416422
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
417423
shape0 = ctypes.cast(get_shape_fn(X), c_longlong_p).contents.value
@@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
425431
"UsmNDArray_GetStrides",
426432
b"Py_ssize_t *(struct PyUSMArrayObject *)",
427433
fn_restype=ctypes.c_void_p,
434+
fn_argtypes=(ctypes.py_object,),
428435
)
429436
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
430437
strides0_p = get_strides_fn(X)
@@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
441448
"UsmNDArray_GetNDim",
442449
b"int (struct PyUSMArrayObject *)",
443450
fn_restype=ctypes.c_int,
451+
fn_argtypes=(ctypes.py_object,),
444452
)
445453
assert get_ndim_fn(X) == X.ndim
446454

@@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
452460
"UsmNDArray_GetTypenum",
453461
b"int (struct PyUSMArrayObject *)",
454462
fn_restype=ctypes.c_int,
463+
fn_argtypes=(ctypes.py_object,),
455464
)
456465
typenum = get_typenum_fn(X)
457466
assert type(typenum) is int
@@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
465474
"UsmNDArray_GetElementSize",
466475
b"int (struct PyUSMArrayObject *)",
467476
fn_restype=ctypes.c_int,
477+
fn_argtypes=(ctypes.py_object,),
468478
)
469479
itemsize = get_elemsize_fn(X)
470480
assert type(itemsize) is int
@@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
478488
"UsmNDArray_GetFlags",
479489
b"int (struct PyUSMArrayObject *)",
480490
fn_restype=ctypes.c_int,
491+
fn_argtypes=(ctypes.py_object,),
481492
)
482493
flags = get_flags_fn(X)
483494
assert type(flags) is int and X.flags == flags
@@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
490501
"UsmNDArray_GetOffset",
491502
b"Py_ssize_t (struct PyUSMArrayObject *)",
492503
fn_restype=ctypes.c_longlong,
504+
fn_argtypes=(ctypes.py_object,),
493505
)
494506
offset = get_offset_fn(X)
495507
assert type(offset) is int
@@ -503,11 +515,123 @@ def test_pyx_capi_get_queue_ref():
503515
"UsmNDArray_GetQueueRef",
504516
b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)",
505517
fn_restype=ctypes.c_void_p,
518+
fn_argtypes=(ctypes.py_object,),
506519
)
507520
queue_ref = get_queue_ref_fn(X) # address of a copy, should be unequal
508521
assert queue_ref != X.sycl_queue.addressof_ref()
509522

510523

524+
def test_pyx_capi_make_from_memory():
525+
q = get_queue_or_skip()
526+
n0, n1 = 4, 6
527+
c_tuple = (ctypes.c_ssize_t * 2)(n0, n1)
528+
mem = dpm.MemoryUSMShared(n0 * n1 * 4, queue=q)
529+
typenum = dpt.dtype("single").num
530+
any_usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
531+
make_from_memory_fn = _pyx_capi_fnptr_to_callable(
532+
any_usm_ndarray,
533+
"UsmNDArray_MakeFromMemory",
534+
b"PyObject *(int, Py_ssize_t const *, int, "
535+
b"struct Py_MemoryObject *, Py_ssize_t, char)",
536+
fn_restype=ctypes.py_object,
537+
fn_argtypes=(
538+
ctypes.c_int,
539+
ctypes.POINTER(ctypes.c_ssize_t),
540+
ctypes.c_int,
541+
ctypes.py_object,
542+
ctypes.c_ssize_t,
543+
ctypes.c_char,
544+
),
545+
)
546+
r = make_from_memory_fn(
547+
ctypes.c_int(2),
548+
c_tuple,
549+
ctypes.c_int(typenum),
550+
mem,
551+
ctypes.c_ssize_t(0),
552+
ctypes.c_char(b"C"),
553+
)
554+
assert isinstance(r, dpt.usm_ndarray)
555+
assert r.ndim == 2
556+
assert r.shape == (n0, n1)
557+
assert r._pointer == mem._pointer
558+
assert r.usm_type == "shared"
559+
assert r.sycl_queue == q
560+
assert r.flags["C"]
561+
r2 = make_from_memory_fn(
562+
ctypes.c_int(2),
563+
c_tuple,
564+
ctypes.c_int(typenum),
565+
mem,
566+
ctypes.c_ssize_t(0),
567+
ctypes.c_char(b"F"),
568+
)
569+
ptr = mem._pointer
570+
del mem
571+
del r
572+
assert isinstance(r2, dpt.usm_ndarray)
573+
assert r2._pointer == ptr
574+
assert r2.usm_type == "shared"
575+
assert r2.sycl_queue == q
576+
assert r2.flags["F"]
577+
578+
579+
def test_pyx_capi_set_writable_flag():
580+
q = get_queue_or_skip()
581+
usm_ndarray = dpt.empty((4, 5), dtype="i4", sycl_queue=q)
582+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
583+
assert usm_ndarray.flags["WRITABLE"] is True
584+
set_writable = _pyx_capi_fnptr_to_callable(
585+
usm_ndarray,
586+
"UsmNDArray_SetWritableFlag",
587+
b"void (struct PyUSMArrayObject *, int)",
588+
fn_restype=None,
589+
fn_argtypes=(ctypes.py_object, ctypes.c_int),
590+
)
591+
set_writable(usm_ndarray, ctypes.c_int(0))
592+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
593+
assert usm_ndarray.flags["WRITABLE"] is False
594+
set_writable(usm_ndarray, ctypes.c_int(1))
595+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
596+
assert usm_ndarray.flags["WRITABLE"] is True
597+
598+
599+
def test_pyx_capi_make_from_ptr():
600+
q = get_queue_or_skip()
601+
usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
602+
make_from_ptr = _pyx_capi_fnptr_to_callable(
603+
usm_ndarray,
604+
"UsmNDArray_MakeFromPtr",
605+
b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
606+
b"DPCTLSyclQueueRef, PyObject *)",
607+
fn_restype=ctypes.py_object,
608+
fn_argtypes=(
609+
ctypes.c_size_t,
610+
ctypes.c_int,
611+
ctypes.c_void_p,
612+
ctypes.c_void_p,
613+
ctypes.py_object,
614+
),
615+
)
616+
nelems = 10
617+
dt = dpt.int64
618+
mem = dpm.MemoryUSMDevice(nelems * dt.itemsize, queue=q)
619+
arr = make_from_ptr(
620+
ctypes.c_size_t(nelems),
621+
dt.num,
622+
mem._pointer,
623+
mem.sycl_queue.addressof_ref(),
624+
mem,
625+
)
626+
assert isinstance(arr, dpt.usm_ndarray)
627+
assert arr.shape == (nelems,)
628+
assert arr.dtype == dt
629+
assert arr.sycl_queue == q
630+
assert arr._pointer == mem._pointer
631+
del mem
632+
assert isinstance(arr.__repr__(), str)
633+
634+
511635
def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
512636
import sys
513637

0 commit comments

Comments
 (0)