diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 17fc36fbe2ac8..90d2a2056fe1b 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4651,6 +4651,13 @@ def HLSLNumThreads: InheritableAttr { let Documentation = [NumThreadsDocs]; } +def HLSLSV_GroupThreadID: HLSLAnnotationAttr { + let Spellings = [HLSLAnnotation<"SV_GroupThreadID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_GroupThreadIDDocs]; +} + def HLSLSV_GroupID: HLSLAnnotationAttr { let Spellings = [HLSLAnnotation<"SV_GroupID">]; let Subjects = SubjectList<[ParmVar, Field]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 7a82b8fa32059..fdad4c9a3ea19 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7941,6 +7941,17 @@ randomized. }]; } +def HLSLSV_GroupThreadIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which +individual thread within a thread group is executing in. This attribute is +only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid + }]; +} + def HLSLSV_GroupIDDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index ee685d95c9615..f4cd11f423a84 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); + void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL); void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 2c293523fca8c..fb15b1993e74a 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, CGM.getIntrinsic(getThreadIdIntrinsic()); return buildVectorInput(B, ThreadIDIntrinsic, Ty); } + if (D.hasAttr()) { + llvm::Function *GroupThreadIDIntrinsic = + CGM.getIntrinsic(getGroupThreadIdIntrinsic()); + return buildVectorInput(B, GroupThreadIDIntrinsic, Ty); + } if (D.hasAttr()) { llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id); return buildVectorInput(B, GroupIDIntrinsic, Ty); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index bb120c8b5e9e6..f9efb1bc99641 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -86,6 +86,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step) GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians) GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id) + GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group) GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot) GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot) GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot) diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index 4de342b63ed80..443bf2b9ec626 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs, case ParsedAttr::UnknownAttribute: Diag(Loc, diag::err_unknown_hlsl_semantic) << II; return; + case ParsedAttr::AT_HLSLSV_GroupThreadID: case ParsedAttr::AT_HLSLSV_GroupID: case ParsedAttr::AT_HLSLSV_GroupIndex: case ParsedAttr::AT_HLSLSV_DispatchThreadID: diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 4fd8ef6dbebf8..5d7ee09738377 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7114,6 +7114,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; + case ParsedAttr::AT_HLSLSV_GroupThreadID: + S.HLSL().handleSV_GroupThreadIDAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupID: S.HLSL().handleSV_GroupIDAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 88db3e1254119..600c800029fd0 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation( switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: + case attr::HLSLSV_GroupThreadID: case attr::HLSLSV_GroupID: if (ST == llvm::Triple::Compute) return; @@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); } +void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) { + auto *VD = cast(D); + if (!diagnoseInputIDType(VD->getType(), AL)) + return; + + D->addAttr(::new (getASTContext()) + HLSLSV_GroupThreadIDAttr(getASTContext(), AL)); +} + void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast(D); if (!diagnoseInputIDType(VD->getType(), AL)) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl new file mode 100644 index 0000000000000..3d347b973f39c --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl @@ -0,0 +1,36 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx +// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv + +// Make sure SV_GroupThreadID translated into dx.thread.id.in.group for directx target and spv.thread.id.in.group for spirv target. + +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) +// CHECK-DXIL: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_GroupThreadID) {} + +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK-DXIL: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +// CHECK-SPIRV: call spir_func void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_GroupThreadID) {} + +// CHECK: define void @test() +// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: %[[#ID_Z:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 2) +// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2 +// CHECK-DXIL: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +// CHECK-SPIRV: call spir_func void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +[shader("compute")] +[numthreads(8,8,1)] +void test(uint3 Idx : SV_GroupThreadID) {} diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 13c07038d2e4a..393d7300605c0 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -1,16 +1,14 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl -finclude-default-header -ast-dump -o - %s | FileCheck %s -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} -void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)' +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint' // CHECK-NEXT: HLSLSV_GroupIDAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index 4e1f88aa2294b..1bb4ee5182d62 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -49,3 +49,33 @@ struct ST2_GID { static uint GID : SV_GroupID; uint s_gid : SV_GroupID; }; + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain_GThreadID(float ID : SV_GroupThreadID) { +} + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain2_GThreadID(ST GID : SV_GroupThreadID) { + +} + +void foo_GThreadID() { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + uint GThreadIS : SV_GroupThreadID; +} + +struct ST2_GThreadID { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + static uint GThreadID : SV_GroupThreadID; + uint s_gthreadid : SV_GroupThreadID; +}; + + +[shader("vertex")] +// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'vertex' shaders, requires compute}} +void vs_main(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {} diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl index 10a5e5dabac87..6781f9241df24 100644 --- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) { // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3' // CHECK-NEXT: HLSLSV_GroupIDAttr } + +[numthreads(8,8,1)] +void CSMain_GThreadID(uint ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain3_GThreadID(uint3 : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 1ae3129774e50..fd0c3b2a59e1d 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -59,6 +59,7 @@ let TargetPrefix = "spv" in { // The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support. def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; + def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>; def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>; def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index b64030508cfc1..824bd6b721679 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -262,9 +262,6 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectSaturate(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; - bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; - bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; @@ -310,6 +307,9 @@ class SPIRVInstructionSelector : public InstructionSelector { void extractSubvector(Register &ResVReg, const SPIRVType *ResType, Register &ReadReg, MachineInstr &InsertionPoint) const; bool BuildCOPY(Register DestReg, Register SrcReg, MachineInstr &I) const; + bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue, + Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; }; } // end anonymous namespace @@ -2825,7 +2825,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return BuildCOPY(ResVReg, I.getOperand(2).getReg(), I); break; case Intrinsic::spv_thread_id: - return selectSpvThreadId(ResVReg, ResType, I); + // The HLSL SV_DispatchThreadID semantic is lowered to llvm.spv.thread.id + // intrinsic in LLVM IR for SPIR-V backend. + // + // In SPIR-V backend, llvm.spv.thread.id is now correctly translated to a + // `GlobalInvocationId` builtin variable + return loadVec3BuiltinInputID(SPIRV::BuiltIn::GlobalInvocationId, ResVReg, + ResType, I); + case Intrinsic::spv_thread_id_in_group: + // The HLSL SV_GroupThreadId semantic is lowered to + // llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V backend. + // + // In SPIR-V backend, llvm.spv.thread.id.in.group is now correctly + // translated to a `LocalInvocationId` builtin variable + return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg, + ResType, I); case Intrinsic::spv_fdot: return selectFloatDot(ResVReg, ResType, I); case Intrinsic::spv_udot: @@ -3525,13 +3539,12 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { - // DX intrinsic: @llvm.dx.thread.id(i32) - // ID Name Description - // 93 ThreadId reads the thread ID - +// Generate the instructions to load 3-element vector builtin input +// IDs/Indices. +// Like: GlobalInvocationId, LocalInvocationId, etc.... +bool SPIRVInstructionSelector::loadVec3BuiltinInputID( + SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg, + const SPIRVType *ResType, MachineInstr &I) const { MachineIRBuilder MIRBuilder(I); const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder); const SPIRVType *Vec3Ty = @@ -3539,16 +3552,16 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType( Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input); - // Create new register for GlobalInvocationID builtin variable. + // Create new register for the input ID builtin variable. Register NewRegister = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64)); GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF()); - // Build GlobalInvocationID global variable with the necessary decorations. + // Build global variable with the necessary decorations for the input ID + // builtin variable. Register Variable = GR.buildGlobalVariable( - NewRegister, PtrType, - getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr, + NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, SPIRV::StorageClass::Input, nullptr, true, true, SPIRV::LinkageType::Import, MIRBuilder, false); @@ -3565,12 +3578,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, .addUse(GR.getSPIRVTypeID(Vec3Ty)) .addUse(Variable); - // Get Thread ID index. Expecting operand is a constant immediate value, + // Get the input ID index. Expecting operand is a constant immediate value, // wrapped in a type assignment. assert(I.getOperand(2).isReg()); const uint32_t ThreadId = foldImm(I.getOperand(2), MRI); - // Extract the thread ID from the loaded vector value. + // Extract the input ID from the loaded vector value. MachineBasicBlock &BB = *I.getParent(); auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) .addDef(ResVReg) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll new file mode 100644 index 0000000000000..a88debf97fa7b --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll @@ -0,0 +1,76 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; This file generated from the following command: +; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header - -o - < noundef %ID) #0 { +entry: + %ID.addr = alloca <3 x i32>, align 16 + store <3 x i32> %ID, ptr %ID.addr, align 16 + ret void +} + +; Function Attrs: norecurse +define void @main.1() #1 { +entry: + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0 + %0 = call i32 @llvm.spv.thread.id.in.group(i32 0) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0 + %1 = insertelement <3 x i32> poison, i32 %0, i64 0 + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1 + %2 = call i32 @llvm.spv.thread.id.in.group(i32 1) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1 + %3 = insertelement <3 x i32> %1, i32 %2, i64 1 + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2 + %4 = call i32 @llvm.spv.thread.id.in.group(i32 2) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2 + %5 = insertelement <3 x i32> %3, i32 %4, i64 2 + + call void @main(<3 x i32> %5) + ret void +} + +; Function Attrs: nounwind willreturn memory(none) +declare i32 @llvm.spv.thread.id.in.group(i32) #2 + +attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { nounwind willreturn memory(none) } + +!llvm.module.flags = !{!0, !1} +!llvm.ident = !{!2} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} +!2 = !{!"clang version 19.0.0git (git@github.com:llvm/llvm-project.git 91600507765679e92434ec7c5edb883bf01f847f)"}