diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp index 484be252f3e7a..b08560df11845 100644 --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -4635,7 +4635,7 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA, } // Turn on Dead Parameter Elimination Optimization with early optimizations - if (!(RawTriple.isNVPTX() || RawTriple.isAMDGCN()) && + if (!(RawTriple.isAMDGCN()) && Args.hasFlag(options::OPT_fsycl_dead_args_optimization, options::OPT_fno_sycl_dead_args_optimization, false)) CmdArgs.push_back("-fenable-sycl-dae"); @@ -8920,8 +8920,7 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA, // -fsycl-device-code-split=auto // Turn on Dead Parameter Elimination Optimization with early optimizations - if (!(getToolChain().getTriple().isNVPTX() || - getToolChain().getTriple().isAMDGCN()) && + if (!(getToolChain().getTriple().isAMDGCN()) && TCArgs.hasFlag(options::OPT_fsycl_dead_args_optimization, options::OPT_fno_sycl_dead_args_optimization, false)) addArgs(CmdArgs, TCArgs, {"-emit-param-info"}); diff --git a/clang/test/Driver/sycl-triple-dae-flags.cpp b/clang/test/Driver/sycl-triple-dae-flags.cpp index cd356e051dcea..9c1190163a704 100644 --- a/clang/test/Driver/sycl-triple-dae-flags.cpp +++ b/clang/test/Driver/sycl-triple-dae-flags.cpp @@ -1,11 +1,11 @@ -// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out -// RUN: FileCheck %s --input-file %t.cuda.out -// // RUN: %clangxx -### -fsycl -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=gfx906 -fsycl-dead-args-optimization %s 2> %t.rocm.out // RUN: FileCheck %s --input-file %t.rocm.out // CHECK-NOT: -fenable-sycl-dae // CHECK-NOT: -emit-param-info // +// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out +// RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.cuda.out +// // RUN: %clangxx -### -fsycl -fsycl-targets=spir64-unknown-unknown -fsycl-dead-args-optimization %s 2> %t.out // RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.out // CHECK-FENABLE: -fenable-sycl-dae diff --git a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h index 7dbe55ca0bcd2..0738968a4c8fc 100644 --- a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" #include @@ -74,9 +75,9 @@ class DeadArgumentEliminationPass enum Liveness { Live, MaybeLive }; DeadArgumentEliminationPass(bool ShouldHackArguments_ = false, - bool CheckSpirKernels_ = false) + bool CheckSYCLKernels_ = false) : ShouldHackArguments(ShouldHackArguments_), - CheckSpirKernels(CheckSpirKernels_) {} + CheckSYCLKernels(CheckSYCLKernels_) {} PreservedAnalyses run(Module &M, ModuleAnalysisManager &); @@ -123,9 +124,9 @@ class DeadArgumentEliminationPass /// (used only by bugpoint). bool ShouldHackArguments = false; - /// This allows to eliminate dead arguments in SPIR kernel functions with - /// external linkage in SYCL environment - bool CheckSpirKernels = false; + /// This allows to eliminate dead arguments in SYCL kernel wrapper functions + /// with external linkage + bool CheckSYCLKernels = false; private: Liveness MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses); @@ -143,6 +144,45 @@ class DeadArgumentEliminationPass bool RemoveDeadStuffFromFunction(Function *F); bool DeleteDeadVarargs(Function &Fn); bool RemoveDeadArgumentsFromCallers(Function &Fn); + + void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF); + llvm::DenseSet NVPTXKernelSet; + + bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); }; + + void BuildNVPTXKernelSet(const Module &M) { + + auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations"); + if (!NvvmMetadata) + return; + + for (auto *MetadataNode : NvvmMetadata->operands()) { + if (MetadataNode->getNumOperands() != 3) + continue; + + // NVPTX identifies kernel entry points using metadata nodes of the form: + // !X = !{, !"kernel", i32 1} + auto *Type = dyn_cast(MetadataNode->getOperand(1)); + // Only process kernel entry points. + if (!Type || Type->getString() != "kernel") + continue; + + // Get a pointer to the entry point function from the metadata. + if (const auto &FuncOperand = MetadataNode->getOperand(0)) { + if (auto *FuncConstant = dyn_cast(FuncOperand)) { + if (auto *Func = dyn_cast(FuncConstant->getValue())) { + if (auto *Val = mdconst::dyn_extract( + MetadataNode->getOperand(2))) { + if (Val->getValue() == 1) { + NVPTXKernelSet.insert(Func); + } + } + } + } + } + } + return; + } }; class DeadArgumentEliminationSYCLPass @@ -155,7 +195,7 @@ class DeadArgumentEliminationSYCLPass private: DeadArgumentEliminationPass Impl = DeadArgumentEliminationPass(/* ShouldHackArguemtns */ false, - /* CheckSpirKernels */ true); + /* CheckSYCLKernels */ true); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index cbfe98cdd5f79..a71c93f21e3d2 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -78,14 +78,14 @@ namespace { if (skipModule(M)) return false; DeadArgumentEliminationPass DAEP(ShouldHackArguments(), - CheckSpirKernels()); + CheckSYCLKernels()); ModuleAnalysisManager DummyMAM; PreservedAnalyses PA = DAEP.run(M, DummyMAM); return !PA.areAllPreserved(); } virtual bool ShouldHackArguments() const { return false; } - virtual bool CheckSpirKernels() const { return false; } + virtual bool CheckSYCLKernels() const { return false; } }; } // end anonymous namespace @@ -105,7 +105,7 @@ namespace { DAH() : DAE(ID) {} bool ShouldHackArguments() const override { return true; } - bool CheckSpirKernels() const override { return false; } + bool CheckSYCLKernels() const override { return false; } }; } // end anonymous namespace @@ -118,7 +118,7 @@ INITIALIZE_PASS(DAH, "deadarghaX0r", namespace { -/// DAESYCL - DeadArgumentElimination pass for SPIR kernel functions even +/// DAESYCL - DeadArgumentElimination pass for SYCL kernel functions even /// if they are external. struct DAESYCL : public DAE { static char ID; @@ -128,21 +128,19 @@ struct DAESYCL : public DAE { } StringRef getPassName() const override { - return "Dead Argument Elimination for SPIR kernels in SYCL environment"; + return "Dead Argument Elimination for SYCL kernels"; } bool ShouldHackArguments() const override { return false; } - bool CheckSpirKernels() const override { return true; } + bool CheckSYCLKernels() const override { return true; } }; } // end anonymous namespace char DAESYCL::ID = 0; -INITIALIZE_PASS( - DAESYCL, "deadargelim-sycl", - "Dead Argument Elimination for SPIR kernels in SYCL environment", false, - false) +INITIALIZE_PASS(DAESYCL, "deadargelim-sycl", + "Dead Argument Elimination for SYCL kernels", false, false) /// createDeadArgEliminationPass - This pass removes arguments from functions /// which are not used by the body of the function. @@ -573,12 +571,13 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { } // We can't modify arguments if the function is not local - // but we can do so for SPIR kernel function in SYCL environment. + // but we can do so for SYCL kernel functions. // DAE is not currently supported for ESIMD kernels. - bool FuncIsSpirNonEsimdKernel = - CheckSpirKernels && F.getCallingConv() == CallingConv::SPIR_KERNEL && + bool FuncIsSyclNonEsimdKernel = + CheckSYCLKernels && + (F.getCallingConv() == CallingConv::SPIR_KERNEL || IsNVPTXKernel(&F)) && !F.getMetadata("sycl_explicit_simd"); - bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSpirNonEsimdKernel; + bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSyclNonEsimdKernel; if (FuncIsLive && (!ShouldHackArguments || F.isIntrinsic())) { MarkLive(F); return; @@ -812,7 +811,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { } } - if (CheckSpirKernels) { + if (CheckSYCLKernels) { SmallVector MDOmitArgs; auto MDOmitArgTrue = llvm::ConstantAsMetadata::get( ConstantInt::get(Type::getInt1Ty(F->getContext()), 1)); @@ -820,7 +819,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { ConstantInt::get(Type::getInt1Ty(F->getContext()), 0)); for (auto &AliveArg : ArgAlive) MDOmitArgs.push_back(AliveArg ? MDOmitArgFalse : MDOmitArgTrue); - F->setMetadata("spir_kernel_omit_args", + F->setMetadata("sycl_kernel_omit_args", llvm::MDNode::get(F->getContext(), MDOmitArgs)); } @@ -1131,6 +1130,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { for (auto MD : MDs) NF->addMetadata(MD.first, *MD.second); + if (IsNVPTXKernel(F)) + UpdateNVPTXMetadata(*(F->getParent()), F, NF); + // Now that the old function is dead, delete it. F->eraseFromParent(); @@ -1141,6 +1143,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; + BuildNVPTXKernelSet(M); + // First pass: Do a simple check to see if any functions can have their "..." // removed. We can do this if they never call va_start. This loop cannot be // fused with the next loop, because deleting a function invalidates @@ -1173,3 +1177,25 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, return PreservedAnalyses::all(); return PreservedAnalyses::none(); } + +void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F, + Function *NF) { + + auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations"); + if (!NvvmMetadata) + return; + + for (auto *MetadataNode : NvvmMetadata->operands()) { + const auto &FuncOperand = MetadataNode->getOperand(0); + if (!FuncOperand) + continue; + auto FuncConstant = dyn_cast(FuncOperand); + if (!FuncConstant) + continue; + auto *Func = dyn_cast(FuncConstant->getValue()); + if (Func != F) + continue; + // Update the metadata with the new function + MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF)); + } +} diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll index 7e2a1adaf8327..211a1bf4d731f 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll @@ -7,7 +7,7 @@ target triple = "spir64" ; This test ensures dead arguments are not eliminated ; from a global function that is not a SPIR kernel. -; CHECK-NOT: !spir_kernel_omit_args +; CHECK-NOT: !sycl_kernel_omit_args define weak_odr void @NotASpirKernel(float %arg1, float %arg2) { ; CHECK-LABEL: define {{[^@]+}}@NotASpirKernel diff --git a/llvm/test/Transforms/DeadArgElim/sycl-kernels.ll b/llvm/test/Transforms/DeadArgElim/sycl-kernels.ll index 6725047170c42..3b9b2de4e32f4 100644 --- a/llvm/test/Transforms/DeadArgElim/sycl-kernels.ll +++ b/llvm/test/Transforms/DeadArgElim/sycl-kernels.ll @@ -14,7 +14,7 @@ define weak_odr spir_kernel void @SpirKernel1(float %arg1, float %arg2) { ; CHECK-NEXT: ret void ; ; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel1 -; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS1:[0-9]]] +; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS1:[0-9]]] ; CHECK-SYCL-NEXT: call void @foo(float [[ARG1]]) ; CHECK-SYCL-NEXT: ret void @@ -29,7 +29,7 @@ define weak_odr spir_kernel void @SpirKernel2(float %arg1, float %arg2) { ; CHECK-NEXT: ret void ; ; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel2 -; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS2:[0-9]]] +; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS2:[0-9]]] ; CHECK-SYCL-NEXT: call void @foo(float [[ARG2]]) ; CHECK-SYCL-NEXT: ret void diff --git a/llvm/test/tools/sycl-post-link/omit_kernel_args.ll b/llvm/test/tools/sycl-post-link/omit_kernel_args.ll index a29dd4ddf3b01..96826d5e37533 100644 --- a/llvm/test/tools/sycl-post-link/omit_kernel_args.ll +++ b/llvm/test/tools/sycl-post-link/omit_kernel_args.ll @@ -8,12 +8,12 @@ target triple = "spir64-unknown-unknown" -define weak_odr spir_kernel void @SpirKernel1(float %arg1) !spir_kernel_omit_args !0 { +define weak_odr spir_kernel void @SpirKernel1(float %arg1) !sycl_kernel_omit_args !0 { call void @foo(float %arg1) ret void } -define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !spir_kernel_omit_args !1 { +define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !sycl_kernel_omit_args !1 { call void @bar(i8 %arg1) call void @bar(i8 %arg2) call void @bar(i8 %arg3) diff --git a/llvm/tools/sycl-post-link/CMakeLists.txt b/llvm/tools/sycl-post-link/CMakeLists.txt index 07fba760ba8f9..95d88d7d49d93 100644 --- a/llvm/tools/sycl-post-link/CMakeLists.txt +++ b/llvm/tools/sycl-post-link/CMakeLists.txt @@ -19,7 +19,7 @@ include_directories( add_llvm_tool(sycl-post-link sycl-post-link.cpp - SPIRKernelParamOptInfo.cpp + SYCLKernelParamOptInfo.cpp SpecConstants.cpp SYCLDeviceLibReqMask.cpp ADDITIONAL_HEADER_DIRS diff --git a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp b/llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.cpp similarity index 70% rename from llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp rename to llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.cpp index 9532ee0e0f836..5a1c6f4e3e03b 100644 --- a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp +++ b/llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.cpp @@ -1,4 +1,4 @@ -//==-- SPIRKernelParamOptInfo.cpp -- get kernel param optimization info ----==// +//==-- SYCLKernelParamOptInfo.cpp -- get kernel param optimization info ----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "SPIRKernelParamOptInfo.h" +#include "SYCLKernelParamOptInfo.h" #include "llvm/IR/Constants.h" #include "llvm/Support/Casting.h" @@ -14,23 +14,23 @@ namespace { // Must match the one produced by DeadArgumentElimination -static constexpr char MetaDataID[] = "spir_kernel_omit_args"; +static constexpr char MetaDataID[] = "sycl_kernel_omit_args"; } // anonymous namespace namespace llvm { -void SPIRKernelParamOptInfo::releaseMemory() { clear(); } +void SYCLKernelParamOptInfo::releaseMemory() { clear(); } -SPIRKernelParamOptInfo -SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) { - SPIRKernelParamOptInfo Res; +SYCLKernelParamOptInfo +SYCLKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + SYCLKernelParamOptInfo Res; for (const Function &F : M) { MDNode *MD = F.getMetadata(MetaDataID); if (!MD) continue; - using BaseTy = SPIRKernelParamOptInfoBaseTy; + using BaseTy = SYCLKernelParamOptInfoBaseTy; auto Ins = Res.insert(BaseTy::value_type{F.getName(), BaseTy::mapped_type{}}); assert(Ins.second && "duplicate kernel?"); @@ -46,6 +46,6 @@ SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) { return Res; } -AnalysisKey SPIRKernelParamOptInfoAnalysis::Key; +AnalysisKey SYCLKernelParamOptInfoAnalysis::Key; -} // namespace llvm \ No newline at end of file +} // namespace llvm diff --git a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h b/llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.h similarity index 72% rename from llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h rename to llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.h index a5ba8f1d4eadb..00eeadedc1e00 100644 --- a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h +++ b/llvm/tools/sycl-post-link/SYCLKernelParamOptInfo.h @@ -1,4 +1,4 @@ -//==-- SPIRKernelParamOptInfo.h -- get kernel param optimization info ------==// +//==-- SYCLKernelParamOptInfo.h -- get kernel param optimization info ------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -21,25 +21,25 @@ namespace llvm { // the StringRef key refers to a function name -using SPIRKernelParamOptInfoBaseTy = DenseMap; +using SYCLKernelParamOptInfoBaseTy = DenseMap; -class SPIRKernelParamOptInfo : public SPIRKernelParamOptInfoBaseTy { +class SYCLKernelParamOptInfo : public SYCLKernelParamOptInfoBaseTy { public: void releaseMemory(); }; -class SPIRKernelParamOptInfoAnalysis - : public AnalysisInfoMixin { - friend AnalysisInfoMixin; +class SYCLKernelParamOptInfoAnalysis + : public AnalysisInfoMixin { + friend AnalysisInfoMixin; static AnalysisKey Key; public: /// Provide the result type for this analysis pass. - using Result = SPIRKernelParamOptInfo; + using Result = SYCLKernelParamOptInfo; /// Run the analysis pass over a function and produce BPI. - SPIRKernelParamOptInfo run(Module &M, ModuleAnalysisManager &AM); + SYCLKernelParamOptInfo run(Module &M, ModuleAnalysisManager &AM); }; } // namespace llvm diff --git a/llvm/tools/sycl-post-link/sycl-post-link.cpp b/llvm/tools/sycl-post-link/sycl-post-link.cpp index 0faf5bc3948cb..c4d99474b886d 100644 --- a/llvm/tools/sycl-post-link/sycl-post-link.cpp +++ b/llvm/tools/sycl-post-link/sycl-post-link.cpp @@ -13,8 +13,8 @@ // - specialization constant intrinsic transformation //===----------------------------------------------------------------------===// -#include "SPIRKernelParamOptInfo.h" #include "SYCLDeviceLibReqMask.h" +#include "SYCLKernelParamOptInfo.h" #include "SpecConstants.h" #include "llvm/ADT/SetVector.h" @@ -583,9 +583,10 @@ static string_vector saveDeviceImageProperty( // Register required analysis MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); // Register the payload analysis - MAM.registerPass([&] { return SPIRKernelParamOptInfoAnalysis(); }); - SPIRKernelParamOptInfo PInfo = - MAM.getResult( + + MAM.registerPass([&] { return SYCLKernelParamOptInfoAnalysis(); }); + SYCLKernelParamOptInfo PInfo = + MAM.getResult( *ResultModules[I].ModulePtr); // convert analysis results into properties and record them