diff --git a/source/adapters/cuda/event.cpp b/source/adapters/cuda/event.cpp index 1e8f2dd384..11f9df37bc 100644 --- a/source/adapters/cuda/event.cpp +++ b/source/adapters/cuda/event.cpp @@ -42,11 +42,34 @@ ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context, urContextRetain(Context); } +void ur_event_handle_t_::reset() { + detail::ur::assertion( + RefCount == 0, "Attempting to reset an event that is still referenced"); + + HasBeenWaitedOn = false; + IsRecorded = false; + IsStarted = false; + Queue = nullptr; + Context = nullptr; +} + ur_event_handle_t_::~ur_event_handle_t_() { - if (Queue != nullptr) { + if (HasOwnership) { + if (EvEnd) + UR_CHECK_ERROR(cuEventDestroy(EvEnd)); + + if (EvQueued) + UR_CHECK_ERROR(cuEventDestroy(EvQueued)); + + if (EvStart) + UR_CHECK_ERROR(cuEventDestroy(EvStart)); + } + if (Queue) { urQueueRelease(Queue); } - urContextRelease(Context); + if (Context) { + urContextRelease(Context); + } } ur_result_t ur_event_handle_t_::start() { @@ -141,22 +164,6 @@ ur_result_t ur_event_handle_t_::wait() { return Result; } -ur_result_t ur_event_handle_t_::release() { - if (!backendHasOwnership()) - return UR_RESULT_SUCCESS; - - assert(Queue != nullptr); - - UR_CHECK_ERROR(cuEventDestroy(EvEnd)); - - if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) { - UR_CHECK_ERROR(cuEventDestroy(EvQueued)); - UR_CHECK_ERROR(cuEventDestroy(EvStart)); - } - - return UR_RESULT_SUCCESS; -} - UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, ur_event_info_t propName, size_t propValueSize, @@ -254,16 +261,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // decrement ref count. If it is 0, delete the event. if (hEvent->decrementReferenceCount() == 0) { std::unique_ptr event_ptr{hEvent}; - ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { ScopedContext Active(hEvent->getContext()); - Result = hEvent->release(); - } catch (...) { - Result = UR_RESULT_ERROR_OUT_OF_RESOURCES; + if (!hEvent->backendHasOwnership()) { + return UR_RESULT_SUCCESS; + } else { + auto Queue = event_ptr->getQueue(); + auto Context = event_ptr->getContext(); + + event_ptr->reset(); + if (Queue) { + Queue->cache_event(event_ptr.release()); + urQueueRelease(Queue); + } + urContextRelease(Context); + } + } catch (ur_result_t Err) { + return Err; } - return Result; } - return UR_RESULT_SUCCESS; } diff --git a/source/adapters/cuda/event.hpp b/source/adapters/cuda/event.hpp index 5ed68f0f25..802a29240e 100644 --- a/source/adapters/cuda/event.hpp +++ b/source/adapters/cuda/event.hpp @@ -90,6 +90,21 @@ struct ur_event_handle_t_ { const bool RequiresTimings = Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || Type == UR_COMMAND_TIMESTAMP_RECORDING_EXP; + if (Queue->has_cached_events()) { + auto retEvent = Queue->get_cached_event(); + + retEvent->Stream = Stream; + retEvent->StreamToken = StreamToken; + retEvent->CommandType = Type; + retEvent->Queue = Queue; + retEvent->Context = Queue->Context; + retEvent->RefCount = 1; + + urQueueRetain(retEvent->Queue); + urContextRetain(retEvent->Context); + + return retEvent; + } native_type EvEnd = nullptr, EvQueued = nullptr, EvStart = nullptr; UR_CHECK_ERROR(cuEventCreate( &EvEnd, RequiresTimings ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING)); @@ -107,7 +122,9 @@ struct ur_event_handle_t_ { return new ur_event_handle_t_(context, eventNative); } - ur_result_t release(); + // Resets attributes of an event. + // Throws an error if its RefCount is not 0. + void reset(); ~ur_event_handle_t_(); diff --git a/source/adapters/cuda/queue.cpp b/source/adapters/cuda/queue.cpp index 120d665524..4dce8eb534 100644 --- a/source/adapters/cuda/queue.cpp +++ b/source/adapters/cuda/queue.cpp @@ -32,6 +32,17 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded( } } +ur_queue_handle_t_::~ur_queue_handle_t_() { + urContextRelease(Context); + urDeviceRelease(Device); + + std::lock_guard CacheGuard(CacheMutex); + while (!CachedEvents.empty()) { + std::unique_ptr Ev{CachedEvents.top()}; + CachedEvents.pop(); + } +} + CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) { uint32_t StreamI; uint32_t Token; diff --git a/source/adapters/cuda/queue.hpp b/source/adapters/cuda/queue.hpp index c79ca18a9b..e3fb2d1d51 100644 --- a/source/adapters/cuda/queue.hpp +++ b/source/adapters/cuda/queue.hpp @@ -13,6 +13,7 @@ #include #include +#include #include using ur_stream_guard_ = std::unique_lock; @@ -35,6 +36,9 @@ struct ur_queue_handle_t_ { // keep track of which streams have applied barrier std::vector ComputeAppliedBarrier; std::vector TransferAppliedBarrier; + // CachedEvents assumes ownership of events. + // Events on the stack are destructed when queue is destructed as well. + std::stack CachedEvents; ur_context_handle_t_ *Context; ur_device_handle_t_ *Device; CUevent BarrierEvent = nullptr; @@ -57,6 +61,8 @@ struct ur_queue_handle_t_ { std::mutex ComputeStreamMutex; std::mutex TransferStreamMutex; std::mutex BarrierMutex; + // The event cache might be accessed in multiple threads. + std::mutex CacheMutex; bool HasOwnership; ur_queue_handle_t_(std::vector &&ComputeStreams, @@ -77,10 +83,7 @@ struct ur_queue_handle_t_ { urDeviceRetain(Device); } - ~ur_queue_handle_t_() { - urContextRelease(Context); - urDeviceRelease(Device); - } + ~ur_queue_handle_t_(); void computeStreamWaitForBarrierIfNeeded(CUstream Strean, uint32_t StreamI); void transferStreamWaitForBarrierIfNeeded(CUstream Stream, uint32_t StreamI); @@ -245,4 +248,23 @@ struct ur_queue_handle_t_ { uint32_t getNextEventID() noexcept { return ++EventCount; } bool backendHasOwnership() const noexcept { return HasOwnership; } + + bool has_cached_events() { + std::lock_guard CacheGuard(CacheMutex); + return !CachedEvents.empty(); + } + + void cache_event(ur_event_handle_t Event) { + std::lock_guard CacheGuard(CacheMutex); + CachedEvents.push(Event); + } + + // Returns and removes an event from the CachedEvents stack. + ur_event_handle_t get_cached_event() { + std::lock_guard CacheGuard(CacheMutex); + assert(!CachedEvents.empty()); + auto RetEv = CachedEvents.top(); + CachedEvents.pop(); + return RetEv; + } }; diff --git a/source/adapters/hip/event.cpp b/source/adapters/hip/event.cpp index 5327c43a3b..9cd6e00df6 100644 --- a/source/adapters/hip/event.cpp +++ b/source/adapters/hip/event.cpp @@ -47,11 +47,34 @@ ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context, urContextRetain(Context); } +void ur_event_handle_t_::reset() { + detail::ur::assertion( + RefCount == 0, "Attempting to reset an event that is still referenced"); + + HasBeenWaitedOn = false; + IsRecorded = false; + IsStarted = false; + Queue = nullptr; + Context = nullptr; +} + ur_event_handle_t_::~ur_event_handle_t_() { - if (Queue != nullptr) { + if (HasOwnership) { + if (EvEnd) + UR_CHECK_ERROR(hipEventDestroy(EvEnd)); + + if (EvQueued) + UR_CHECK_ERROR(hipEventDestroy(EvQueued)); + + if (EvStart) + UR_CHECK_ERROR(hipEventDestroy(EvStart)); + } + if (Queue) { urQueueRelease(Queue); } - urContextRelease(Context); + if (Context) { + urContextRelease(Context); + } } ur_result_t ur_event_handle_t_::start() { @@ -171,21 +194,6 @@ ur_result_t ur_event_handle_t_::wait() { return Result; } -ur_result_t ur_event_handle_t_::release() { - if (!backendHasOwnership()) - return UR_RESULT_SUCCESS; - - assert(Queue != nullptr); - UR_CHECK_ERROR(hipEventDestroy(EvEnd)); - - if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) { - UR_CHECK_ERROR(hipEventDestroy(EvQueued)); - UR_CHECK_ERROR(hipEventDestroy(EvStart)); - } - - return UR_RESULT_SUCCESS; -} - UR_APIEXPORT ur_result_t UR_APICALL urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { UR_ASSERT(numEvents > 0, UR_RESULT_ERROR_INVALID_VALUE); @@ -291,15 +299,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // decrement ref count. If it is 0, delete the event. if (hEvent->decrementReferenceCount() == 0) { std::unique_ptr event_ptr{hEvent}; - ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { - Result = hEvent->release(); - } catch (...) { - Result = UR_RESULT_ERROR_OUT_OF_RESOURCES; + if (!hEvent->backendHasOwnership()) { + return UR_RESULT_SUCCESS; + } else { + auto Queue = event_ptr->getQueue(); + auto Context = event_ptr->getContext(); + + event_ptr->reset(); + if (Queue) { + Queue->cache_event(event_ptr.release()); + urQueueRelease(Queue); + } + urContextRelease(Context); + } + } catch (ur_result_t Err) { + return Err; } - return Result; } - return UR_RESULT_SUCCESS; } diff --git a/source/adapters/hip/event.hpp b/source/adapters/hip/event.hpp index 64e8b2d9c8..2125f13f55 100644 --- a/source/adapters/hip/event.hpp +++ b/source/adapters/hip/event.hpp @@ -82,6 +82,21 @@ struct ur_event_handle_t_ { static ur_event_handle_t makeNative(ur_command_t Type, ur_queue_handle_t Queue, hipStream_t Stream, uint32_t StreamToken = std::numeric_limits::max()) { + if (Queue->has_cached_events()) { + auto retEvent = Queue->get_cached_event(); + + retEvent->Stream = Stream; + retEvent->StreamToken = StreamToken; + retEvent->CommandType = Type; + retEvent->Queue = Queue; + retEvent->Context = Queue->Context; + retEvent->RefCount = 1; + + urQueueRetain(retEvent->Queue); + urContextRetain(retEvent->Context); + + return retEvent; + } return new ur_event_handle_t_(Type, Queue->getContext(), Queue, Stream, StreamToken); } @@ -91,7 +106,9 @@ struct ur_event_handle_t_ { return new ur_event_handle_t_(context, eventNative); } - ur_result_t release(); + // Resets attributes of an event. + // Throws an error if its RefCount is not 0. + void reset(); ~ur_event_handle_t_(); diff --git a/source/adapters/hip/queue.cpp b/source/adapters/hip/queue.cpp index 6e6496fec1..3758b40104 100644 --- a/source/adapters/hip/queue.cpp +++ b/source/adapters/hip/queue.cpp @@ -28,6 +28,17 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded( } } +ur_queue_handle_t_::~ur_queue_handle_t_() { + urContextRelease(Context); + urDeviceRelease(Device); + + std::lock_guard CacheGuard(CacheMutex); + while (!CachedEvents.empty()) { + std::unique_ptr Ev{CachedEvents.top()}; + CachedEvents.pop(); + } +} + hipStream_t ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) { uint32_t Stream_i; uint32_t Token; diff --git a/source/adapters/hip/queue.hpp b/source/adapters/hip/queue.hpp index ad2f0f016e..99e96c752e 100644 --- a/source/adapters/hip/queue.hpp +++ b/source/adapters/hip/queue.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include using ur_stream_quard = std::unique_lock; @@ -30,6 +31,9 @@ struct ur_queue_handle_t_ { // keep track of which streams have applied barrier std::vector ComputeAppliedBarrier; std::vector TransferAppliedBarrier; + // CachedEvents assumes ownership of events. + // Events on the stack are destructed when queue is destructed as well. + std::stack CachedEvents; ur_context_handle_t Context; ur_device_handle_t Device; hipEvent_t BarrierEvent = nullptr; @@ -52,6 +56,8 @@ struct ur_queue_handle_t_ { std::mutex ComputeStreamMutex; std::mutex TransferStreamMutex; std::mutex BarrierMutex; + // The event cache might be accessed in multiple threads. + std::mutex CacheMutex; bool HasOwnership; ur_queue_handle_t_(std::vector &&ComputeStreams, @@ -72,10 +78,7 @@ struct ur_queue_handle_t_ { urDeviceRetain(Device); } - ~ur_queue_handle_t_() { - urContextRelease(Context); - urDeviceRelease(Device); - } + ~ur_queue_handle_t_(); void computeStreamWaitForBarrierIfNeeded(hipStream_t Stream, uint32_t Stream_i); @@ -242,4 +245,23 @@ struct ur_queue_handle_t_ { uint32_t getNextEventId() noexcept { return ++EventCount; } bool backendHasOwnership() const noexcept { return HasOwnership; } + + bool has_cached_events() { + std::lock_guard CacheGuard(CacheMutex); + return !CachedEvents.empty(); + } + + void cache_event(ur_event_handle_t Event) { + std::lock_guard CacheGuard(CacheMutex); + CachedEvents.push(Event); + } + + // Returns and removes an event from the CachedEvents stack. + ur_event_handle_t get_cached_event() { + std::lock_guard CacheGuard(CacheMutex); + assert(!CachedEvents.empty()); + auto RetEv = CachedEvents.top(); + CachedEvents.pop(); + return RetEv; + } };