Skip to content

Commit 6aac9ac

Browse files
committed
[SYCL] Add work_group_num_dim metadata
Emit metadata to describe number of dimensions specified in reqd_work_group_size. This is needed in order to be able to use that metadata correctly, since it was specified for OpenCL, and SYCL piggy-backs on it, backends correctly assert if all 3 dimensions are not provided. work_group_num_dim allows the compiler to pad the missing dimensions with 1, while preserving the notion of how many dimensions were specified.
1 parent 7271d61 commit 6aac9ac

File tree

11 files changed

+124
-58
lines changed

11 files changed

+124
-58
lines changed

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,10 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
664664
AttrMDArgs.push_back(
665665
llvm::ConstantAsMetadata::get(Builder.getInt(*XDimVal)));
666666

667+
for (auto i = AttrMDArgs.size(); i < 3; ++i)
668+
AttrMDArgs.push_back(
669+
llvm::ConstantAsMetadata::get(Builder.getInt(llvm::APInt(32, 1))));
670+
667671
Fn->setMetadata("work_group_size_hint",
668672
llvm::MDNode::get(Context, AttrMDArgs));
669673
}
@@ -684,16 +688,28 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
684688
std::optional<llvm::APSInt> ZDimVal = A->getZDimVal();
685689
llvm::SmallVector<llvm::Metadata *, 3> AttrMDArgs;
686690

691+
llvm::APInt NumDims(32, 1); // X
687692
// On SYCL target the dimensions are reversed if present.
688-
if (ZDimVal)
693+
if (ZDimVal) {
689694
AttrMDArgs.push_back(
690695
llvm::ConstantAsMetadata::get(Builder.getInt(*ZDimVal)));
691-
if (YDimVal)
696+
++NumDims;
697+
}
698+
if (YDimVal) {
692699
AttrMDArgs.push_back(
693700
llvm::ConstantAsMetadata::get(Builder.getInt(*YDimVal)));
701+
++NumDims;
702+
}
694703
AttrMDArgs.push_back(
695704
llvm::ConstantAsMetadata::get(Builder.getInt(*XDimVal)));
696705

706+
for (auto i = NumDims.getZExtValue(); i < 3; ++i)
707+
AttrMDArgs.push_back(
708+
llvm::ConstantAsMetadata::get(Builder.getInt(llvm::APInt(32, 1))));
709+
710+
Fn->setMetadata("work_group_num_dim",
711+
llvm::MDNode::get(Context, llvm::ConstantAsMetadata::get(
712+
Builder.getInt(NumDims))));
697713
Fn->setMetadata("reqd_work_group_size",
698714
llvm::MDNode::get(Context, AttrMDArgs));
699715
}
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s
22

3+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple amdgcn-amd-amdhsa -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s
4+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s
5+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s
6+
37
// Tests that work_group_size_hint and reqd_work_group_size generate the same
48
// metadata nodes for the same arguments.
59

@@ -11,21 +15,24 @@ int main() {
1115
queue q;
1216

1317
q.submit([&](handler &h) {
14-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_1d() #0 {{.*}} !work_group_size_hint ![[WG1D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WG1D]]
18+
// CHECK: define {{.*}} void @{{.*}}kernel_1d() #0 {{.*}} !work_group_size_hint ![[WGSH1D:[0-9]+]]{{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WGSH1D]]
1519
h.single_task<class kernel_1d>([]() [[sycl::work_group_size_hint(8)]] [[sycl::reqd_work_group_size(8)]] {});
1620
});
1721

1822
q.submit([&](handler &h) {
19-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_2d() #0 {{.*}} !work_group_size_hint ![[WG2D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WG2D]]
23+
// CHECK: define {{.*}} void @{{.*}}kernel_2d() #0 {{.*}} !work_group_size_hint ![[WGSH2D:[0-9]+]]{{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WGSH2D:[0-9]+]]{{.*}}
2024
h.single_task<class kernel_2d>([]() [[sycl::work_group_size_hint(8, 16)]] [[sycl::reqd_work_group_size(8, 16)]] {});
2125
});
2226

2327
q.submit([&](handler &h) {
24-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_3d() #0 {{.*}} !work_group_size_hint ![[WG3D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WG3D]]
28+
// CHECK: define {{.*}} void @{{.*}}kernel_3d() #0 {{.*}} !work_group_size_hint ![[WG3D:[0-9]+]]{{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]]{{.*}} !reqd_work_group_size ![[WG3D]]
2529
h.single_task<class kernel_3d>([]() [[sycl::work_group_size_hint(8, 16, 32)]] [[sycl::reqd_work_group_size(8, 16, 32)]] {});
2630
});
2731
}
2832

29-
// CHECK: ![[WG1D]] = !{i32 8}
30-
// CHECK: ![[WG2D]] = !{i32 16, i32 8}
33+
// CHECK: ![[WGSH1D]] = !{i32 8, i32 1, i32 1}
34+
// CHECK: ![[NDRWGS1D]] = !{i32 1}
35+
// CHECK: ![[WGSH2D]] = !{i32 16, i32 8, i32 1}
36+
// CHECK: ![[NDRWGS2D]] = !{i32 2}
3137
// CHECK: ![[WG3D]] = !{i32 32, i32 16, i32 8}
38+
// CHECK: ![[NDRWGS3D]] = !{i32 3}

clang/test/CodeGenSYCL/reqd-work-group-size.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -disable-llvm-passes -sycl-std=2017 -emit-llvm -o - %s | FileCheck %s
2+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple amdgcn-amd-amdhsa -disable-llvm-passes -sycl-std=2017 -emit-llvm -o - %s | FileCheck %s
3+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -disable-llvm-passes -sycl-std=2017 -emit-llvm -o - %s | FileCheck %s
4+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -disable-llvm-passes -sycl-std=2017 -emit-llvm -o - %s | FileCheck %s
25

36
#include "sycl.hpp"
47

@@ -163,43 +166,46 @@ int main() {
163166
return 0;
164167
}
165168

166-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name1() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D32:[0-9]+]]
167-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name2() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D8:[0-9]+]]
168-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name3() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D88:[0-9]+]]
169-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name4() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D22:[0-9]+]]
170-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name5() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D44:[0-9]+]]
171-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name6() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D2:[0-9]+]]
172-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name7() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D32]]
173-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name8() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D8]]
174-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name9() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D88]]
175-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name10() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D22]]
176-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name11() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D44]]
177-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name12() #0 {{.*}} !reqd_work_group_size ![[WGSIZE3D2]]
178-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name13() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D32:[0-9]+]]
179-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name14() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D8:[0-9]+]]
180-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name15() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D88:[0-9]+]]
181-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name16() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D22:[0-9]+]]
182-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name17() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D44:[0-9]+]]
183-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name18() #0 {{.*}} !reqd_work_group_size ![[WGSIZE2D2:[0-9]+]]
184-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name19() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D32:[0-9]+]]
185-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name20() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D8:[0-9]+]]
186-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name21() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D8]]
187-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name22() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D22:[0-9]+]]
188-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name23() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D8]]
189-
// CHECK: define {{.*}}spir_kernel void @{{.*}}kernel_name24() #0 {{.*}} !reqd_work_group_size ![[WGSIZE1D2:[0-9]+]]
169+
// CHECK: define {{.*}} void @{{.*}}kernel_name1() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D32:[0-9]+]]
170+
// CHECK: define {{.*}} void @{{.*}}kernel_name2() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D8:[0-9]+]]
171+
// CHECK: define {{.*}} void @{{.*}}kernel_name3() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D88:[0-9]+]]
172+
// CHECK: define {{.*}} void @{{.*}}kernel_name4() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D22:[0-9]+]]
173+
// CHECK: define {{.*}} void @{{.*}}kernel_name5() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D44:[0-9]+]]
174+
// CHECK: define {{.*}} void @{{.*}}kernel_name6() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D2:[0-9]+]]
175+
// CHECK: define {{.*}} void @{{.*}}kernel_name7() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D32]]
176+
// CHECK: define {{.*}} void @{{.*}}kernel_name8() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D8]]
177+
// CHECK: define {{.*}} void @{{.*}}kernel_name9() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D88]]
178+
// CHECK: define {{.*}} void @{{.*}}kernel_name10() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D22]]
179+
// CHECK: define {{.*}} void @{{.*}}kernel_name11() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D44]]
180+
// CHECK: define {{.*}} void @{{.*}}kernel_name12() #0 {{.*}} !work_group_num_dim ![[NDRWGS3D:[0-9]+]] !reqd_work_group_size ![[WGSIZE3D2]]
181+
// CHECK: define {{.*}} void @{{.*}}kernel_name13() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D32:[0-9]+]]
182+
// CHECK: define {{.*}} void @{{.*}}kernel_name14() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D8:[0-9]+]]
183+
// CHECK: define {{.*}} void @{{.*}}kernel_name15() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D88:[0-9]+]]
184+
// CHECK: define {{.*}} void @{{.*}}kernel_name16() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D22:[0-9]+]]
185+
// CHECK: define {{.*}} void @{{.*}}kernel_name17() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D44:[0-9]+]]
186+
// CHECK: define {{.*}} void @{{.*}}kernel_name18() #0 {{.*}} !work_group_num_dim ![[NDRWGS2D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D2_or_WGSIZE1D8:[0-9]+]]
187+
// CHECK: define {{.*}} void @{{.*}}kernel_name19() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE1D32:[0-9]+]]
188+
// CHECK: define {{.*}} void @{{.*}}kernel_name20() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D2_or_WGSIZE1D8]]
189+
// CHECK: define {{.*}} void @{{.*}}kernel_name21() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D2_or_WGSIZE1D8]]
190+
// CHECK: define {{.*}} void @{{.*}}kernel_name22() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE1D22:[0-9]+]]
191+
// CHECK: define {{.*}} void @{{.*}}kernel_name23() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE2D2_or_WGSIZE1D8]]
192+
// CHECK: define {{.*}} void @{{.*}}kernel_name24() #0 {{.*}} !work_group_num_dim ![[NDRWGS1D:[0-9]+]] !reqd_work_group_size ![[WGSIZE1D2:[0-9]+]]
193+
194+
// CHECK: ![[NDRWGS3D]] = !{i32 3}
190195
// CHECK: ![[WGSIZE3D32]] = !{i32 16, i32 16, i32 32}
191196
// CHECK: ![[WGSIZE3D8]] = !{i32 1, i32 1, i32 8}
192197
// CHECK: ![[WGSIZE3D88]] = !{i32 8, i32 8, i32 8}
193198
// CHECK: ![[WGSIZE3D22]] = !{i32 2, i32 2, i32 2}
194199
// CHECK: ![[WGSIZE3D44]] = !{i32 4, i32 4, i32 8}
195200
// CHECK: ![[WGSIZE3D2]] = !{i32 2, i32 8, i32 1}
196-
// CHECK: ![[WGSIZE2D32]] = !{i32 16, i32 32}
197-
// CHECK: ![[WGSIZE2D8]] = !{i32 1, i32 8}
198-
// CHECK: ![[WGSIZE2D88]] = !{i32 8, i32 8}
199-
// CHECK: ![[WGSIZE2D22]] = !{i32 2, i32 2}
200-
// CHECK: ![[WGSIZE2D44]] = !{i32 4, i32 8}
201-
// CHECK: ![[WGSIZE2D2]] = !{i32 8, i32 1}
202-
// CHECK: ![[WGSIZE1D32]] = !{i32 32}
203-
// CHECK: ![[WGSIZE1D8]] = !{i32 8}
204-
// CHECK: ![[WGSIZE1D22]] = !{i32 2}
205-
// CHECK: ![[WGSIZE1D2]] = !{i32 1}
201+
// CHECK: ![[NDRWGS2D]] = !{i32 2}
202+
// CHECK: ![[WGSIZE2D32]] = !{i32 16, i32 32, i32 1}
203+
// CHECK: ![[WGSIZE2D8]] = !{i32 1, i32 8, i32 1}
204+
// CHECK: ![[WGSIZE2D88]] = !{i32 8, i32 8, i32 1}
205+
// CHECK: ![[WGSIZE2D22]] = !{i32 2, i32 2, i32 1}
206+
// CHECK: ![[WGSIZE2D44]] = !{i32 4, i32 8, i32 1}
207+
// CHECK: ![[WGSIZE2D2_or_WGSIZE1D8]] = !{i32 8, i32 1, i32 1}
208+
// CHECK: ![[NDRWGS1D]] = !{i32 1}
209+
// CHECK: ![[WGSIZE1D32]] = !{i32 32, i32 1, i32 1}
210+
// CHECK: ![[WGSIZE1D22]] = !{i32 2, i32 1, i32 1}
211+
// CHECK: ![[WGSIZE1D2]] = !{i32 1, i32 1, i32 1}

llvm/include/llvm/SYCLLowerIR/SYCLDeviceRequirements.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct SYCLDeviceRequirements {
3333
std::set<uint32_t> Aspects;
3434
std::set<uint32_t> FixedTarget;
3535
std::optional<llvm::SmallVector<uint64_t, 3>> ReqdWorkGroupSize;
36+
std::optional<uint32_t> WorkGroupNumDim;
3637
std::optional<llvm::SmallString<256>> JointMatrix;
3738
std::optional<llvm::SmallString<256>> JointMatrixMad;
3839
std::optional<uint32_t> SubGroupSize;

llvm/lib/SYCLLowerIR/ModuleSplitter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,7 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
982982
Categorizer.registerSimpleStringAttributeRule("sycl-grf-size");
983983
Categorizer.registerListOfIntegersInMetadataSortedRule("sycl_used_aspects");
984984
Categorizer.registerListOfIntegersInMetadataRule("reqd_work_group_size");
985+
Categorizer.registerListOfIntegersInMetadataRule("work_group_num_dim");
985986
Categorizer.registerListOfIntegersInMetadataRule(
986987
"intel_reqd_sub_group_size");
987988
Categorizer.registerSimpleStringAttributeRule(

llvm/lib/SYCLLowerIR/SYCLDeviceRequirements.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) {
5757
}
5858
}
5959

60+
if (auto *MDN = F.getMetadata("work_group_num_dim")) {
61+
uint32_t WGND = ExtractUnsignedIntegerFromMDNodeOperand(MDN, 0);
62+
if (!Reqs.ReqdWorkGroupSize.has_value())
63+
Reqs.WorkGroupNumDim = WGND;
64+
}
65+
6066
if (auto *MDN = F.getMetadata("reqd_work_group_size")) {
6167
llvm::SmallVector<uint64_t, 3> NewReqdWorkGroupSize;
6268
for (size_t I = 0, E = MDN->getNumOperands(); I < E; ++I)
@@ -133,5 +139,8 @@ std::map<StringRef, util::PropertyValue> SYCLDeviceRequirements::asMap() const {
133139
if (SubGroupSize.has_value())
134140
Requirements["reqd_sub_group_size"] = *SubGroupSize;
135141

142+
if (WorkGroupNumDim.has_value())
143+
Requirements["work_group_num_dim"] = *WorkGroupNumDim;
144+
136145
return Requirements;
137146
}

llvm/tools/sycl-post-link/sycl-post-link.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,16 @@ bool isModuleUsingAsan(const Module &M) {
347347
return MDVal->getString() == "asan";
348348
}
349349

350+
// Gets work_group_num_dim information for function Func, conviniently 0 if
351+
// metadata is not present.
352+
uint32_t getKernelWorkGroupNumDim(const Function &Func) {
353+
MDNode *MaxDimMD = Func.getMetadata("work_group_num_dim");
354+
if (!MaxDimMD)
355+
return 0;
356+
assert(MaxDimMD->getNumOperands() == 1 && "Malformed node.");
357+
return mdconst::extract<ConstantInt>(MaxDimMD->getOperand(0))->getZExtValue();
358+
}
359+
350360
// Gets reqd_work_group_size information for function Func.
351361
std::vector<uint32_t> getKernelReqdWorkGroupSizeMetadata(const Function &Func) {
352362
MDNode *ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size");
@@ -473,15 +483,23 @@ std::string saveModuleProperties(module_split::ModuleDesc &MD,
473483
SmallVector<std::string, 4> MetadataNames;
474484

475485
if (GlobProps.EmitProgramMetadata) {
476-
// Add reqd_work_group_size information to program metadata
486+
// Add reqd_work_group_size and work_group_num_dim information to
487+
// program metadata.
477488
for (const Function &Func : M.functions()) {
478489
std::vector<uint32_t> KernelReqdWorkGroupSize =
479490
getKernelReqdWorkGroupSizeMetadata(Func);
480-
if (KernelReqdWorkGroupSize.empty())
481-
continue;
482-
MetadataNames.push_back(Func.getName().str() + "@reqd_work_group_size");
483-
PropSet.add(PropSetRegTy::SYCL_PROGRAM_METADATA, MetadataNames.back(),
484-
KernelReqdWorkGroupSize);
491+
if (!KernelReqdWorkGroupSize.empty()) {
492+
MetadataNames.push_back(Func.getName().str() + "@reqd_work_group_size");
493+
PropSet.add(PropSetRegTy::SYCL_PROGRAM_METADATA, MetadataNames.back(),
494+
KernelReqdWorkGroupSize);
495+
}
496+
497+
uint32_t WorkGroupNumDim = getKernelWorkGroupNumDim(Func);
498+
if (WorkGroupNumDim) {
499+
MetadataNames.push_back(Func.getName().str() + "@work_group_num_dim");
500+
PropSet.add(PropSetRegTy::SYCL_PROGRAM_METADATA, MetadataNames.back(),
501+
WorkGroupNumDim);
502+
}
485503
}
486504

487505
// Add global_id_mapping information with mapping between device-global

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,8 +2681,7 @@ checkDevSupportDeviceRequirements(const device &Dev,
26812681
const RTDeviceBinaryImage &Img,
26822682
const NDRDescT &NDRDesc) {
26832683
auto getPropIt = [&Img](const std::string &PropName) {
2684-
const RTDeviceBinaryImage::PropertyRange &PropRange =
2685-
Img.getDeviceRequirements();
2684+
auto &PropRange = Img.getDeviceRequirements();
26862685
RTDeviceBinaryImage::PropertyRange::ConstIterator PropIt = std::find_if(
26872686
PropRange.begin(), PropRange.end(),
26882687
[&PropName](RTDeviceBinaryImage::PropertyRange::ConstIterator &&Prop) {
@@ -2700,6 +2699,7 @@ checkDevSupportDeviceRequirements(const device &Dev,
27002699
auto ReqdWGSizeUint32TPropIt = getPropIt("reqd_work_group_size");
27012700
auto ReqdWGSizeUint64TPropIt = getPropIt("reqd_work_group_size_uint64_t");
27022701
auto ReqdSubGroupSizePropIt = getPropIt("reqd_sub_group_size");
2702+
auto WorkGroupNumDim = getPropIt("work_group_num_dim");
27032703

27042704
// Checking if device supports defined aspects
27052705
if (AspectsPropIt) {
@@ -2796,7 +2796,23 @@ checkDevSupportDeviceRequirements(const device &Dev,
27962796
Dims++;
27972797
}
27982798

2799-
if (NDRDesc.Dims != 0 && NDRDesc.Dims != static_cast<size_t>(Dims))
2799+
size_t UserProvidedNumDims = 0;
2800+
if (WorkGroupNumDim) {
2801+
// We know the dimensions have been padded to 3, make sure that the pad
2802+
// value is always set to 1 and record the number of dimensions specified
2803+
// by the user.
2804+
UserProvidedNumDims =
2805+
DeviceBinaryProperty(*(WorkGroupNumDim.value())).asUint32();
2806+
#ifndef NDEBUG
2807+
for (unsigned i = UserProvidedNumDims; i < 3; ++i)
2808+
assert(ReqdWGSizeVec[i] == 1 &&
2809+
"Incorrect padding in required work-group size metadata.");
2810+
#endif // NDEBUG
2811+
} else {
2812+
UserProvidedNumDims = Dims;
2813+
}
2814+
2815+
if (NDRDesc.Dims != 0 && NDRDesc.Dims != UserProvidedNumDims)
28002816
return sycl::exception(
28012817
sycl::errc::nd_range,
28022818
"The local size dimension of submitted nd_range doesn't match the "

sycl/test-e2e/Basic/reqd_work_group_size.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// RUN: %{build} -o %t.out
22
// RUN: %{run} %t.out
3-
//
4-
// Failing negative test with HIP
5-
// UNSUPPORTED: hip
63

74
#include <sycl/detail/core.hpp>
85

sycl/test-e2e/Basic/reqd_work_group_size_check_exception.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
22
// RUN: %{run} %t.out
33

4-
// UNSUPPORTED: hip
5-
64
#include <sycl/detail/core.hpp>
75

86
#define CHECK_INVALID_REQD_WORK_GROUP_SIZE(Dim, ...) \

0 commit comments

Comments
 (0)