Skip to content

Commit 6b3c20f

Browse files
authored
[SYCL] Enable Dead Argument Elimination for NVPTX backend (#4617)
This addresses #2359 for NVPTX backend only. Changes are to: - Detect NVPTX kernels for DAE (using `nvvm.annotations` metadata) - Update the metadata with new function signatures as required - Renaming some metadata, variables & files to reflect the fact that SYCL DAE is no longer SPIR specific - Updating tests to reflect the expectation that -fenable-sycl-dae flag should be passed through for nvptx targets
1 parent d7af47a commit 6b3c20f

File tree

11 files changed

+122
-56
lines changed

11 files changed

+122
-56
lines changed

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4635,7 +4635,7 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA,
46354635
}
46364636

46374637
// Turn on Dead Parameter Elimination Optimization with early optimizations
4638-
if (!(RawTriple.isNVPTX() || RawTriple.isAMDGCN()) &&
4638+
if (!(RawTriple.isAMDGCN()) &&
46394639
Args.hasFlag(options::OPT_fsycl_dead_args_optimization,
46404640
options::OPT_fno_sycl_dead_args_optimization, false))
46414641
CmdArgs.push_back("-fenable-sycl-dae");
@@ -8920,8 +8920,7 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA,
89208920
// -fsycl-device-code-split=auto
89218921

89228922
// Turn on Dead Parameter Elimination Optimization with early optimizations
8923-
if (!(getToolChain().getTriple().isNVPTX() ||
8924-
getToolChain().getTriple().isAMDGCN()) &&
8923+
if (!(getToolChain().getTriple().isAMDGCN()) &&
89258924
TCArgs.hasFlag(options::OPT_fsycl_dead_args_optimization,
89268925
options::OPT_fno_sycl_dead_args_optimization, false))
89278926
addArgs(CmdArgs, TCArgs, {"-emit-param-info"});

clang/test/Driver/sycl-triple-dae-flags.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out
2-
// RUN: FileCheck %s --input-file %t.cuda.out
3-
//
41
// RUN: %clangxx -### -fsycl -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=gfx906 -fsycl-dead-args-optimization %s 2> %t.rocm.out
52
// RUN: FileCheck %s --input-file %t.rocm.out
63
// CHECK-NOT: -fenable-sycl-dae
74
// CHECK-NOT: -emit-param-info
85
//
6+
// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out
7+
// RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.cuda.out
8+
//
99
// RUN: %clangxx -### -fsycl -fsycl-targets=spir64-unknown-unknown -fsycl-dead-args-optimization %s 2> %t.out
1010
// RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.out
1111
// CHECK-FENABLE: -fenable-sycl-dae

llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "llvm/ADT/SmallVector.h"
2323
#include "llvm/ADT/Twine.h"
24+
#include "llvm/IR/Constants.h"
2425
#include "llvm/IR/Function.h"
2526
#include "llvm/IR/PassManager.h"
2627
#include <map>
@@ -74,9 +75,9 @@ class DeadArgumentEliminationPass
7475
enum Liveness { Live, MaybeLive };
7576

7677
DeadArgumentEliminationPass(bool ShouldHackArguments_ = false,
77-
bool CheckSpirKernels_ = false)
78+
bool CheckSYCLKernels_ = false)
7879
: ShouldHackArguments(ShouldHackArguments_),
79-
CheckSpirKernels(CheckSpirKernels_) {}
80+
CheckSYCLKernels(CheckSYCLKernels_) {}
8081

8182
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
8283

@@ -123,9 +124,9 @@ class DeadArgumentEliminationPass
123124
/// (used only by bugpoint).
124125
bool ShouldHackArguments = false;
125126

126-
/// This allows to eliminate dead arguments in SPIR kernel functions with
127-
/// external linkage in SYCL environment
128-
bool CheckSpirKernels = false;
127+
/// This allows to eliminate dead arguments in SYCL kernel wrapper functions
128+
/// with external linkage
129+
bool CheckSYCLKernels = false;
129130

130131
private:
131132
Liveness MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses);
@@ -143,6 +144,45 @@ class DeadArgumentEliminationPass
143144
bool RemoveDeadStuffFromFunction(Function *F);
144145
bool DeleteDeadVarargs(Function &Fn);
145146
bool RemoveDeadArgumentsFromCallers(Function &Fn);
147+
148+
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF);
149+
llvm::DenseSet<Function *> NVPTXKernelSet;
150+
151+
bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); };
152+
153+
void BuildNVPTXKernelSet(const Module &M) {
154+
155+
auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
156+
if (!NvvmMetadata)
157+
return;
158+
159+
for (auto *MetadataNode : NvvmMetadata->operands()) {
160+
if (MetadataNode->getNumOperands() != 3)
161+
continue;
162+
163+
// NVPTX identifies kernel entry points using metadata nodes of the form:
164+
// !X = !{<function>, !"kernel", i32 1}
165+
auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(1));
166+
// Only process kernel entry points.
167+
if (!Type || Type->getString() != "kernel")
168+
continue;
169+
170+
// Get a pointer to the entry point function from the metadata.
171+
if (const auto &FuncOperand = MetadataNode->getOperand(0)) {
172+
if (auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand)) {
173+
if (auto *Func = dyn_cast<Function>(FuncConstant->getValue())) {
174+
if (auto *Val = mdconst::dyn_extract<ConstantInt>(
175+
MetadataNode->getOperand(2))) {
176+
if (Val->getValue() == 1) {
177+
NVPTXKernelSet.insert(Func);
178+
}
179+
}
180+
}
181+
}
182+
}
183+
}
184+
return;
185+
}
146186
};
147187

148188
class DeadArgumentEliminationSYCLPass
@@ -155,7 +195,7 @@ class DeadArgumentEliminationSYCLPass
155195
private:
156196
DeadArgumentEliminationPass Impl =
157197
DeadArgumentEliminationPass(/* ShouldHackArguemtns */ false,
158-
/* CheckSpirKernels */ true);
198+
/* CheckSYCLKernels */ true);
159199
};
160200

161201
} // end namespace llvm

llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ namespace {
7878
if (skipModule(M))
7979
return false;
8080
DeadArgumentEliminationPass DAEP(ShouldHackArguments(),
81-
CheckSpirKernels());
81+
CheckSYCLKernels());
8282
ModuleAnalysisManager DummyMAM;
8383
PreservedAnalyses PA = DAEP.run(M, DummyMAM);
8484
return !PA.areAllPreserved();
8585
}
8686

8787
virtual bool ShouldHackArguments() const { return false; }
88-
virtual bool CheckSpirKernels() const { return false; }
88+
virtual bool CheckSYCLKernels() const { return false; }
8989
};
9090

9191
} // end anonymous namespace
@@ -105,7 +105,7 @@ namespace {
105105
DAH() : DAE(ID) {}
106106

107107
bool ShouldHackArguments() const override { return true; }
108-
bool CheckSpirKernels() const override { return false; }
108+
bool CheckSYCLKernels() const override { return false; }
109109
};
110110

111111
} // end anonymous namespace
@@ -118,7 +118,7 @@ INITIALIZE_PASS(DAH, "deadarghaX0r",
118118

119119
namespace {
120120

121-
/// DAESYCL - DeadArgumentElimination pass for SPIR kernel functions even
121+
/// DAESYCL - DeadArgumentElimination pass for SYCL kernel functions even
122122
/// if they are external.
123123
struct DAESYCL : public DAE {
124124
static char ID;
@@ -128,21 +128,19 @@ struct DAESYCL : public DAE {
128128
}
129129

130130
StringRef getPassName() const override {
131-
return "Dead Argument Elimination for SPIR kernels in SYCL environment";
131+
return "Dead Argument Elimination for SYCL kernels";
132132
}
133133

134134
bool ShouldHackArguments() const override { return false; }
135-
bool CheckSpirKernels() const override { return true; }
135+
bool CheckSYCLKernels() const override { return true; }
136136
};
137137

138138
} // end anonymous namespace
139139

140140
char DAESYCL::ID = 0;
141141

142-
INITIALIZE_PASS(
143-
DAESYCL, "deadargelim-sycl",
144-
"Dead Argument Elimination for SPIR kernels in SYCL environment", false,
145-
false)
142+
INITIALIZE_PASS(DAESYCL, "deadargelim-sycl",
143+
"Dead Argument Elimination for SYCL kernels", false, false)
146144

147145
/// createDeadArgEliminationPass - This pass removes arguments from functions
148146
/// which are not used by the body of the function.
@@ -573,12 +571,13 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) {
573571
}
574572

575573
// We can't modify arguments if the function is not local
576-
// but we can do so for SPIR kernel function in SYCL environment.
574+
// but we can do so for SYCL kernel functions.
577575
// DAE is not currently supported for ESIMD kernels.
578-
bool FuncIsSpirNonEsimdKernel =
579-
CheckSpirKernels && F.getCallingConv() == CallingConv::SPIR_KERNEL &&
576+
bool FuncIsSyclNonEsimdKernel =
577+
CheckSYCLKernels &&
578+
(F.getCallingConv() == CallingConv::SPIR_KERNEL || IsNVPTXKernel(&F)) &&
580579
!F.getMetadata("sycl_explicit_simd");
581-
bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSpirNonEsimdKernel;
580+
bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSyclNonEsimdKernel;
582581
if (FuncIsLive && (!ShouldHackArguments || F.isIntrinsic())) {
583582
MarkLive(F);
584583
return;
@@ -812,15 +811,15 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
812811
}
813812
}
814813

815-
if (CheckSpirKernels) {
814+
if (CheckSYCLKernels) {
816815
SmallVector<Metadata *, 10> MDOmitArgs;
817816
auto MDOmitArgTrue = llvm::ConstantAsMetadata::get(
818817
ConstantInt::get(Type::getInt1Ty(F->getContext()), 1));
819818
auto MDOmitArgFalse = llvm::ConstantAsMetadata::get(
820819
ConstantInt::get(Type::getInt1Ty(F->getContext()), 0));
821820
for (auto &AliveArg : ArgAlive)
822821
MDOmitArgs.push_back(AliveArg ? MDOmitArgFalse : MDOmitArgTrue);
823-
F->setMetadata("spir_kernel_omit_args",
822+
F->setMetadata("sycl_kernel_omit_args",
824823
llvm::MDNode::get(F->getContext(), MDOmitArgs));
825824
}
826825

@@ -1131,6 +1130,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
11311130
for (auto MD : MDs)
11321131
NF->addMetadata(MD.first, *MD.second);
11331132

1133+
if (IsNVPTXKernel(F))
1134+
UpdateNVPTXMetadata(*(F->getParent()), F, NF);
1135+
11341136
// Now that the old function is dead, delete it.
11351137
F->eraseFromParent();
11361138

@@ -1141,6 +1143,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
11411143
ModuleAnalysisManager &) {
11421144
bool Changed = false;
11431145

1146+
BuildNVPTXKernelSet(M);
1147+
11441148
// First pass: Do a simple check to see if any functions can have their "..."
11451149
// removed. We can do this if they never call va_start. This loop cannot be
11461150
// fused with the next loop, because deleting a function invalidates
@@ -1173,3 +1177,25 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
11731177
return PreservedAnalyses::all();
11741178
return PreservedAnalyses::none();
11751179
}
1180+
1181+
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
1182+
Function *NF) {
1183+
1184+
auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
1185+
if (!NvvmMetadata)
1186+
return;
1187+
1188+
for (auto *MetadataNode : NvvmMetadata->operands()) {
1189+
const auto &FuncOperand = MetadataNode->getOperand(0);
1190+
if (!FuncOperand)
1191+
continue;
1192+
auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
1193+
if (!FuncConstant)
1194+
continue;
1195+
auto *Func = dyn_cast<Function>(FuncConstant->getValue());
1196+
if (Func != F)
1197+
continue;
1198+
// Update the metadata with the new function
1199+
MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF));
1200+
}
1201+
}

llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ target triple = "spir64"
77
; This test ensures dead arguments are not eliminated
88
; from a global function that is not a SPIR kernel.
99

10-
; CHECK-NOT: !spir_kernel_omit_args
10+
; CHECK-NOT: !sycl_kernel_omit_args
1111

1212
define weak_odr void @NotASpirKernel(float %arg1, float %arg2) {
1313
; CHECK-LABEL: define {{[^@]+}}@NotASpirKernel

llvm/test/Transforms/DeadArgElim/sycl-kernels.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ define weak_odr spir_kernel void @SpirKernel1(float %arg1, float %arg2) {
1414
; CHECK-NEXT: ret void
1515
;
1616
; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel1
17-
; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS1:[0-9]]]
17+
; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS1:[0-9]]]
1818
; CHECK-SYCL-NEXT: call void @foo(float [[ARG1]])
1919
; CHECK-SYCL-NEXT: ret void
2020

@@ -29,7 +29,7 @@ define weak_odr spir_kernel void @SpirKernel2(float %arg1, float %arg2) {
2929
; CHECK-NEXT: ret void
3030
;
3131
; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel2
32-
; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS2:[0-9]]]
32+
; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS2:[0-9]]]
3333
; CHECK-SYCL-NEXT: call void @foo(float [[ARG2]])
3434
; CHECK-SYCL-NEXT: ret void
3535

llvm/test/tools/sycl-post-link/omit_kernel_args.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
target triple = "spir64-unknown-unknown"
1010

11-
define weak_odr spir_kernel void @SpirKernel1(float %arg1) !spir_kernel_omit_args !0 {
11+
define weak_odr spir_kernel void @SpirKernel1(float %arg1) !sycl_kernel_omit_args !0 {
1212
call void @foo(float %arg1)
1313
ret void
1414
}
1515

16-
define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !spir_kernel_omit_args !1 {
16+
define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !sycl_kernel_omit_args !1 {
1717
call void @bar(i8 %arg1)
1818
call void @bar(i8 %arg2)
1919
call void @bar(i8 %arg3)

llvm/tools/sycl-post-link/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include_directories(
1919

2020
add_llvm_tool(sycl-post-link
2121
sycl-post-link.cpp
22-
SPIRKernelParamOptInfo.cpp
22+
SYCLKernelParamOptInfo.cpp
2323
SpecConstants.cpp
2424
SYCLDeviceLibReqMask.cpp
2525
ADDITIONAL_HEADER_DIRS
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
1-
//==-- SPIRKernelParamOptInfo.cpp -- get kernel param optimization info ----==//
1+
//==-- SYCLKernelParamOptInfo.cpp -- get kernel param optimization info ----==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "SPIRKernelParamOptInfo.h"
9+
#include "SYCLKernelParamOptInfo.h"
1010

1111
#include "llvm/IR/Constants.h"
1212
#include "llvm/Support/Casting.h"
1313

1414
namespace {
1515

1616
// Must match the one produced by DeadArgumentElimination
17-
static constexpr char MetaDataID[] = "spir_kernel_omit_args";
17+
static constexpr char MetaDataID[] = "sycl_kernel_omit_args";
1818

1919
} // anonymous namespace
2020

2121
namespace llvm {
2222

23-
void SPIRKernelParamOptInfo::releaseMemory() { clear(); }
23+
void SYCLKernelParamOptInfo::releaseMemory() { clear(); }
2424

25-
SPIRKernelParamOptInfo
26-
SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
27-
SPIRKernelParamOptInfo Res;
25+
SYCLKernelParamOptInfo
26+
SYCLKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
27+
SYCLKernelParamOptInfo Res;
2828

2929
for (const Function &F : M) {
3030
MDNode *MD = F.getMetadata(MetaDataID);
3131
if (!MD)
3232
continue;
33-
using BaseTy = SPIRKernelParamOptInfoBaseTy;
33+
using BaseTy = SYCLKernelParamOptInfoBaseTy;
3434
auto Ins =
3535
Res.insert(BaseTy::value_type{F.getName(), BaseTy::mapped_type{}});
3636
assert(Ins.second && "duplicate kernel?");
@@ -46,6 +46,6 @@ SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
4646
return Res;
4747
}
4848

49-
AnalysisKey SPIRKernelParamOptInfoAnalysis::Key;
49+
AnalysisKey SYCLKernelParamOptInfoAnalysis::Key;
5050

51-
} // namespace llvm
51+
} // namespace llvm

0 commit comments

Comments
 (0)