Skip to content

Commit a2b6a9c

Browse files
C-api functions for memory/tensor objects(#740)
* Added more public API methods for _Memory object * changed test to reflect change in public API function name change * Added Memory C-API functions to dpctl_capi.h Renamed functions to avoid name clashes for get_queue_ref/get_context_ref defined for different signatures. (C-API requires names to be different). Adjsuted tests to reflect changes in C-API names * Renamed Python C-API functions to ObjectName_CamelCasedFunctionName * Renamed usm_ndarray's C-API functions as UsmNDArray_* * dpctl_capi imports tensor submodule as well as part of importing dpctl * Added docstring as per PR feedback * populated docstrings of C-API functions of usmarray
2 parents 2614bb9 + 50f16f4 commit a2b6a9c

14 files changed

+116
-85
lines changed

dpctl/_sycl_context.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,15 +480,15 @@ cdef class SyclContext(_SyclContext):
480480
&_context_capsule_deleter
481481
)
482482

483-
cdef api DPCTLSyclContextRef get_context_ref(SyclContext ctx):
483+
cdef api DPCTLSyclContextRef SyclContext_GetContextRef(SyclContext ctx):
484484
"""
485485
C-API function to get opaque context reference from
486486
:class:`dpctl.SyclContext` instance.
487487
"""
488488
return ctx.get_context_ref()
489489

490490

491-
cdef api SyclContext make_SyclContext(DPCTLSyclContextRef CRef):
491+
cdef api SyclContext SyclContext_Make(DPCTLSyclContextRef CRef):
492492
"""
493493
C-API function to create :class:`dpctl.SyclContext` instance
494494
from the given opaque context reference.

dpctl/_sycl_device.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,15 +1130,15 @@ cdef class SyclDevice(_SyclDevice):
11301130
else:
11311131
return str(relId)
11321132

1133-
cdef api DPCTLSyclDeviceRef get_device_ref(SyclDevice dev):
1133+
cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
11341134
"""
11351135
C-API function to get opaque device reference from
11361136
:class:`dpctl.SyclDevice` instance.
11371137
"""
11381138
return dev.get_device_ref()
11391139

11401140

1141-
cdef api SyclDevice make_SyclDevice(DPCTLSyclDeviceRef DRef):
1141+
cdef api SyclDevice SyclDevice_Make(DPCTLSyclDeviceRef DRef):
11421142
"""
11431143
C-API function to create :class:`dpctl.SyclDevice` instance
11441144
from the given opaque device reference.

dpctl/_sycl_event.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ __all__ = [
5757
_logger = logging.getLogger(__name__)
5858

5959

60-
cdef api DPCTLSyclEventRef get_event_ref(SyclEvent ev):
60+
cdef api DPCTLSyclEventRef SyclEvent_GetEventRef(SyclEvent ev):
6161
""" C-API function to access opaque event reference from
6262
Python object of type :class:`dpctl.SyclEvent`.
6363
"""
6464
return ev.get_event_ref()
6565

6666

67-
cdef api SyclEvent make_SyclEvent(DPCTLSyclEventRef ERef):
67+
cdef api SyclEvent SyclEvent_Make(DPCTLSyclEventRef ERef):
6868
"""
6969
C-API function to create :class:`dpctl.SyclEvent`
7070
instance from opaque sycl event reference.

dpctl/_sycl_queue.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,15 +1001,15 @@ cdef class SyclQueue(_SyclQueue):
10011001
self.sycl_device.print_device_info()
10021002

10031003

1004-
cdef api DPCTLSyclQueueRef get_queue_ref(SyclQueue q):
1004+
cdef api DPCTLSyclQueueRef SyclQueue_GetQueueRef(SyclQueue q):
10051005
"""
10061006
C-API function to get opaque queue reference from
10071007
:class:`dpctl.SyclQueue` instance.
10081008
"""
10091009
return q.get_queue_ref()
10101010

10111011

1012-
cdef api SyclQueue make_SyclQueue(DPCTLSyclQueueRef QRef):
1012+
cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
10131013
"""
10141014
C-API function to create :class:`dpctl.SyclQueue` instance
10151015
from the given opaque queue reference.

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ template <> struct type_caster<sycl::queue>
4949
{
5050
PyObject *source = src.ptr();
5151
if (PyObject_TypeCheck(source, &PySyclQueueType)) {
52-
DPCTLSyclQueueRef QRef =
53-
get_queue_ref(reinterpret_cast<PySyclQueueObject *>(source));
52+
DPCTLSyclQueueRef QRef = SyclQueue_GetQueueRef(
53+
reinterpret_cast<PySyclQueueObject *>(source));
5454
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
5555
value = *q;
5656
return true;
@@ -63,7 +63,7 @@ template <> struct type_caster<sycl::queue>
6363

6464
static handle cast(sycl::queue src, return_value_policy, handle)
6565
{
66-
auto tmp = make_SyclQueue(reinterpret_cast<DPCTLSyclQueueRef>(&src));
66+
auto tmp = SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&src));
6767
return handle(reinterpret_cast<PyObject *>(tmp));
6868
}
6969
};
@@ -87,8 +87,8 @@ template <> struct type_caster<sycl::device>
8787
{
8888
PyObject *source = src.ptr();
8989
if (PyObject_TypeCheck(source, &PySyclDeviceType)) {
90-
DPCTLSyclDeviceRef DRef =
91-
get_device_ref(reinterpret_cast<PySyclDeviceObject *>(source));
90+
DPCTLSyclDeviceRef DRef = SyclDevice_GetDeviceRef(
91+
reinterpret_cast<PySyclDeviceObject *>(source));
9292
sycl::device *d = reinterpret_cast<sycl::device *>(DRef);
9393
value = *d;
9494
return true;
@@ -101,7 +101,7 @@ template <> struct type_caster<sycl::device>
101101

102102
static handle cast(sycl::device src, return_value_policy, handle)
103103
{
104-
auto tmp = make_SyclDevice(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
104+
auto tmp = SyclDevice_Make(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
105105
return handle(reinterpret_cast<PyObject *>(tmp));
106106
}
107107
};
@@ -125,7 +125,7 @@ template <> struct type_caster<sycl::context>
125125
{
126126
PyObject *source = src.ptr();
127127
if (PyObject_TypeCheck(source, &PySyclContextType)) {
128-
DPCTLSyclContextRef CRef = get_context_ref(
128+
DPCTLSyclContextRef CRef = SyclContext_GetContextRef(
129129
reinterpret_cast<PySyclContextObject *>(source));
130130
sycl::context *ctx = reinterpret_cast<sycl::context *>(CRef);
131131
value = *ctx;
@@ -140,7 +140,7 @@ template <> struct type_caster<sycl::context>
140140
static handle cast(sycl::context src, return_value_policy, handle)
141141
{
142142
auto tmp =
143-
make_SyclContext(reinterpret_cast<DPCTLSyclContextRef>(&src));
143+
SyclContext_Make(reinterpret_cast<DPCTLSyclContextRef>(&src));
144144
return handle(reinterpret_cast<PyObject *>(tmp));
145145
}
146146
};
@@ -164,8 +164,8 @@ template <> struct type_caster<sycl::event>
164164
{
165165
PyObject *source = src.ptr();
166166
if (PyObject_TypeCheck(source, &PySyclEventType)) {
167-
DPCTLSyclEventRef ERef =
168-
get_event_ref(reinterpret_cast<PySyclEventObject *>(source));
167+
DPCTLSyclEventRef ERef = SyclEvent_GetEventRef(
168+
reinterpret_cast<PySyclEventObject *>(source));
169169
sycl::event *ev = reinterpret_cast<sycl::event *>(ERef);
170170
value = *ev;
171171
return true;
@@ -178,7 +178,7 @@ template <> struct type_caster<sycl::event>
178178

179179
static handle cast(sycl::event src, return_value_policy, handle)
180180
{
181-
auto tmp = make_SyclEvent(reinterpret_cast<DPCTLSyclEventRef>(&src));
181+
auto tmp = SyclEvent_Make(reinterpret_cast<DPCTLSyclEventRef>(&src));
182182
return handle(reinterpret_cast<PyObject *>(tmp));
183183
}
184184
};

dpctl/apis/include/dpctl_capi.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
#include "../_sycl_event_api.h"
3737
#include "../_sycl_queue.h"
3838
#include "../_sycl_queue_api.h"
39+
#include "../memory/_memory.h"
40+
#include "../memory/_memory_api.h"
41+
#include "../tensor/_usmarray.h"
42+
#include "../tensor/_usmarray_api.h"
3943
// clang-format on
4044

4145
/*
@@ -50,6 +54,7 @@ void import_dpctl(void)
5054
import_dpctl___sycl_context();
5155
import_dpctl___sycl_event();
5256
import_dpctl___sycl_queue();
53-
57+
import_dpctl__memory___memory();
58+
import_dpctl__tensor___usmarray();
5459
return;
5560
}

dpctl/memory/_memory.pyx

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -751,11 +751,30 @@ def as_usm_memory(obj):
751751
)
752752

753753

754-
cdef api DPCTLSyclUSMRef get_usm_pointer(_Memory obj):
754+
cdef api DPCTLSyclUSMRef Memory_GetUsmPointer(_Memory obj):
755+
"Pointer of USM allocation"
755756
return obj.memory_ptr
756757

757-
cdef api DPCTLSyclContextRef get_context(_Memory obj):
758+
cdef api DPCTLSyclContextRef Memory_GetContextRef(_Memory obj):
759+
"Context reference to which USM allocation is bound"
758760
return obj.queue._context.get_context_ref()
759761

760-
cdef api size_t get_nbytes(_Memory obj):
762+
cdef api DPCTLSyclQueueRef Memory_GetQueueRef(_Memory obj):
763+
"""Queue associated with this allocation, used
764+
for copying, population, etc."""
765+
return obj.queue.get_queue_ref()
766+
767+
cdef api size_t Memory_GetNumBytes(_Memory obj):
768+
"Size of the allocation in bytes."
761769
return <size_t>obj.nbytes
770+
771+
cdef api object Memory_Make(
772+
DPCTLSyclUSMRef ptr,
773+
size_t nbytes,
774+
DPCTLSyclQueueRef QRef,
775+
object owner
776+
):
777+
"Create _Memory Python object from preallocated memory."
778+
return _Memory.create_from_usm_pointer_size_qref(
779+
ptr, nbytes, QRef, memory_owner=owner
780+
)

dpctl/tensor/_usmarray.pyx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,37 +1136,36 @@ cdef usm_ndarray _zero_like(usm_ndarray ary):
11361136
return r
11371137

11381138

1139-
cdef api char* usm_ndarray_get_data(usm_ndarray arr):
1140-
"""
1141-
"""
1139+
cdef api char* UsmNDArray_GetData(usm_ndarray arr):
1140+
"""Get allocation pointer of zero index element of array """
11421141
return arr.get_data()
11431142

11441143

1145-
cdef api int usm_ndarray_get_ndim(usm_ndarray arr):
1146-
""""""
1144+
cdef api int UsmNDArray_GetNDim(usm_ndarray arr):
1145+
"""Get array rank: length of its shape"""
11471146
return arr.get_ndim()
11481147

11491148

1150-
cdef api Py_ssize_t* usm_ndarray_get_shape(usm_ndarray arr):
1151-
""" """
1149+
cdef api Py_ssize_t* UsmNDArray_GetShape(usm_ndarray arr):
1150+
"""Get host pointer to shape vector"""
11521151
return arr.get_shape()
11531152

11541153

1155-
cdef api Py_ssize_t* usm_ndarray_get_strides(usm_ndarray arr):
1156-
""" """
1154+
cdef api Py_ssize_t* UsmNDArray_GetStrides(usm_ndarray arr):
1155+
"""Get host pointer to strides vector"""
11571156
return arr.get_strides()
11581157

11591158

1160-
cdef api int usm_ndarray_get_typenum(usm_ndarray arr):
1161-
""" """
1159+
cdef api int UsmNDArray_GetTypenum(usm_ndarray arr):
1160+
"""Get type number for data type of array elements"""
11621161
return arr.get_typenum()
11631162

11641163

1165-
cdef api int usm_ndarray_get_flags(usm_ndarray arr):
1166-
""" """
1164+
cdef api int UsmNDArray_GetFlags(usm_ndarray arr):
1165+
"""Get flags of array"""
11671166
return arr.get_flags()
11681167

11691168

1170-
cdef api c_dpctl.DPCTLSyclQueueRef usm_ndarray_get_queue_ref(usm_ndarray arr):
1171-
""" """
1169+
cdef api c_dpctl.DPCTLSyclQueueRef UsmNDArray_GetQueueRef(usm_ndarray arr):
1170+
"""Get DPCTLSyclQueueRef for queue associated with the array"""
11721171
return arr.get_queue_ref()

dpctl/tests/test_sycl_context.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ def test_context_repr():
190190
assert type(ctx.__repr__()) is str
191191

192192

193-
def test_cpython_api_get_context_ref():
193+
def test_cpython_api_SyclContext_GetContextRef():
194194
import ctypes
195195
import sys
196196

197197
ctx = dpctl.SyclContext()
198198
mod = sys.modules[ctx.__class__.__module__]
199-
# get capsule storign get_context_ref function ptr
200-
ctx_ref_fn_cap = mod.__pyx_capi__["get_context_ref"]
201-
# construct Python callable to invoke "get_context_ref"
199+
# get capsule storign SyclContext_GetContextRef function ptr
200+
ctx_ref_fn_cap = mod.__pyx_capi__["SyclContext_GetContextRef"]
201+
# construct Python callable to invoke "SyclContext_GetContextRef"
202202
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
203203
cap_ptr_fn.restype = ctypes.c_void_p
204204
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
@@ -213,15 +213,15 @@ def test_cpython_api_get_context_ref():
213213
assert r1 == r2
214214

215215

216-
def test_cpython_api_make_SyclContext():
216+
def test_cpython_api_SyclContext_Make():
217217
import ctypes
218218
import sys
219219

220220
ctx = dpctl.SyclContext()
221221
mod = sys.modules[ctx.__class__.__module__]
222-
# get capsule storign make_SyclContext function ptr
223-
make_ctx_fn_cap = mod.__pyx_capi__["make_SyclContext"]
224-
# construct Python callable to invoke "make_SyclContext"
222+
# get capsule storign SyclContext_Make function ptr
223+
make_ctx_fn_cap = mod.__pyx_capi__["SyclContext_Make"]
224+
# construct Python callable to invoke "SyclContext_Make"
225225
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
226226
cap_ptr_fn.restype = ctypes.c_void_p
227227
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]

dpctl/tests/test_sycl_device.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -731,15 +731,15 @@ def test_handle_no_device():
731731
dpctl.select_device_with_aspects("cpu", excluded_aspects="cpu")
732732

733733

734-
def test_cpython_api_get_device_ref():
734+
def test_cpython_api_SyclDevice_GetDeviceRef():
735735
import ctypes
736736
import sys
737737

738738
d = dpctl.SyclDevice()
739739
mod = sys.modules[d.__class__.__module__]
740-
# get capsule storign get_device_ref function ptr
741-
d_ref_fn_cap = mod.__pyx_capi__["get_device_ref"]
742-
# construct Python callable to invoke "get_device_ref"
740+
# get capsule storing SyclDevice_GetDeviceRef function ptr
741+
d_ref_fn_cap = mod.__pyx_capi__["SyclDevice_GetDeviceRef"]
742+
# construct Python callable to invoke "SyclDevice_GetDeviceRef"
743743
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
744744
cap_ptr_fn.restype = ctypes.c_void_p
745745
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
@@ -754,15 +754,15 @@ def test_cpython_api_get_device_ref():
754754
assert r1 == r2
755755

756756

757-
def test_cpython_api_make_SyclDevice():
757+
def test_cpython_api_SyclDevice_Make():
758758
import ctypes
759759
import sys
760760

761761
d = dpctl.SyclDevice()
762762
mod = sys.modules[d.__class__.__module__]
763-
# get capsule storign make_SyclContext function ptr
764-
make_d_fn_cap = mod.__pyx_capi__["make_SyclDevice"]
765-
# construct Python callable to invoke "make_SyclDevice"
763+
# get capsule storign SyclContext_Make function ptr
764+
make_d_fn_cap = mod.__pyx_capi__["SyclDevice_Make"]
765+
# construct Python callable to invoke "SyclDevice_Make"
766766
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
767767
cap_ptr_fn.restype = ctypes.c_void_p
768768
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]

0 commit comments

Comments
 (0)