diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index c9c7838876a4a..1f97892d99018 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -566,31 +566,27 @@ bool _pi_event::is_completed() const noexcept { return true; } -pi_uint64 _pi_event::get_queued_time() const { +pi_uint64 _pi_device::get_elapsed_time(CUevent ev) const { float miliSeconds = 0.0f; - assert(is_started()); - PI_CHECK_ERROR( - cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evQueued_)); + PI_CHECK_ERROR(cuEventElapsedTime(&miliSeconds, evBase_, ev)); + return static_cast(miliSeconds * 1.0e6); } -pi_uint64 _pi_event::get_start_time() const { - float miliSeconds = 0.0f; +pi_uint64 _pi_event::get_queued_time() const { assert(is_started()); + return queue_->get_device()->get_elapsed_time(evQueued_); +} - PI_CHECK_ERROR( - cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evStart_)); - return static_cast(miliSeconds * 1.0e6); +pi_uint64 _pi_event::get_start_time() const { + assert(is_started()); + return queue_->get_device()->get_elapsed_time(evStart_); } pi_uint64 _pi_event::get_end_time() const { - float miliSeconds = 0.0f; assert(is_started() && is_recorded()); - - PI_CHECK_ERROR( - cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evEnd_)); - return static_cast(miliSeconds * 1.0e6); + return queue_->get_device()->get_elapsed_time(evEnd_); } pi_result _pi_event::record() { @@ -905,8 +901,15 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, CUcontext context; err = PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&context, device)); + ScopedContext active(context); + CUevent evBase; + err = PI_CHECK_ERROR(cuEventCreate(&evBase, CU_EVENT_DEFAULT)); + + // Use default stream to record base event counter + err = PI_CHECK_ERROR(cuEventRecord(evBase, 0)); + platformIds[i].devices_.emplace_back( - new _pi_device{device, context, &platformIds[i]}); + new _pi_device{device, context, evBase, &platformIds[i]}); { const auto &dev = platformIds[i].devices_.back().get(); @@ -2136,18 +2139,6 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties, std::unique_ptr<_pi_context> piContextPtr{nullptr}; try { piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{*devices}); - - static std::once_flag initFlag; - std::call_once( - initFlag, - [](pi_result &err) { - // Use default stream to record base event counter - PI_CHECK_ERROR( - cuEventCreate(&_pi_platform::evBase_, CU_EVENT_DEFAULT)); - PI_CHECK_ERROR(cuEventRecord(_pi_platform::evBase_, 0)); - }, - errcode_ret); - *retcontext = piContextPtr.release(); } catch (pi_result err) { errcode_ret = err; @@ -5615,11 +5606,7 @@ pi_result cuda_piGetDeviceAndHostTimer(pi_device Device, uint64_t *DeviceTime, if (DeviceTime) { PI_CHECK_ERROR(cuEventSynchronize(event)); - - float elapsedTime = 0.0f; - PI_CHECK_ERROR( - cuEventElapsedTime(&elapsedTime, _pi_platform::evBase_, event)); - *DeviceTime = (uint64_t)(elapsedTime * (double)1e6); + *DeviceTime = Device->get_elapsed_time(event); } return PI_SUCCESS; @@ -5786,5 +5773,3 @@ pi_result piPluginInit(pi_plugin *PluginInit) { } } // extern "C" - -CUevent _pi_platform::evBase_{nullptr}; diff --git a/sycl/plugins/cuda/pi_cuda.hpp b/sycl/plugins/cuda/pi_cuda.hpp index a957b8df603c7..4b679e58d9230 100644 --- a/sycl/plugins/cuda/pi_cuda.hpp +++ b/sycl/plugins/cuda/pi_cuda.hpp @@ -72,7 +72,6 @@ using _pi_stream_guard = std::unique_lock; /// when devices are used. /// struct _pi_platform { - static CUevent evBase_; // CUDA event used as base counter std::vector> devices_; }; @@ -87,6 +86,7 @@ struct _pi_device { native_type cuDevice_; CUcontext cuContext_; + CUevent evBase_; // CUDA event used as base counter std::atomic_uint32_t refCount_; pi_platform platform_; @@ -95,9 +95,10 @@ struct _pi_device { int max_work_group_size; public: - _pi_device(native_type cuDevice, CUcontext cuContext, pi_platform platform) - : cuDevice_(cuDevice), cuContext_(cuContext), refCount_{1}, - platform_(platform) {} + _pi_device(native_type cuDevice, CUcontext cuContext, CUevent evBase, + pi_platform platform) + : cuDevice_(cuDevice), cuContext_(cuContext), + evBase_(evBase), refCount_{1}, platform_(platform) {} ~_pi_device() { cuDevicePrimaryCtxRelease(cuDevice_); } @@ -109,6 +110,8 @@ struct _pi_device { pi_platform get_platform() const noexcept { return platform_; }; + pi_uint64 get_elapsed_time(CUevent) const; + void save_max_work_item_sizes(size_t size, size_t *save_max_work_item_sizes) noexcept { memcpy(max_work_item_sizes, save_max_work_item_sizes, size);