@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
41
41
*OutEvent // /< [in,out][optional] return an event object that identifies
42
42
// /< this particular kernel execution instance.
43
43
) {
44
+ auto ZeDevice = Queue->Device ->ZeDevice ;
45
+
46
+ ze_kernel_handle_t ZeKernel{};
47
+ if (Kernel->ZeKernelMap .empty ()) {
48
+ ZeKernel = Kernel->ZeKernel ;
49
+ } else {
50
+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
51
+ ZeKernel = It->second ;
52
+ }
44
53
// Lock automatically releases when this goes out of scope.
45
54
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
46
55
Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
51
60
}
52
61
53
62
ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
54
- (Kernel-> ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
63
+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
55
64
GlobalWorkOffset[2 ]));
56
65
}
57
66
@@ -65,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
65
74
Queue->Device ));
66
75
}
67
76
ZE2UR_CALL (zeKernelSetArgumentValue,
68
- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
77
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
69
78
}
70
79
Kernel->PendingArguments .clear ();
71
80
@@ -99,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99
108
}
100
109
if (SuggestGroupSize) {
101
110
ZE2UR_CALL (zeKernelSuggestGroupSize,
102
- (Kernel-> ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
111
+ (ZeKernel, GlobalWorkSize[0 ], GlobalWorkSize[1 ],
103
112
GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
104
113
} else {
105
114
for (int I : {0 , 1 , 2 }) {
@@ -175,7 +184,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175
184
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176
185
}
177
186
178
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
187
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
179
188
180
189
bool UseCopyEngine = false ;
181
190
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +236,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227
236
Queue->CaptureIndirectAccesses ();
228
237
// Add the command to the command list, which implies submission.
229
238
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
230
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
231
- ZeEvent, (*Event)->WaitList .Length ,
232
- (*Event)->WaitList .ZeEventList ));
239
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
240
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
233
241
} else {
234
242
// Add the command to the command list for later submission.
235
243
// No lock is needed here, unlike the immediate commandlist case above,
236
244
// because the kernels are not actually submitted yet. Kernels will be
237
245
// submitted only when the comamndlist is closed. Then, a lock is held.
238
246
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
239
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
240
- ZeEvent, (*Event)->WaitList .Length ,
241
- (*Event)->WaitList .ZeEventList ));
247
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
248
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
242
249
}
243
250
244
251
urPrint (" calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +370,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363
370
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364
371
}
365
372
366
- ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
367
- ZeKernelDesc.flags = 0 ;
368
- ZeKernelDesc.pKernelName = KernelName;
369
-
370
- ze_kernel_handle_t ZeKernel;
371
- ZE2UR_CALL (zeKernelCreate, (Program->ZeModule , &ZeKernelDesc, &ZeKernel));
372
-
373
373
try {
374
- ur_kernel_handle_t_ *UrKernel =
375
- new ur_kernel_handle_t_ (ZeKernel, true , Program);
374
+ ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_ (true , Program);
376
375
*RetKernel = reinterpret_cast <ur_kernel_handle_t >(UrKernel);
377
376
} catch (const std::bad_alloc &) {
378
377
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379
378
} catch (...) {
380
379
return UR_RESULT_ERROR_UNKNOWN;
381
380
}
382
381
382
+ for (auto It : Program->ZeModuleMap ) {
383
+ auto ZeModule = It.second ;
384
+ ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
385
+ ZeKernelDesc.flags = 0 ;
386
+ ZeKernelDesc.pKernelName = KernelName;
387
+
388
+ ze_kernel_handle_t ZeKernel;
389
+ ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
390
+
391
+ auto ZeDevice = It.first ;
392
+
393
+ // Store the kernel in the ZeKernelMap so the correct
394
+ // kernel can be retrieved later for a specific device
395
+ // where a queue is being submitted.
396
+ (*RetKernel)->ZeKernelMap [ZeDevice] = ZeKernel;
397
+ (*RetKernel)->ZeKernels .push_back (ZeKernel);
398
+
399
+ // If the device used to create the module's kernel is a root-device
400
+ // then store the kernel also using the sub-devices, since application
401
+ // could submit the root-device's kernel to a sub-device's queue.
402
+ uint32_t SubDevicesCount = 0 ;
403
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, nullptr );
404
+ std::vector<ze_device_handle_t > ZeSubDevices (SubDevicesCount);
405
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, ZeSubDevices.data ());
406
+ for (auto ZeSubDevice : ZeSubDevices) {
407
+ (*RetKernel)->ZeKernelMap [ZeSubDevice] = ZeKernel;
408
+ }
409
+ }
410
+
411
+ (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap .begin ()->second ;
412
+
383
413
UR_CALL ((*RetKernel)->initialize ());
384
414
385
415
return UR_RESULT_SUCCESS;
@@ -409,8 +439,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409
439
}
410
440
411
441
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
412
- ZE2UR_CALL (zeKernelSetArgumentValue,
413
- (Kernel->ZeKernel , ArgIndex, ArgSize, PArgValue));
442
+ for (auto It : Kernel->ZeKernelMap ) {
443
+ auto ZeKernel = It.second ;
444
+ ZE2UR_CALL (zeKernelSetArgumentValue,
445
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
446
+ }
414
447
415
448
return UR_RESULT_SUCCESS;
416
449
}
@@ -596,10 +629,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596
629
597
630
auto KernelProgram = Kernel->Program ;
598
631
if (Kernel->OwnNativeHandle ) {
599
- auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (Kernel->ZeKernel ));
600
- // Gracefully handle the case that L0 was already unloaded.
601
- if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602
- return ze2urResult (ZeResult);
632
+ for (auto &ZeKernel : Kernel->ZeKernels ) {
633
+ auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (ZeKernel));
634
+ // Gracefully handle the case that L0 was already unloaded.
635
+ if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
636
+ return ze2urResult (ZeResult);
637
+ }
603
638
}
604
639
if (IndirectAccessTrackingEnabled) {
605
640
UR_CALL (urContextRelease (KernelProgram->Context ));
@@ -639,6 +674,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639
674
std::ignore = PropSize;
640
675
std::ignore = Properties;
641
676
677
+ auto ZeKernel = Kernel->ZeKernel ;
642
678
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
643
679
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644
680
*(static_cast <const ur_bool_t *>(PropValue)) == true ) {
@@ -649,7 +685,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649
685
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650
686
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651
687
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652
- ZE2UR_CALL (zeKernelSetIndirectAccess, (Kernel-> ZeKernel , IndirectFlags));
688
+ ZE2UR_CALL (zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653
689
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654
690
ze_cache_config_flag_t ZeCacheConfig{};
655
691
auto CacheConfig =
@@ -663,7 +699,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663
699
else
664
700
// Unexpected cache configuration value.
665
701
return UR_RESULT_ERROR_INVALID_VALUE;
666
- ZE2UR_CALL (zeKernelSetCacheConfig, (Kernel-> ZeKernel , ZeCacheConfig););
702
+ ZE2UR_CALL (zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667
703
} else {
668
704
urPrint (" urKernelSetExecInfo: unsupported ParamName\n " );
669
705
return UR_RESULT_ERROR_INVALID_VALUE;
0 commit comments