Skip to content

[SYCL][CUDA] Move base event into the device #8312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 19 additions & 34 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,31 +566,27 @@ bool _pi_event::is_completed() const noexcept {
return true;
}

pi_uint64 _pi_event::get_queued_time() const {
pi_uint64 _pi_device::get_elapsed_time(CUevent ev) const {
float miliSeconds = 0.0f;
assert(is_started());

PI_CHECK_ERROR(
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evQueued_));
PI_CHECK_ERROR(cuEventElapsedTime(&miliSeconds, evBase_, ev));

return static_cast<pi_uint64>(miliSeconds * 1.0e6);
}

pi_uint64 _pi_event::get_start_time() const {
float miliSeconds = 0.0f;
pi_uint64 _pi_event::get_queued_time() const {
assert(is_started());
return queue_->get_device()->get_elapsed_time(evQueued_);
}

PI_CHECK_ERROR(
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evStart_));
return static_cast<pi_uint64>(miliSeconds * 1.0e6);
pi_uint64 _pi_event::get_start_time() const {
assert(is_started());
return queue_->get_device()->get_elapsed_time(evStart_);
}

pi_uint64 _pi_event::get_end_time() const {
float miliSeconds = 0.0f;
assert(is_started() && is_recorded());

PI_CHECK_ERROR(
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evEnd_));
return static_cast<pi_uint64>(miliSeconds * 1.0e6);
return queue_->get_device()->get_elapsed_time(evEnd_);
}

pi_result _pi_event::record() {
Expand Down Expand Up @@ -905,8 +901,15 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
CUcontext context;
err = PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&context, device));

ScopedContext active(context);
CUevent evBase;
err = PI_CHECK_ERROR(cuEventCreate(&evBase, CU_EVENT_DEFAULT));

// Use default stream to record base event counter
err = PI_CHECK_ERROR(cuEventRecord(evBase, 0));

platformIds[i].devices_.emplace_back(
new _pi_device{device, context, &platformIds[i]});
new _pi_device{device, context, evBase, &platformIds[i]});

{
const auto &dev = platformIds[i].devices_.back().get();
Expand Down Expand Up @@ -2136,18 +2139,6 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
std::unique_ptr<_pi_context> piContextPtr{nullptr};
try {
piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{*devices});

static std::once_flag initFlag;
std::call_once(
initFlag,
[](pi_result &err) {
// Use default stream to record base event counter
PI_CHECK_ERROR(
cuEventCreate(&_pi_platform::evBase_, CU_EVENT_DEFAULT));
PI_CHECK_ERROR(cuEventRecord(_pi_platform::evBase_, 0));
},
errcode_ret);

*retcontext = piContextPtr.release();
} catch (pi_result err) {
errcode_ret = err;
Expand Down Expand Up @@ -5615,11 +5606,7 @@ pi_result cuda_piGetDeviceAndHostTimer(pi_device Device, uint64_t *DeviceTime,

if (DeviceTime) {
PI_CHECK_ERROR(cuEventSynchronize(event));

float elapsedTime = 0.0f;
PI_CHECK_ERROR(
cuEventElapsedTime(&elapsedTime, _pi_platform::evBase_, event));
*DeviceTime = (uint64_t)(elapsedTime * (double)1e6);
*DeviceTime = Device->get_elapsed_time(event);
}

return PI_SUCCESS;
Expand Down Expand Up @@ -5786,5 +5773,3 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
}

} // extern "C"

CUevent _pi_platform::evBase_{nullptr};
11 changes: 7 additions & 4 deletions sycl/plugins/cuda/pi_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ using _pi_stream_guard = std::unique_lock<std::mutex>;
/// when devices are used.
///
struct _pi_platform {
static CUevent evBase_; // CUDA event used as base counter
std::vector<std::unique_ptr<_pi_device>> devices_;
};

Expand All @@ -87,6 +86,7 @@ struct _pi_device {

native_type cuDevice_;
CUcontext cuContext_;
CUevent evBase_; // CUDA event used as base counter
std::atomic_uint32_t refCount_;
pi_platform platform_;

Expand All @@ -95,9 +95,10 @@ struct _pi_device {
int max_work_group_size;

public:
_pi_device(native_type cuDevice, CUcontext cuContext, pi_platform platform)
: cuDevice_(cuDevice), cuContext_(cuContext), refCount_{1},
platform_(platform) {}
_pi_device(native_type cuDevice, CUcontext cuContext, CUevent evBase,
pi_platform platform)
: cuDevice_(cuDevice), cuContext_(cuContext),
evBase_(evBase), refCount_{1}, platform_(platform) {}

~_pi_device() { cuDevicePrimaryCtxRelease(cuDevice_); }

Expand All @@ -109,6 +110,8 @@ struct _pi_device {

pi_platform get_platform() const noexcept { return platform_; };

pi_uint64 get_elapsed_time(CUevent) const;

void save_max_work_item_sizes(size_t size,
size_t *save_max_work_item_sizes) noexcept {
memcpy(max_work_item_sizes, save_max_work_item_sizes, size);
Expand Down