Skip to content

Commit 80568c6

Browse files
author
aidan.belton
committed
add cuda interop context, queue, event
1 parent d17840d commit 80568c6

File tree

3 files changed

+89
-20
lines changed

3 files changed

+89
-20
lines changed

sycl/include/sycl/ext/oneapi/experimental/backend/cuda.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ namespace cuda {
2222

2323
// Implementation of cuda::make<device>
2424
inline __SYCL_EXPORT device make_device(pi_native_handle NativeHandle) {
25-
return detail::make_device(NativeHandle, backend::cuda);
25+
return sycl::detail::make_device(NativeHandle, backend::cuda);
2626
}
2727

2828
// Implementation of cuda::make<platform>
2929
inline __SYCL_EXPORT platform make_platform(pi_native_handle NativeHandle) {
30-
return detail::make_platform(NativeHandle, backend::cuda);
30+
return sycl::detail::make_platform(NativeHandle, backend::cuda);
3131
}
3232

3333
} // namespace cuda

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,12 @@ _pi_event::_pi_event(pi_command_type type, pi_context context, pi_queue queue)
344344
cuda_piContextRetain(context_);
345345
}
346346

347+
_pi_event::_pi_event(pi_context context, CUevent eventNative)
348+
: commandType_{PI_COMMAND_TYPE_USER}, refCount_{1}, hasBeenWaitedOn_{false},
349+
isRecorded_{false}, isStarted_{false}, evEnd_{eventNative},
350+
evStart_{nullptr}, evQueued_{nullptr}, queue_{nullptr}, context_{
351+
context} {}
352+
347353
_pi_event::~_pi_event() {
348354
if (queue_ != nullptr) {
349355
cuda_piQueueRelease(queue_);
@@ -1977,15 +1983,14 @@ pi_result cuda_piContextRelease(pi_context ctxt) {
19771983

19781984
std::unique_ptr<_pi_context> context{ctxt};
19791985

1980-
PI_CHECK_ERROR(cuEventDestroy(context->evBase_));
1981-
19821986
if (!ctxt->is_primary()) {
19831987
CUcontext cuCtxt = ctxt->get();
19841988
CUcontext current = nullptr;
19851989
cuCtxGetCurrent(&current);
19861990
if (cuCtxt != current) {
19871991
PI_CHECK_ERROR(cuCtxPushCurrent(cuCtxt));
19881992
}
1993+
PI_CHECK_ERROR(cuEventDestroy(context->evBase_));
19891994
PI_CHECK_ERROR(cuCtxSynchronize());
19901995
cuCtxGetCurrent(&current);
19911996
if (cuCtxt == current) {
@@ -1994,6 +1999,7 @@ pi_result cuda_piContextRelease(pi_context ctxt) {
19941999
return PI_CHECK_ERROR(cuCtxDestroy(cuCtxt));
19952000
} else {
19962001
// Primary context is not destroyed, but released
2002+
PI_CHECK_ERROR(cuEventDestroy(context->evBase_));
19972003
CUdevice cuDev = ctxt->get_device()->get();
19982004
CUcontext current;
19992005
cuCtxPopCurrent(&current);
@@ -2021,12 +2027,43 @@ pi_result cuda_piextContextGetNativeHandle(pi_context context,
20212027
/// \param[out] context Set to the PI context object created from native handle.
20222028
///
20232029
/// \return TBD
2024-
pi_result cuda_piextContextCreateWithNativeHandle(pi_native_handle, pi_uint32,
2025-
const pi_device *, bool,
2026-
pi_context *) {
2027-
cl::sycl::detail::pi::die(
2028-
"Creation of PI context from native handle not implemented");
2029-
return {};
2030+
pi_result cuda_piextContextCreateWithNativeHandle(pi_native_handle nativeHandle,
2031+
pi_uint32 num_devices,
2032+
const pi_device *devices,
2033+
bool ownNativeHandle,
2034+
pi_context *piContext) {
2035+
(void)num_devices;
2036+
(void)devices;
2037+
(void)ownNativeHandle;
2038+
assert(piContext != nullptr);
2039+
assert(ownNativeHandle == false);
2040+
2041+
CUcontext newContext = reinterpret_cast<CUcontext>(nativeHandle);
2042+
2043+
// Push native context to thread
2044+
pi_result retErr = PI_CHECK_ERROR(cuCtxPushCurrent(newContext));
2045+
2046+
// Get context's native device
2047+
CUdevice cu_device;
2048+
retErr = PI_CHECK_ERROR(cuCtxGetDevice(&cu_device));
2049+
2050+
// Create a SYCL device from the ctx device
2051+
pi_device device = nullptr;
2052+
retErr = cuda_piextDeviceCreateWithNativeHandle(cu_device, nullptr, &device);
2053+
2054+
// Create sycl context
2055+
*piContext =
2056+
new _pi_context{_pi_context::kind::user_defined, newContext, device};
2057+
2058+
// Use default stream to record base event counter
2059+
retErr =
2060+
PI_CHECK_ERROR(cuEventCreate(&(*piContext)->evBase_, CU_EVENT_DEFAULT));
2061+
retErr = PI_CHECK_ERROR(cuEventRecord((*piContext)->evBase_, 0));
2062+
2063+
// Pop native context
2064+
retErr = PI_CHECK_ERROR(cuCtxPopCurrent(nullptr));
2065+
2066+
return retErr;
20302067
}
20312068

20322069
/// Creates a PI Memory object using a CUDA memory allocation.
@@ -2430,13 +2467,29 @@ pi_result cuda_piextQueueGetNativeHandle(pi_queue queue,
24302467
/// the native handle, if it can.
24312468
///
24322469
/// \return TBD
2433-
pi_result cuda_piextQueueCreateWithNativeHandle(pi_native_handle, pi_context,
2434-
pi_queue *,
2470+
pi_result cuda_piextQueueCreateWithNativeHandle(pi_native_handle nativeHandle,
2471+
pi_context context,
2472+
pi_queue *queue,
24352473
bool ownNativeHandle) {
24362474
(void)ownNativeHandle;
2437-
cl::sycl::detail::pi::die(
2438-
"Creation of PI queue from native handle not implemented");
2439-
return {};
2475+
assert(ownNativeHandle == 1);
2476+
2477+
unsigned int flags;
2478+
CUstream cuStream = reinterpret_cast<CUstream>(nativeHandle);
2479+
2480+
auto retErr = PI_CHECK_ERROR(cuStreamGetFlags(cuStream, &flags));
2481+
2482+
pi_queue_properties properties = 0;
2483+
if (flags == CU_STREAM_DEFAULT)
2484+
properties = __SYCL_PI_CUDA_USE_DEFAULT_STREAM;
2485+
else if (flags == CU_STREAM_NON_BLOCKING)
2486+
properties = __SYCL_PI_CUDA_SYNC_WITH_DEFAULT;
2487+
else
2488+
cl::sycl::detail::pi::die("Unknown cuda stream");
2489+
2490+
*queue = new _pi_queue{cuStream, context, context->get_device(), properties};
2491+
2492+
return retErr;
24402493
}
24412494

24422495
pi_result cuda_piEnqueueMemBufferWrite(pi_queue command_queue, pi_mem buffer,
@@ -3699,11 +3752,19 @@ pi_result cuda_piextEventGetNativeHandle(pi_event event,
36993752
/// \param[out] event Set to the PI event object created from native handle.
37003753
///
37013754
/// \return TBD
3702-
pi_result cuda_piextEventCreateWithNativeHandle(pi_native_handle, pi_context,
3703-
bool, pi_event *) {
3704-
cl::sycl::detail::pi::die(
3705-
"Creation of PI event from native handle not implemented");
3706-
return {};
3755+
pi_result cuda_piextEventCreateWithNativeHandle(pi_native_handle nativeHandle,
3756+
pi_context context,
3757+
bool ownNativeHandle,
3758+
pi_event *event) {
3759+
(void)ownNativeHandle;
3760+
assert(ownNativeHandle == true);
3761+
3762+
std::unique_ptr<_pi_event> event_ptr{nullptr};
3763+
3764+
*event = _pi_event::make_with_native(context,
3765+
reinterpret_cast<CUevent>(nativeHandle));
3766+
3767+
return PI_SUCCESS;
37073768
}
37083769

37093770
/// Creates a PI sampler object

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,10 @@ struct _pi_event {
477477
return new _pi_event(type, queue->get_context(), queue);
478478
}
479479

480+
static pi_event make_with_native(pi_context context, CUevent eventNative) {
481+
return new _pi_event(context, eventNative);
482+
}
483+
480484
pi_result release();
481485

482486
~_pi_event();
@@ -486,6 +490,10 @@ struct _pi_event {
486490
// make_user static members in order to create a pi_event for CUDA.
487491
_pi_event(pi_command_type type, pi_context context, pi_queue queue);
488492

493+
// This constructor is private to force programmers to use the
494+
// make_from_native / for event introp
495+
_pi_event(pi_context context, CUevent eventNative);
496+
489497
pi_command_type commandType_; // The type of command associated with event.
490498

491499
std::atomic_uint32_t refCount_; // Event reference count.

0 commit comments

Comments
 (0)