diff --git a/sycl/include/CL/sycl/detail/pi.def b/sycl/include/CL/sycl/detail/pi.def index 6fe9471d134da..c16f1d4331695 100644 --- a/sycl/include/CL/sycl/detail/pi.def +++ b/sycl/include/CL/sycl/detail/pi.def @@ -31,6 +31,7 @@ _PI_API(piContextCreate) _PI_API(piContextGetInfo) _PI_API(piContextRetain) _PI_API(piContextRelease) +_PI_API(piextContextSetExtendedDeleter) // Queue _PI_API(piQueueCreate) _PI_API(piQueueGetInfo) diff --git a/sycl/include/CL/sycl/detail/pi.h b/sycl/include/CL/sycl/detail/pi.h index 03bf0e8f5b2c8..62404d75e809d 100644 --- a/sycl/include/CL/sycl/detail/pi.h +++ b/sycl/include/CL/sycl/detail/pi.h @@ -824,6 +824,12 @@ pi_result piContextRetain(pi_context context); pi_result piContextRelease(pi_context context); +typedef void (*pi_context_extended_deleter)(void *user_data); + +pi_result piextContextSetExtendedDeleter(pi_context context, + pi_context_extended_deleter func, + void *user_data); + // // Queue // diff --git a/sycl/include/CL/sycl/detail/pi.hpp b/sycl/include/CL/sycl/detail/pi.hpp index 73ba98a4e4530..e401284990a36 100644 --- a/sycl/include/CL/sycl/detail/pi.hpp +++ b/sycl/include/CL/sycl/detail/pi.hpp @@ -30,6 +30,9 @@ struct trace_event_data_t; __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { + +class context; + namespace detail { enum class PiApiKind { @@ -95,6 +98,10 @@ using PiMemObjectType = ::pi_mem_type; using PiMemImageChannelOrder = ::pi_image_channel_order; using PiMemImageChannelType = ::pi_image_channel_type; +void contextSetExtendedDeleter(const cl::sycl::context &constext, + pi_context_extended_deleter func, + void *user_data); + // Function to load the shared library // Implementation is OS dependent. void *loadOsLibrary(const std::string &Library); diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 936bedad90912..c2bb65b6cbeaa 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -777,6 +777,12 @@ pi_result cuda_piContextRetain(pi_context context) { return PI_SUCCESS; } +pi_result cuda_piextContextSetExtendedDeleter( + pi_context context, pi_context_extended_deleter function, void *user_data) { + context->set_extended_deleter(function, user_data); + return PI_SUCCESS; +} + /// Not applicable to CUDA, devices cannot be partitioned. /// pi_result cuda_piDevicePartition( @@ -1462,7 +1468,7 @@ pi_result cuda_piContextRelease(pi_context ctxt) { if (ctxt->decrement_reference_count() > 0) { return PI_SUCCESS; } - ctxt->invoke_callback(); + ctxt->invoke_extended_deleters(); std::unique_ptr<_pi_context> context{ctxt}; @@ -3586,6 +3592,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) { _PI_CL(piextDeviceSelectBinary, cuda_piextDeviceSelectBinary) _PI_CL(piextGetDeviceFunctionPointer, cuda_piextGetDeviceFunctionPointer) // Context + _PI_CL(piextContextSetExtendedDeleter, cuda_piextContextSetExtendedDeleter) _PI_CL(piContextCreate, cuda_piContextCreate) _PI_CL(piContextGetInfo, cuda_piContextGetInfo) _PI_CL(piContextRetain, cuda_piContextRetain) diff --git a/sycl/plugins/cuda/pi_cuda.hpp b/sycl/plugins/cuda/pi_cuda.hpp index e04cd2ab47d21..d6989a87cf66f 100644 --- a/sycl/plugins/cuda/pi_cuda.hpp +++ b/sycl/plugins/cuda/pi_cuda.hpp @@ -121,6 +121,14 @@ class _pi_device { /// See proposal for details. /// struct _pi_context { + + struct deleter_data { + pi_context_extended_deleter function; + void *user_data; + + void operator()() { function(user_data); } + }; + using native_type = CUcontext; enum class kind { primary, user_defined } kind_; @@ -138,20 +146,17 @@ struct _pi_context { ~_pi_context() { cuda_piDeviceRelease(deviceId_); } - void invoke_callback() - { + void invoke_extended_deleters() { std::lock_guard guard(mutex_); - for(const auto& callback : destruction_callbacks_) - { - callback(); + for (auto &deleter : extended_deleters_) { + deleter(); } } - template - void register_callback(Func&& callback) - { + void set_extended_deleter(pi_context_extended_deleter function, + void *user_data) { std::lock_guard guard(mutex_); - destruction_callbacks_.emplace_back(std::forward(callback)); + extended_deleters_.emplace_back(deleter_data{function, user_data}); } pi_device get_device() const noexcept { return deviceId_; } @@ -168,7 +173,7 @@ struct _pi_context { private: std::mutex mutex_; - std::vector> destruction_callbacks_; + std::vector extended_deleters_; }; /// PI Mem mapping to a CUDA memory allocation diff --git a/sycl/source/detail/pi.cpp b/sycl/source/detail/pi.cpp index 0655211154002..4ec13c447a1c7 100644 --- a/sycl/source/detail/pi.cpp +++ b/sycl/source/detail/pi.cpp @@ -11,6 +11,8 @@ /// /// \ingroup sycl_pi +#include "context_impl.hpp" +#include #include #include #include @@ -53,6 +55,16 @@ namespace pi { bool XPTIInitDone = false; +void contextSetExtendedDeleter(const cl::sycl::context &context, + pi_context_extended_deleter func, + void *user_data) { + auto impl = getSyclObjImpl(context); + auto contextHandle = reinterpret_cast(impl->getHandleRef()); + auto plugin = impl->getPlugin(); + plugin.call_nocheck( + contextHandle, func, user_data); +} + std::string platformInfoToString(pi_platform_info info) { switch (info) { case PI_PLATFORM_INFO_PROFILE: