11
11
#include " program.hpp"
12
12
13
13
ur_program_handle_t_::ur_program_handle_t_ (ur_context_handle_t Ctxt)
14
- : Module{nullptr }, Binary{}, BinarySizeInBytes{0 }, RefCount{1 }, Context{
15
- Ctxt} {
14
+ : Module{nullptr }, Binary{}, BinarySizeInBytes{0 }, RefCount{1 },
15
+ Context{ Ctxt}, KernelReqdWorkGroupSizeMD{ } {
16
16
urContextRetain (Context);
17
17
}
18
18
19
19
ur_program_handle_t_::~ur_program_handle_t_ () { urContextRelease (Context); }
20
20
21
+ ur_result_t
22
+ ur_program_handle_t_::setMetadata (const ur_program_metadata_t *Metadata,
23
+ size_t Length) {
24
+ for (size_t i = 0 ; i < Length; ++i) {
25
+ const ur_program_metadata_t MetadataElement = Metadata[i];
26
+ std::string MetadataElementName{MetadataElement.pName };
27
+
28
+ std::string Prefix{};
29
+ std::string Tag{};
30
+ size_t SplitPos = MetadataElementName.rfind (' @' );
31
+ if (SplitPos != std::string::npos) {
32
+ Prefix = MetadataElementName.substr (0 , SplitPos);
33
+ Tag = MetadataElementName.substr (SplitPos, MetadataElementName.length ());
34
+ }
35
+
36
+ if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
37
+ // If metadata is reqd_work_group_size, record it for the corresponding
38
+ // kernel name.
39
+ size_t MDElemsSize = MetadataElement.size - sizeof (std::uint64_t );
40
+
41
+ // Expect between 1 and 3 32-bit integer values.
42
+ UR_ASSERT (MDElemsSize >= sizeof (std::uint32_t ) &&
43
+ MDElemsSize <= sizeof (std::uint32_t ) * 3 ,
44
+ UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
45
+
46
+ // Get pointer to data, skipping 64-bit size at the start of the data.
47
+ const char *ValuePtr =
48
+ reinterpret_cast <const char *>(MetadataElement.value .pData ) +
49
+ sizeof (std::uint64_t );
50
+ // Read values and pad with 1's for values not present.
51
+ std::uint32_t ReqdWorkGroupElements[] = {1 , 1 , 1 };
52
+ std::memcpy (ReqdWorkGroupElements, ValuePtr, MDElemsSize);
53
+ KernelReqdWorkGroupSizeMD[Prefix] =
54
+ std::make_tuple (ReqdWorkGroupElements[0 ], ReqdWorkGroupElements[1 ],
55
+ ReqdWorkGroupElements[2 ]);
56
+ }
57
+ }
58
+
59
+ return UR_RESULT_SUCCESS;
60
+ }
61
+
21
62
ur_result_t ur_program_handle_t_::setBinary (const char *Source, size_t Length) {
22
63
// Do not re-set program binary data which has already been set as that will
23
64
// delete the old binary data.
@@ -246,7 +287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
246
287
// / Note: Only supports one device
247
288
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary (
248
289
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
249
- const uint8_t *pBinary, const ur_program_properties_t *,
290
+ const uint8_t *pBinary, const ur_program_properties_t *pProperties ,
250
291
ur_program_handle_t *phProgram) {
251
292
UR_ASSERT (pBinary != nullptr && size != 0 , UR_RESULT_ERROR_INVALID_BINARY);
252
293
UR_ASSERT (hContext->getDevice ()->get () == hDevice->get (),
@@ -257,8 +298,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
257
298
std::unique_ptr<ur_program_handle_t_> RetProgram{
258
299
new ur_program_handle_t_{hContext}};
259
300
260
- // TODO: Set metadata here and use reqd_work_group_size information.
261
- // See urProgramCreateWithBinary in CUDA adapter.
301
+ if (pProperties) {
302
+ if (pProperties->count > 0 && pProperties->pMetadatas == nullptr ) {
303
+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
304
+ } else if (pProperties->count == 0 && pProperties->pMetadatas != nullptr ) {
305
+ return UR_RESULT_ERROR_INVALID_SIZE;
306
+ }
307
+ Result =
308
+ RetProgram->setMetadata (pProperties->pMetadatas , pProperties->count );
309
+ UR_ASSERT (Result == UR_RESULT_SUCCESS, Result);
310
+ }
262
311
263
312
auto pBinary_string = reinterpret_cast <const char *>(pBinary);
264
313
if (size == 0 ) {
0 commit comments