Skip to content

Commit d08793d

Browse files
Merge pull request #1604 from IntelPython/backport-gh-1560
Backport gh-1560 to 0.16.x maintenance branch
2 parents 6efb2c9 + 4e53753 commit d08793d

File tree

6 files changed

+94
-14
lines changed

6 files changed

+94
-14
lines changed

dpctl/_sycl_platform.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,20 @@ cdef class SyclPlatform(_SyclPlatform):
267267
"""Returns the default platform context for this platform
268268
269269
Returns:
270-
SyclContext: The default context for the platform.
270+
SyclContext
271+
The default context for the platform.
272+
Raises:
273+
SyclContextCreationError
274+
If default_context is not supported
271275
"""
272276
cdef DPCTLSyclContextRef CRef = (
273277
DPCTLPlatform_GetDefaultContext(self._platform_ref)
274278
)
275279

276280
if (CRef == NULL):
277-
raise RuntimeError("Getting default error ran into a problem")
281+
raise SyclContextCreationError(
282+
"Getting default_context ran into a problem"
283+
)
278284
else:
279285
return SyclContext._create(CRef)
280286

dpctl/tensor/_dlpack.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# cython: language_level=3
1919
# cython: linetrace=True
2020

21+
from .._sycl_device cimport SyclDevice
2122
from ._usmarray cimport usm_ndarray
2223

2324

@@ -32,6 +33,8 @@ cpdef usm_ndarray from_dlpack_capsule(object dltensor) except +
3233

3334
cpdef from_dlpack(array)
3435

36+
cdef int get_parent_device_ordinal_id(SyclDevice dev) except *
37+
3538
cdef class DLPackCreationError(Exception):
3639
"""
3740
A DLPackCreateError exception is raised when constructing

dpctl/tensor/_dlpack.pyx

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,39 @@ cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil
121121
dlm_tensor.manager_ctx = NULL
122122
stdlib.free(dlm_tensor)
123123

124+
cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
125+
try:
126+
if _IS_LINUX:
127+
default_context = dev.sycl_platform.default_context
128+
else:
129+
default_context = None
130+
except RuntimeError:
131+
# RT does not support default_context, e.g. Windows
132+
default_context = None
133+
134+
return default_context
135+
136+
137+
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
138+
cdef DPCTLSyclDeviceRef pDRef = NULL
139+
cdef DPCTLSyclDeviceRef tDRef = NULL
140+
cdef c_dpctl.SyclDevice p_dev
141+
142+
pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
143+
if pDRef is not NULL:
144+
# if dev is a sub-device, find its parent
145+
# and return its overall ordinal id
146+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
147+
while tDRef is not NULL:
148+
DPCTLDevice_Delete(pDRef)
149+
pDRef = tDRef
150+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
151+
p_dev = c_dpctl.SyclDevice._create(pDRef)
152+
return p_dev.get_overall_ordinal()
153+
154+
# return overall ordinal id of argument device
155+
return dev.get_overall_ordinal()
156+
124157

125158
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
126159
"""
@@ -168,14 +201,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
168201
ary_sycl_queue = usm_ary.get_sycl_queue()
169202
ary_sycl_device = ary_sycl_queue.get_sycl_device()
170203

171-
try:
172-
if _IS_LINUX:
173-
default_context = ary_sycl_device.sycl_platform.default_context
174-
else:
175-
default_context = None
176-
except RuntimeError:
177-
# RT does not support default_context, e.g. Windows
178-
default_context = None
204+
default_context = _get_default_context(ary_sycl_device)
179205
if default_context is None:
180206
# check that ary_sycl_device is a non-partitioned device
181207
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())

dpctl/tensor/_usmarray.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,10 +954,10 @@ cdef class usm_ndarray:
954954
DLPackCreationError: when array is allocation on a partitioned
955955
SYCL device
956956
"""
957-
cdef int dev_id = (<c_dpctl.SyclDevice>self.sycl_device).get_overall_ordinal()
957+
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
958958
if dev_id < 0:
959959
raise c_dlpack.DLPackCreationError(
960-
"DLPack protocol is only supported for non-partitioned devices"
960+
"Could not determine id of the device where array was allocated."
961961
)
962962
else:
963963
return (

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,39 @@ def test_from_dlpack_fortran_contig_array_roundtripping():
197197

198198
assert dpt.all(dpt.equal(ar2d_f, ar2d_r))
199199
assert dpt.all(dpt.equal(ar2d_c, ar2d_r))
200+
201+
202+
def test_dlpack_from_subdevice():
203+
"""
204+
This test checks that array allocated on a sub-device,
205+
with memory bound to platform-default SyclContext can be
206+
exported and imported via DLPack.
207+
"""
208+
n = 64
209+
try:
210+
dev = dpctl.SyclDevice()
211+
except dpctl.SyclDeviceCreationError:
212+
pytest.skip("No default device available")
213+
try:
214+
sdevs = dev.create_sub_devices(partition="next_partitionable")
215+
except dpctl.SyclSubDeviceCreationError:
216+
sdevs = None
217+
try:
218+
sdevs = (
219+
dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs
220+
)
221+
except dpctl.SyclSubDeviceCreationError:
222+
pytest.skip("Default device can not be partitioned")
223+
assert isinstance(sdevs, list) and len(sdevs) > 0
224+
try:
225+
ctx = sdevs[0].sycl_platform.default_context
226+
except dpctl.SyclContextCreationError:
227+
pytest.skip("Platform's default_context is not available")
228+
try:
229+
q = dpctl.SyclQueue(ctx, sdevs[0])
230+
except dpctl.SyclQueueCreationError:
231+
pytest.skip("Queue could not be created")
232+
233+
ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q)
234+
ar2 = dpt.from_dlpack(ar)
235+
assert ar2.sycl_device == sdevs[0]

libsyclinterface/source/dpctl_sycl_platform_interface.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,17 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
225225
{
226226
auto P = unwrap<platform>(PRef);
227227
if (P) {
228-
const auto &default_ctx = P->ext_oneapi_get_default_context();
229-
return wrap<context>(new context(default_ctx));
228+
#ifdef SYCL_EXT_ONEAPI_DEFAULT_CONTEXT
229+
try {
230+
const auto &default_ctx = P->ext_oneapi_get_default_context();
231+
return wrap<context>(new context(default_ctx));
232+
} catch (const std::exception &ex) {
233+
error_handler(ex, __FILE__, __func__, __LINE__);
234+
return nullptr;
235+
}
236+
#else
237+
return nullptr;
238+
#endif
230239
}
231240
else {
232241
error_handler(

0 commit comments

Comments
 (0)