@@ -566,31 +566,27 @@ bool _pi_event::is_completed() const noexcept {
566
566
return true ;
567
567
}
568
568
569
- pi_uint64 _pi_event::get_queued_time ( ) const {
569
+ pi_uint64 _pi_device::get_elapsed_time (CUevent ev ) const {
570
570
float miliSeconds = 0 .0f ;
571
- assert (is_started ());
572
571
573
- PI_CHECK_ERROR (
574
- cuEventElapsedTime (&miliSeconds, _pi_platform::evBase_, evQueued_));
572
+ PI_CHECK_ERROR (cuEventElapsedTime (&miliSeconds, evBase_, ev));
573
+
575
574
return static_cast <pi_uint64>(miliSeconds * 1.0e6 );
576
575
}
577
576
578
- pi_uint64 _pi_event::get_start_time () const {
579
- float miliSeconds = 0 .0f ;
577
+ pi_uint64 _pi_event::get_queued_time () const {
580
578
assert (is_started ());
579
+ return queue_->get_device ()->get_elapsed_time (evQueued_);
580
+ }
581
581
582
- PI_CHECK_ERROR (
583
- cuEventElapsedTime (&miliSeconds, _pi_platform::evBase_, evStart_ ));
584
- return static_cast <pi_uint64>(miliSeconds * 1.0e6 );
582
+ pi_uint64 _pi_event::get_start_time () const {
583
+ assert ( is_started ( ));
584
+ return queue_-> get_device ()-> get_elapsed_time (evStart_ );
585
585
}
586
586
587
587
pi_uint64 _pi_event::get_end_time () const {
588
- float miliSeconds = 0 .0f ;
589
588
assert (is_started () && is_recorded ());
590
-
591
- PI_CHECK_ERROR (
592
- cuEventElapsedTime (&miliSeconds, _pi_platform::evBase_, evEnd_));
593
- return static_cast <pi_uint64>(miliSeconds * 1.0e6 );
589
+ return queue_->get_device ()->get_elapsed_time (evEnd_);
594
590
}
595
591
596
592
pi_result _pi_event::record () {
@@ -830,8 +826,15 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
830
826
CUcontext context;
831
827
err = PI_CHECK_ERROR (cuDevicePrimaryCtxRetain (&context, device));
832
828
829
+ ScopedContext active (context);
830
+ CUevent evBase;
831
+ err = PI_CHECK_ERROR (cuEventCreate (&evBase, CU_EVENT_DEFAULT));
832
+
833
+ // Use default stream to record base event counter
834
+ err = PI_CHECK_ERROR (cuEventRecord (evBase, 0 ));
835
+
833
836
platformIds[i].devices_ .emplace_back (
834
- new _pi_device{device, context, &platformIds[i]});
837
+ new _pi_device{device, context, evBase, &platformIds[i]});
835
838
836
839
{
837
840
const auto &dev = platformIds[i].devices_ .back ().get ();
@@ -2061,18 +2064,6 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
2061
2064
std::unique_ptr<_pi_context> piContextPtr{nullptr };
2062
2065
try {
2063
2066
piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{*devices});
2064
-
2065
- static std::once_flag initFlag;
2066
- std::call_once (
2067
- initFlag,
2068
- [](pi_result &err) {
2069
- // Use default stream to record base event counter
2070
- PI_CHECK_ERROR (
2071
- cuEventCreate (&_pi_platform::evBase_, CU_EVENT_DEFAULT));
2072
- PI_CHECK_ERROR (cuEventRecord (_pi_platform::evBase_, 0 ));
2073
- },
2074
- errcode_ret);
2075
-
2076
2067
*retcontext = piContextPtr.release ();
2077
2068
} catch (pi_result err) {
2078
2069
errcode_ret = err;
@@ -5537,11 +5528,7 @@ pi_result cuda_piGetDeviceAndHostTimer(pi_device Device, uint64_t *DeviceTime,
5537
5528
5538
5529
if (DeviceTime) {
5539
5530
PI_CHECK_ERROR (cuEventSynchronize (event));
5540
-
5541
- float elapsedTime = 0 .0f ;
5542
- PI_CHECK_ERROR (
5543
- cuEventElapsedTime (&elapsedTime, _pi_platform::evBase_, event));
5544
- *DeviceTime = (uint64_t )(elapsedTime * (double )1e6 );
5531
+ *DeviceTime = Device->get_elapsed_time (event);
5545
5532
}
5546
5533
5547
5534
return PI_SUCCESS;
@@ -5708,5 +5695,3 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
5708
5695
}
5709
5696
5710
5697
} // extern "C"
5711
-
5712
- CUevent _pi_platform::evBase_{nullptr };
0 commit comments