Skip to content

Commit bf9b8df

Browse files
committed
[SYCL] Handle required wg size attribute in HIP
1 parent 9fc8230 commit bf9b8df

File tree

7 files changed

+67
-22
lines changed

7 files changed

+67
-22
lines changed

source/adapters/cuda/program.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "ur_util.hpp"
1112
#include "program.hpp"
1213

1314
bool getMaxRegistersJitOptionValue(const std::string &BuildOptions,
@@ -52,15 +53,6 @@ ur_program_handle_t_::ur_program_handle_t_(ur_context_handle_t Context)
5253

5354
ur_program_handle_t_::~ur_program_handle_t_() { urContextRelease(Context); }
5455

55-
std::pair<std::string, std::string>
56-
splitMetadataName(const std::string &metadataName) {
57-
size_t splitPos = metadataName.rfind('@');
58-
if (splitPos == std::string::npos)
59-
return std::make_pair(metadataName, std::string{});
60-
return std::make_pair(metadataName.substr(0, splitPos),
61-
metadataName.substr(splitPos, metadataName.length()));
62-
}
63-
6456
ur_result_t
6557
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
6658
size_t Length) {

source/adapters/hip/enqueue.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
201201
size_t MaxWorkGroupSize = 0u;
202202
size_t MaxThreadsPerBlock[3] = {};
203203
bool ProvidedLocalWorkGroupSize = (pLocalWorkSize != nullptr);
204+
size_t *ReqdThreadsPerBlock = hKernel->ReqdThreadsPerBlock;
204205

205206
{
206207
ur_result_t Result = urDeviceGetInfo(
@@ -218,6 +219,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
218219

219220
if (ProvidedLocalWorkGroupSize) {
220221
auto isValid = [&](int dim) {
222+
if (ReqdThreadsPerBlock[dim] != 0 &&
223+
pLocalWorkSize[dim] != ReqdThreadsPerBlock[dim])
224+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
221225
UR_ASSERT(pLocalWorkSize[dim] <= MaxThreadsPerBlock[dim],
222226
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
223227
// Checks that local work sizes are a divisor of the global work sizes

source/adapters/hip/kernel.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,17 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
8787
return ReturnValue(size_t(MaxThreads));
8888
}
8989
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
90-
size_t group_size[3] = {0, 0, 0};
91-
// Returns the work-group size specified in the kernel source or IL.
92-
// If the work-group size is not specified in the kernel source or IL,
93-
// (0, 0, 0) is returned.
94-
// https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
95-
96-
// TODO: can we extract the work group size from the PTX?
97-
return ReturnValue(group_size, 3);
90+
size_t GroupSize[3] = {0, 0, 0};
91+
const auto &ReqdWGSizeMDMap =
92+
hKernel->getProgram()->KernelReqdWorkGroupSizeMD;
93+
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->getName());
94+
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
95+
const auto ReqdWGSize = ReqdWGSizeMD->second;
96+
GroupSize[0] = std::get<0>(ReqdWGSize);
97+
GroupSize[1] = std::get<1>(ReqdWGSize);
98+
GroupSize[2] = std::get<2>(ReqdWGSize);
99+
}
100+
return ReturnValue(GroupSize, 3);
98101
}
99102
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
100103
// OpenCL LOCAL == HIP SHARED

source/adapters/hip/kernel.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ struct ur_kernel_handle_t_ {
4343
ur_program_handle_t Program;
4444
std::atomic_uint32_t RefCount;
4545

46+
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
47+
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
48+
4649
/// Structure that holds the arguments to the kernel.
4750
/// Note earch argument size is known, since it comes
4851
/// from the kernel signature.
@@ -134,6 +137,12 @@ struct ur_kernel_handle_t_ {
134137
ur_context_handle_t Ctxt)
135138
: Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam},
136139
Name{Name}, Context{Ctxt}, Program{Program}, RefCount{1} {
140+
ur_result_t RetError = urKernelGetGroupInfo(
141+
this, Context->getDevice(),
142+
UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
143+
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr);
144+
(void)RetError;
145+
assert(RetError == UR_RESULT_SUCCESS);
137146
urProgramRetain(Program);
138147
urContextRetain(Context);
139148
}

source/adapters/hip/program.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "ur_util.hpp"
1112
#include "program.hpp"
1213

1314
#ifdef SYCL_ENABLE_KERNEL_FUSION
@@ -75,8 +76,8 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog,
7576
#endif
7677

7778
ur_program_handle_t_::ur_program_handle_t_(ur_context_handle_t Ctxt)
78-
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1}, Context{
79-
Ctxt} {
79+
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
80+
Context{Ctxt}, KernelReqdWorkGroupSizeMD{} {
8081
urContextRetain(Context);
8182
}
8283

@@ -94,7 +95,32 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
9495
assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32);
9596
IsRelocatable = MetadataElement.value.data32;
9697
}
98+
99+
auto [Prefix, Tag] = splitMetadataName(MetadataElementName);
100+
101+
if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
102+
// If metadata is reqd_work_group_size, record it for the corresponding
103+
// kernel name.
104+
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);
105+
106+
// Expect between 1 and 3 32-bit integer values.
107+
UR_ASSERT(MDElemsSize >= sizeof(std::uint32_t) &&
108+
MDElemsSize <= sizeof(std::uint32_t) * 3,
109+
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
110+
111+
// Get pointer to data, skipping 64-bit size at the start of the data.
112+
const char *ValuePtr =
113+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
114+
sizeof(std::uint64_t);
115+
// Read values and pad with 1's for values not present.
116+
std::uint32_t ReqdWorkGroupElements[] = {1, 1, 1};
117+
std::memcpy(ReqdWorkGroupElements, ValuePtr, MDElemsSize);
118+
KernelReqdWorkGroupSizeMD[Prefix] =
119+
std::make_tuple(ReqdWorkGroupElements[0], ReqdWorkGroupElements[1],
120+
ReqdWorkGroupElements[2]);
121+
}
97122
}
123+
98124
return UR_RESULT_SUCCESS;
99125
}
100126

@@ -410,8 +436,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
410436
std::unique_ptr<ur_program_handle_t_> RetProgram{
411437
new ur_program_handle_t_{hContext}};
412438

413-
// TODO: Set metadata here and use reqd_work_group_size information.
414-
// See urProgramCreateWithBinary in CUDA adapter.
415439
if (pProperties) {
416440
if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) {
417441
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
@@ -420,8 +444,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
420444
}
421445
Result =
422446
RetProgram->setMetadata(pProperties->pMetadatas, pProperties->count);
447+
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
423448
}
424-
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
425449

426450
auto pBinary_string = reinterpret_cast<const char *>(pBinary);
427451
if (size == 0) {

source/adapters/hip/program.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ur_api.h>
1313

1414
#include <atomic>
15+
#include <unordered_map>
1516

1617
#include "context.hpp"
1718

@@ -28,6 +29,10 @@ struct ur_program_handle_t_ {
2829
// Metadata
2930
bool IsRelocatable = false;
3031

32+
// Metadata
33+
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
34+
KernelReqdWorkGroupSizeMD;
35+
3136
constexpr static size_t MAX_LOG_SIZE = 8192u;
3237

3338
char ErrorLog[MAX_LOG_SIZE], InfoLog[MAX_LOG_SIZE];

source/common/ur_util.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,12 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) {
293293
}
294294
}
295295

296+
inline std::pair<std::string, std::string>
297+
splitMetadataName(const std::string &metadataName) {
298+
size_t splitPos = metadataName.rfind('@');
299+
if (splitPos == std::string::npos)
300+
return std::make_pair(metadataName, std::string{});
301+
return std::make_pair(metadataName.substr(0, splitPos),
302+
metadataName.substr(splitPos, metadataName.length()));
303+
}
296304
#endif /* UR_UTIL_H */

0 commit comments

Comments
 (0)