diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 5a2c8e0615317..23f67e2ab71d5 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -678,13 +678,16 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name, /// Triggers the CUDA Driver initialization (cuInit) the first time, so this /// must be the first PI API called. /// +/// However because multiple devices in a context is not currently supported, +/// place each device in a separate platform. +/// pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, pi_uint32 *num_platforms) { try { static std::once_flag initFlag; static pi_uint32 numPlatforms = 1; - static _pi_platform platformId; + static std::vector<_pi_platform> platformIds; if (num_entries == 0 && platforms != nullptr) { return PI_INVALID_VALUE; @@ -709,14 +712,18 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, return; } try { - platformId.devices_.reserve(numDevices); + // make one platform per device + numPlatforms = numDevices; + platformIds.resize(numDevices); + for (int i = 0; i < numDevices; ++i) { CUdevice device; err = PI_CHECK_ERROR(cuDeviceGet(&device, i)); - platformId.devices_.emplace_back( - new _pi_device{device, &platformId}); + platformIds[i].devices_.emplace_back( + new _pi_device{device, &platformIds[i]}); + { - const auto &dev = platformId.devices_.back().get(); + const auto &dev = platformIds[i].devices_.back().get(); size_t maxWorkGroupSize = 0u; size_t maxThreadsPerBlock[3] = {}; pi_result retError = cuda_piDeviceGetInfo( @@ -737,11 +744,17 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, } } catch (const std::bad_alloc &) { // Signal out-of-memory situation - platformId.devices_.clear(); + for (int i = 0; i < numDevices; ++i) { + platformIds[i].devices_.clear(); + } + platformIds.clear(); err = PI_OUT_OF_HOST_MEMORY; } catch (...) { // Clear and rethrow to allow retry - platformId.devices_.clear(); + for (int i = 0; i < numDevices; ++i) { + platformIds[i].devices_.clear(); + } + platformIds.clear(); throw; } }, @@ -752,7 +765,9 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, } if (platforms != nullptr) { - *platforms = &platformId; + for (unsigned i = 0; i < std::min(num_entries, numPlatforms); ++i) { + platforms[i] = &platformIds[i]; + } } return err; diff --git a/sycl/plugins/hip/pi_hip.cpp b/sycl/plugins/hip/pi_hip.cpp index d5019a234d30c..4a34005c63be5 100644 --- a/sycl/plugins/hip/pi_hip.cpp +++ b/sycl/plugins/hip/pi_hip.cpp @@ -671,13 +671,16 @@ extern "C" { /// Triggers the HIP Driver initialization (hipInit) the first time, so this /// must be the first PI API called. /// +/// However because multiple devices in a context is not currently supported, +/// place each device in a separate platform. +/// pi_result hip_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, pi_uint32 *num_platforms) { try { static std::once_flag initFlag; static pi_uint32 numPlatforms = 1; - static _pi_platform platformId; + static std::vector<_pi_platform> platformIds; if (num_entries == 0 and platforms != nullptr) { return PI_INVALID_VALUE; @@ -707,20 +710,28 @@ pi_result hip_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, return; } try { - platformId.devices_.reserve(numDevices); + numPlatforms = numDevices; + platformIds.resize(numDevices); + for (int i = 0; i < numDevices; ++i) { hipDevice_t device; err = PI_CHECK_ERROR(hipDeviceGet(&device, i)); - platformId.devices_.emplace_back( - new _pi_device{device, &platformId}); + platformIds[i].devices_.emplace_back( + new _pi_device{device, &platformIds[i]}); } } catch (const std::bad_alloc &) { // Signal out-of-memory situation - platformId.devices_.clear(); + for (int i = 0; i < numDevices; ++i) { + platformIds[i].devices_.clear(); + } + platformIds.clear(); err = PI_OUT_OF_HOST_MEMORY; } catch (...) { // Clear and rethrow to allow retry - platformId.devices_.clear(); + for (int i = 0; i < numDevices; ++i) { + platformIds[i].devices_.clear(); + } + platformIds.clear(); throw; } }, @@ -731,7 +742,9 @@ pi_result hip_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms, } if (platforms != nullptr) { - *platforms = &platformId; + for (unsigned i = 0; i < std::min(num_entries, numPlatforms); ++i) { + platforms[i] = &platformIds[i]; + } } return err;