Skip to content

Commit 12e59d4

Browse files
committed
[L0] Add support for multi-device kernel compilation
Signed-off-by: Spruit, Neil R <[email protected]>
1 parent 5e914c5 commit 12e59d4

File tree

4 files changed

+188
-112
lines changed

4 files changed

+188
-112
lines changed

source/adapters/level_zero/kernel.cpp

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141
*OutEvent ///< [in,out][optional] return an event object that identifies
4242
///< this particular kernel execution instance.
4343
) {
44+
auto ZeDevice = Queue->Device->ZeDevice;
45+
46+
ze_kernel_handle_t ZeKernel{};
47+
if (Kernel->ZeKernelMap.empty()) {
48+
ZeKernel = Kernel->ZeKernel;
49+
} else {
50+
auto It = Kernel->ZeKernelMap.find(ZeDevice);
51+
ZeKernel = It->second;
52+
}
4453
// Lock automatically releases when this goes out of scope.
4554
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
4655
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
5160
}
5261

5362
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
54-
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
63+
(ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
5564
GlobalWorkOffset[2]));
5665
}
5766

@@ -65,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6574
Queue->Device));
6675
}
6776
ZE2UR_CALL(zeKernelSetArgumentValue,
68-
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
77+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
6978
}
7079
Kernel->PendingArguments.clear();
7180

@@ -99,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99108
}
100109
if (SuggestGroupSize) {
101110
ZE2UR_CALL(zeKernelSuggestGroupSize,
102-
(Kernel->ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
111+
(ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
103112
GlobalWorkSize[2], &WG[0], &WG[1], &WG[2]));
104113
} else {
105114
for (int I : {0, 1, 2}) {
@@ -175,7 +184,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175184
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176185
}
177186

178-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
187+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
179188

180189
bool UseCopyEngine = false;
181190
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +236,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227236
Queue->CaptureIndirectAccesses();
228237
// Add the command to the command list, which implies submission.
229238
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
230-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
231-
ZeEvent, (*Event)->WaitList.Length,
232-
(*Event)->WaitList.ZeEventList));
239+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
240+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
233241
} else {
234242
// Add the command to the command list for later submission.
235243
// No lock is needed here, unlike the immediate commandlist case above,
236244
// because the kernels are not actually submitted yet. Kernels will be
237245
// submitted only when the comamndlist is closed. Then, a lock is held.
238246
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
239-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
240-
ZeEvent, (*Event)->WaitList.Length,
241-
(*Event)->WaitList.ZeEventList));
247+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
248+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
242249
}
243250

244251
urPrint("calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +370,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363370
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364371
}
365372

366-
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
367-
ZeKernelDesc.flags = 0;
368-
ZeKernelDesc.pKernelName = KernelName;
369-
370-
ze_kernel_handle_t ZeKernel;
371-
ZE2UR_CALL(zeKernelCreate, (Program->ZeModule, &ZeKernelDesc, &ZeKernel));
372-
373373
try {
374-
ur_kernel_handle_t_ *UrKernel =
375-
new ur_kernel_handle_t_(ZeKernel, true, Program);
374+
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
376375
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);
377376
} catch (const std::bad_alloc &) {
378377
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379378
} catch (...) {
380379
return UR_RESULT_ERROR_UNKNOWN;
381380
}
382381

382+
for (auto It : Program->ZeModuleMap) {
383+
auto ZeModule = It.second;
384+
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
385+
ZeKernelDesc.flags = 0;
386+
ZeKernelDesc.pKernelName = KernelName;
387+
388+
ze_kernel_handle_t ZeKernel;
389+
ZE2UR_CALL(zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
390+
391+
auto ZeDevice = It.first;
392+
393+
// Store the kernel in the ZeKernelMap so the correct
394+
// kernel can be retrieved later for a specific device
395+
// where a queue is being submitted.
396+
(*RetKernel)->ZeKernelMap[ZeDevice] = ZeKernel;
397+
(*RetKernel)->ZeKernels.push_back(ZeKernel);
398+
399+
// If the device used to create the module's kernel is a root-device
400+
// then store the kernel also using the sub-devices, since application
401+
// could submit the root-device's kernel to a sub-device's queue.
402+
uint32_t SubDevicesCount = 0;
403+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, nullptr);
404+
std::vector<ze_device_handle_t> ZeSubDevices(SubDevicesCount);
405+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, ZeSubDevices.data());
406+
for (auto ZeSubDevice : ZeSubDevices) {
407+
(*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel;
408+
}
409+
}
410+
411+
(*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second;
412+
383413
UR_CALL((*RetKernel)->initialize());
384414

385415
return UR_RESULT_SUCCESS;
@@ -409,8 +439,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409439
}
410440

411441
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
412-
ZE2UR_CALL(zeKernelSetArgumentValue,
413-
(Kernel->ZeKernel, ArgIndex, ArgSize, PArgValue));
442+
for (auto It : Kernel->ZeKernelMap) {
443+
auto ZeKernel = It.second;
444+
ZE2UR_CALL(zeKernelSetArgumentValue,
445+
(ZeKernel, ArgIndex, ArgSize, PArgValue));
446+
}
414447

415448
return UR_RESULT_SUCCESS;
416449
}
@@ -596,10 +629,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596629

597630
auto KernelProgram = Kernel->Program;
598631
if (Kernel->OwnNativeHandle) {
599-
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (Kernel->ZeKernel));
600-
// Gracefully handle the case that L0 was already unloaded.
601-
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602-
return ze2urResult(ZeResult);
632+
for (auto &ZeKernel : Kernel->ZeKernels) {
633+
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (ZeKernel));
634+
// Gracefully handle the case that L0 was already unloaded.
635+
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
636+
return ze2urResult(ZeResult);
637+
}
603638
}
604639
if (IndirectAccessTrackingEnabled) {
605640
UR_CALL(urContextRelease(KernelProgram->Context));
@@ -639,6 +674,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639674
std::ignore = PropSize;
640675
std::ignore = Properties;
641676

677+
auto ZeKernel = Kernel->ZeKernel;
642678
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
643679
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644680
*(static_cast<const ur_bool_t *>(PropValue)) == true) {
@@ -649,7 +685,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649685
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650686
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651687
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652-
ZE2UR_CALL(zeKernelSetIndirectAccess, (Kernel->ZeKernel, IndirectFlags));
688+
ZE2UR_CALL(zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653689
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654690
ze_cache_config_flag_t ZeCacheConfig{};
655691
auto CacheConfig =
@@ -663,7 +699,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663699
else
664700
// Unexpected cache configuration value.
665701
return UR_RESULT_ERROR_INVALID_VALUE;
666-
ZE2UR_CALL(zeKernelSetCacheConfig, (Kernel->ZeKernel, ZeCacheConfig););
702+
ZE2UR_CALL(zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667703
} else {
668704
urPrint("urKernelSetExecInfo: unsupported ParamName\n");
669705
return UR_RESULT_ERROR_INVALID_VALUE;

source/adapters/level_zero/kernel.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
#include <unordered_set>
1515

1616
struct ur_kernel_handle_t_ : _ur_object {
17-
ur_kernel_handle_t_(ze_kernel_handle_t Kernel, bool OwnZeHandle,
18-
ur_program_handle_t Program)
19-
: Program{Program}, ZeKernel{Kernel}, SubmissionsCount{0}, MemAllocs{} {
17+
ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program)
18+
: Program{Program}, SubmissionsCount{0}, MemAllocs{} {
2019
OwnNativeHandle = OwnZeHandle;
2120
}
2221

@@ -35,6 +34,15 @@ struct ur_kernel_handle_t_ : _ur_object {
3534
// Level Zero function handle.
3635
ze_kernel_handle_t ZeKernel;
3736

37+
// Map of L0 kernels created for all the devices for which a UR Program
38+
// has been built. It may contain duplicated kernel entries for a root
39+
// device and its sub-devices.
40+
std::unordered_map<ze_device_handle_t, ze_kernel_handle_t> ZeKernelMap;
41+
42+
// Vector of L0 kernels. Each entry is unique, so this is used for
43+
// destroying the kernels instead of ZeKernelMap
44+
std::vector<ze_kernel_handle_t> ZeKernels;
45+
3846
// Counter to track the number of submissions of the kernel.
3947
// When this value is zero, it means that kernel is not submitted for an
4048
// execution - at this time we can release memory allocations referenced by

0 commit comments

Comments
 (0)