diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 8a47942427283..05ce962f85d16 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -4825,12 +4825,19 @@ pi_result cuda_piextUSMGetMemAllocInfo(pi_context context, const void *ptr, #endif } case PI_MEM_ALLOC_DEVICE: { - unsigned int value; + // get device index associated with this pointer + unsigned int device_idx; result = PI_CHECK_ERROR(cuPointerGetAttribute( - &value, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)ptr)); - pi_platform platform; - result = cuda_piPlatformsGet(1, &platform, nullptr); - pi_device device = platform->devices_[value].get(); + &device_idx, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)ptr)); + + // currently each device is in its own platform, so find the platform at + // the same index + std::vector platforms; + platforms.resize(device_idx + 1); + result = cuda_piPlatformsGet(device_idx + 1, platforms.data(), nullptr); + + // get the device from the platform + pi_device device = platforms[device_idx]->devices_[0].get(); return getInfo(param_value_size, param_value, param_value_size_ret, device); } diff --git a/sycl/plugins/hip/pi_hip.cpp b/sycl/plugins/hip/pi_hip.cpp index d81b24396bd0d..40439ccde48be 100644 --- a/sycl/plugins/hip/pi_hip.cpp +++ b/sycl/plugins/hip/pi_hip.cpp @@ -4792,15 +4792,19 @@ pi_result hip_piextUSMGetMemAllocInfo(pi_context context, const void *ptr, } case PI_MEM_ALLOC_DEVICE: { - unsigned int value; + // get device index associated with this pointer result = PI_CHECK_ERROR( hipPointerGetAttributes(&hipPointerAttributeType, ptr)); - auto devicePointer = - static_cast(hipPointerAttributeType.devicePointer); - value = *devicePointer; - pi_platform platform; - result = hip_piPlatformsGet(1, &platform, nullptr); - pi_device device = platform->devices_[value].get(); + int device_idx = hipPointerAttributeType.device; + + // currently each device is in its own platform, so find the platform at + // the same index + std::vector platforms; + platforms.resize(device_idx + 1); + result = hip_piPlatformsGet(device_idx + 1, platforms.data(), nullptr); + + // get the device from the platform + pi_device device = platforms[device_idx]->devices_[0].get(); return getInfo(param_value_size, param_value, param_value_size_ret, device); }