diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 8a44c3ff6eb56..6d562ccecfd08 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -542,9 +542,28 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, static std::once_flag initFlag; static _pi_platform platformId; - std::call_once(initFlag, - [](pi_result &err) { err = PI_CHECK_ERROR(cuInit(0)); }, - err); + std::call_once( + initFlag, + [](pi_result &err) { + err = PI_CHECK_ERROR(cuInit(0)); + + int numDevices = 0; + err = PI_CHECK_ERROR(cuDeviceGetCount(&numDevices)); + platformId.devices_.reserve(numDevices); + try { + for (int i = 0; i < numDevices; ++i) { + CUdevice device; + err = PI_CHECK_ERROR(cuDeviceGet(&device, i)); + platformId.devices_.emplace_back( + new _pi_device{device, &platformId}); + } + } catch (...) { + // Clear and rethrow to allow retry + platformId.devices_.clear(); + throw; + } + }, + err); *platforms = &platformId; } @@ -594,22 +613,16 @@ pi_result cuda_piDevicesGet(pi_platform platform, pi_device_type device_type, pi_result err = PI_SUCCESS; const bool askingForGPU = (device_type & PI_DEVICE_TYPE_GPU); - size_t numDevices = askingForGPU ? 1 : 0; + size_t numDevices = askingForGPU ? platform->devices_.size() : 0; try { if (num_devices) { *num_devices = numDevices; } - if (askingForGPU) { - if (devices) { - CUdevice device; - err = PI_CHECK_ERROR(cuDeviceGet(&device, 0)); - *devices = new _pi_device{device, platform}; - } - } else { - if (devices) { - *devices = nullptr; + if (askingForGPU && devices) { + for (size_t i = 0; i < std::min(size_t(num_entries), numDevices); ++i) { + devices[i] = platform->devices_[i].get(); } } diff --git a/sycl/plugins/cuda/pi_cuda.hpp b/sycl/plugins/cuda/pi_cuda.hpp index 2ec7ad49abc7f..9978917b321c8 100644 --- a/sycl/plugins/cuda/pi_cuda.hpp +++ b/sycl/plugins/cuda/pi_cuda.hpp @@ -46,6 +46,7 @@ pi_result cuda_piKernelRelease(pi_kernel); } struct _pi_platform { + std::vector> devices_; }; struct _pi_device {