Skip to content

Commit 416ec37

Browse files
committed
[SYCL] Handle required wg size attribute in HIP
1 parent 0d3fd4b commit 416ec37

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
200200
size_t MaxWorkGroupSize = 0u;
201201
size_t MaxThreadsPerBlock[3] = {};
202202
bool ProvidedLocalWorkGroupSize = (pLocalWorkSize != nullptr);
203+
size_t *ReqdThreadsPerBlock = hKernel->ReqdThreadsPerBlock;
203204

204205
{
205206
ur_result_t Result = urDeviceGetInfo(
@@ -217,6 +218,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
217218

218219
if (ProvidedLocalWorkGroupSize) {
219220
auto isValid = [&](int dim) {
221+
if (ReqdThreadsPerBlock[dim] != 0 &&
222+
pLocalWorkSize[dim] != ReqdThreadsPerBlock[dim])
223+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
220224
UR_ASSERT(pLocalWorkSize[dim] <= MaxThreadsPerBlock[dim],
221225
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
222226
// Checks that local work sizes are a divisor of the global work sizes

source/adapters/hip/kernel.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,17 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
8787
return ReturnValue(size_t(MaxThreads));
8888
}
8989
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
90-
size_t group_size[3] = {0, 0, 0};
91-
// Returns the work-group size specified in the kernel source or IL.
92-
// If the work-group size is not specified in the kernel source or IL,
93-
// (0, 0, 0) is returned.
94-
// https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
95-
96-
// TODO: can we extract the work group size from the PTX?
97-
return ReturnValue(group_size, 3);
90+
size_t GroupSize[3] = {0, 0, 0};
91+
const auto &ReqdWGSizeMDMap =
92+
hKernel->getProgram()->KernelReqdWorkGroupSizeMD;
93+
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->getName());
94+
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
95+
const auto ReqdWGSize = ReqdWGSizeMD->second;
96+
GroupSize[0] = std::get<0>(ReqdWGSize);
97+
GroupSize[1] = std::get<1>(ReqdWGSize);
98+
GroupSize[2] = std::get<2>(ReqdWGSize);
99+
}
100+
return ReturnValue(GroupSize, 3);
98101
}
99102
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
100103
// OpenCL LOCAL == HIP SHARED

source/adapters/hip/kernel.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ struct ur_kernel_handle_t_ {
4343
ur_program_handle_t Program;
4444
std::atomic_uint32_t RefCount;
4545

46+
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
47+
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
48+
4649
/// Structure that holds the arguments to the kernel.
4750
/// Note earch argument size is known, since it comes
4851
/// from the kernel signature.
@@ -134,6 +137,12 @@ struct ur_kernel_handle_t_ {
134137
ur_context_handle_t Ctxt)
135138
: Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam},
136139
Name{Name}, Context{Ctxt}, Program{Program}, RefCount{1} {
140+
ur_result_t RetError = urKernelGetGroupInfo(
141+
this, Context->getDevice(),
142+
UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
143+
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr);
144+
(void)RetError;
145+
assert(RetError == UR_RESULT_SUCCESS);
137146
urProgramRetain(Program);
138147
urContextRetain(Context);
139148
}

source/adapters/hip/program.cpp

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,54 @@
1111
#include "program.hpp"
1212

1313
ur_program_handle_t_::ur_program_handle_t_(ur_context_handle_t Ctxt)
14-
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1}, Context{
15-
Ctxt} {
14+
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
15+
Context{Ctxt}, KernelReqdWorkGroupSizeMD{} {
1616
urContextRetain(Context);
1717
}
1818

1919
ur_program_handle_t_::~ur_program_handle_t_() { urContextRelease(Context); }
2020

21+
ur_result_t
22+
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
23+
size_t Length) {
24+
for (size_t i = 0; i < Length; ++i) {
25+
const ur_program_metadata_t MetadataElement = Metadata[i];
26+
std::string MetadataElementName{MetadataElement.pName};
27+
28+
std::string Prefix{};
29+
std::string Tag{};
30+
size_t SplitPos = MetadataElementName.rfind('@');
31+
if (SplitPos != std::string::npos) {
32+
Prefix = MetadataElementName.substr(0, SplitPos);
33+
Tag = MetadataElementName.substr(SplitPos, MetadataElementName.length());
34+
}
35+
36+
if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
37+
// If metadata is reqd_work_group_size, record it for the corresponding
38+
// kernel name.
39+
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);
40+
41+
// Expect between 1 and 3 32-bit integer values.
42+
UR_ASSERT(MDElemsSize >= sizeof(std::uint32_t) &&
43+
MDElemsSize <= sizeof(std::uint32_t) * 3,
44+
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
45+
46+
// Get pointer to data, skipping 64-bit size at the start of the data.
47+
const char *ValuePtr =
48+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
49+
sizeof(std::uint64_t);
50+
// Read values and pad with 1's for values not present.
51+
std::uint32_t ReqdWorkGroupElements[] = {1, 1, 1};
52+
std::memcpy(ReqdWorkGroupElements, ValuePtr, MDElemsSize);
53+
KernelReqdWorkGroupSizeMD[Prefix] =
54+
std::make_tuple(ReqdWorkGroupElements[0], ReqdWorkGroupElements[1],
55+
ReqdWorkGroupElements[2]);
56+
}
57+
}
58+
59+
return UR_RESULT_SUCCESS;
60+
}
61+
2162
ur_result_t ur_program_handle_t_::setBinary(const char *Source, size_t Length) {
2263
// Do not re-set program binary data which has already been set as that will
2364
// delete the old binary data.
@@ -246,7 +287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
246287
/// Note: Only supports one device
247288
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
248289
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
249-
const uint8_t *pBinary, const ur_program_properties_t *,
290+
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
250291
ur_program_handle_t *phProgram) {
251292
UR_ASSERT(pBinary != nullptr && size != 0, UR_RESULT_ERROR_INVALID_BINARY);
252293
UR_ASSERT(hContext->getDevice()->get() == hDevice->get(),
@@ -257,8 +298,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
257298
std::unique_ptr<ur_program_handle_t_> RetProgram{
258299
new ur_program_handle_t_{hContext}};
259300

260-
// TODO: Set metadata here and use reqd_work_group_size information.
261-
// See urProgramCreateWithBinary in CUDA adapter.
301+
if (pProperties) {
302+
if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) {
303+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
304+
} else if (pProperties->count == 0 && pProperties->pMetadatas != nullptr) {
305+
return UR_RESULT_ERROR_INVALID_SIZE;
306+
}
307+
Result =
308+
RetProgram->setMetadata(pProperties->pMetadatas, pProperties->count);
309+
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
310+
}
262311

263312
auto pBinary_string = reinterpret_cast<const char *>(pBinary);
264313
if (size == 0) {

source/adapters/hip/program.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ur_api.h>
1313

1414
#include <atomic>
15+
#include <unordered_map>
1516

1617
#include "context.hpp"
1718

@@ -24,6 +25,10 @@ struct ur_program_handle_t_ {
2425
std::atomic_uint32_t RefCount;
2526
ur_context_handle_t Context;
2627

28+
// Metadata
29+
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
30+
KernelReqdWorkGroupSizeMD;
31+
2732
constexpr static size_t MAX_LOG_SIZE = 8192u;
2833

2934
char ErrorLog[MAX_LOG_SIZE], InfoLog[MAX_LOG_SIZE];
@@ -33,6 +38,8 @@ struct ur_program_handle_t_ {
3338
ur_program_handle_t_(ur_context_handle_t Ctxt);
3439
~ur_program_handle_t_();
3540

41+
ur_result_t setMetadata(const ur_program_metadata_t *Metadata, size_t Length);
42+
3643
ur_result_t setBinary(const char *Binary, size_t BinarySizeInBytes);
3744

3845
ur_result_t buildProgram(const char *BuildOptions);

0 commit comments

Comments
 (0)