diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index b02462259ea9d..b7c64ef9f9a58 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -188,7 +188,7 @@ pi_result check_error(CUresult result, const char *function, int line, /// contexts to be restored by SYCL. class ScopedContext { public: - ScopedContext(pi_context ctxt) { + ScopedContext(pi_context ctxt) : device(nullptr) { if (!ctxt) { throw PI_ERROR_INVALID_CONTEXT; } @@ -196,9 +196,22 @@ class ScopedContext { set_context(ctxt->get()); } - ScopedContext(CUcontext ctxt) { set_context(ctxt); } + ScopedContext(CUcontext ctxt) : device(nullptr) { set_context(ctxt); } - ~ScopedContext() {} + // Creating a scoped context from a device will simply use the primary + // context, this should be used when there is no other appropriate context, + // such as for the device infos. + ScopedContext(pi_device device) : device(device) { + CUcontext ctxt; + cuDevicePrimaryCtxRetain(&ctxt, device->get()); + + set_context(ctxt); + } + + ~ScopedContext() { + if (device) + cuDevicePrimaryCtxRelease(device->get()); + } private: void set_context(CUcontext desired) { @@ -212,6 +225,8 @@ class ScopedContext { PI_CHECK_ERROR(cuCtxSetCurrent(desired)); } } + + pi_device device; }; /// \cond NODOXY @@ -1946,29 +1961,12 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name, } case PI_EXT_INTEL_DEVICE_INFO_FREE_MEMORY: { - // Check the device of the currently set context uses the same device. - // CUDA_ERROR_INVALID_CONTEXT signifies the absence of an active context. - CUdevice current_ctx_device; - CUresult current_ctx_device_ret = cuCtxGetDevice(¤t_ctx_device); - if (current_ctx_device_ret != CUDA_ERROR_INVALID_CONTEXT) - PI_CHECK_ERROR(current_ctx_device_ret); - bool need_primary_ctx = - current_ctx_device_ret == CUDA_ERROR_INVALID_CONTEXT || - current_ctx_device != device->get(); - if (need_primary_ctx) { - // Use the primary context for the device if no context with the device is - // set. - CUcontext primary_context; - PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&primary_context, device->get())); - PI_CHECK_ERROR(cuCtxSetCurrent(primary_context)); - } + ScopedContext active(device); size_t FreeMemory = 0; size_t TotalMemory = 0; sycl::detail::pi::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) == CUDA_SUCCESS, "failed cuMemGetInfo() API."); - if (need_primary_ctx) - PI_CHECK_ERROR(cuDevicePrimaryCtxRelease(device->get())); return getInfo(param_value_size, param_value, param_value_size_ret, FreeMemory); }