@@ -1413,6 +1413,8 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
1413
1413
1414
1414
std::unique_ptr<_pi_context> piContextPtr{nullptr };
1415
1415
try {
1416
+ CUcontext current = nullptr ;
1417
+
1416
1418
if (property_cuda_primary) {
1417
1419
// Use the CUDA primary context and assume that we want to use it
1418
1420
// immediately as we want to forge context switches.
@@ -1424,23 +1426,26 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
1424
1426
errcode_ret = PI_CHECK_ERROR (cuCtxPushCurrent (Ctxt));
1425
1427
} else {
1426
1428
// Create a scoped context.
1427
- CUcontext newContext, current ;
1429
+ CUcontext newContext;
1428
1430
PI_CHECK_ERROR (cuCtxGetCurrent (¤t));
1429
1431
errcode_ret = PI_CHECK_ERROR (
1430
1432
cuCtxCreate (&newContext, CU_CTX_MAP_HOST, devices[0 ]->get ()));
1431
1433
piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{
1432
1434
_pi_context::kind::user_defined, newContext, *devices});
1433
- // For scoped contexts keep the last active CUDA one on top of the stack
1434
- // as `cuCtxCreate` replaces it implicitly otherwise.
1435
- if (current != nullptr ) {
1436
- PI_CHECK_ERROR (cuCtxSetCurrent (current));
1437
- }
1438
1435
}
1439
1436
1440
1437
// Use default stream to record base event counter
1441
1438
PI_CHECK_ERROR (cuEventCreate (&piContextPtr->evBase_ , CU_EVENT_DEFAULT));
1442
1439
PI_CHECK_ERROR (cuEventRecord (piContextPtr->evBase_ , 0 ));
1443
1440
1441
+ // For non-primary scoped contexts keep the last active on top of the stack
1442
+ // as `cuCtxCreate` replaces it implicitly otherwise.
1443
+ // Primary contexts are kept on top of the stack, so the previous context
1444
+ // is not queried and therefore not recovered.
1445
+ if (current != nullptr ) {
1446
+ PI_CHECK_ERROR (cuCtxSetCurrent (current));
1447
+ }
1448
+
1444
1449
*retcontext = piContextPtr.release ();
1445
1450
} catch (pi_result err) {
1446
1451
errcode_ret = err;
0 commit comments