Skip to content

[SYCL] Enable Dead Argument Elimination for NVPTX backend #4617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions clang/lib/Driver/ToolChains/Clang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4635,7 +4635,7 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA,
}

// Turn on Dead Parameter Elimination Optimization with early optimizations
if (!(RawTriple.isNVPTX() || RawTriple.isAMDGCN()) &&
if (!(RawTriple.isAMDGCN()) &&
Args.hasFlag(options::OPT_fsycl_dead_args_optimization,
options::OPT_fno_sycl_dead_args_optimization, false))
CmdArgs.push_back("-fenable-sycl-dae");
Expand Down Expand Up @@ -8920,8 +8920,7 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA,
// -fsycl-device-code-split=auto

// Turn on Dead Parameter Elimination Optimization with early optimizations
if (!(getToolChain().getTriple().isNVPTX() ||
getToolChain().getTriple().isAMDGCN()) &&
if (!(getToolChain().getTriple().isAMDGCN()) &&
TCArgs.hasFlag(options::OPT_fsycl_dead_args_optimization,
options::OPT_fno_sycl_dead_args_optimization, false))
addArgs(CmdArgs, TCArgs, {"-emit-param-info"});
Expand Down
6 changes: 3 additions & 3 deletions clang/test/Driver/sycl-triple-dae-flags.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out
// RUN: FileCheck %s --input-file %t.cuda.out
//
// RUN: %clangxx -### -fsycl -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=gfx906 -fsycl-dead-args-optimization %s 2> %t.rocm.out
// RUN: FileCheck %s --input-file %t.rocm.out
// CHECK-NOT: -fenable-sycl-dae
// CHECK-NOT: -emit-param-info
//
// RUN: %clangxx -### -fsycl -fsycl-targets=nvptx64-nvidia-cuda -fsycl-dead-args-optimization %s 2> %t.cuda.out
// RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.cuda.out
//
// RUN: %clangxx -### -fsycl -fsycl-targets=spir64-unknown-unknown -fsycl-dead-args-optimization %s 2> %t.out
// RUN: FileCheck %s --check-prefixes=CHECK-FENABLE,CHECK-EMIT --input-file %t.out
// CHECK-FENABLE: -fenable-sycl-dae
Expand Down
52 changes: 46 additions & 6 deletions llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
#include <map>
Expand Down Expand Up @@ -74,9 +75,9 @@ class DeadArgumentEliminationPass
enum Liveness { Live, MaybeLive };

DeadArgumentEliminationPass(bool ShouldHackArguments_ = false,
bool CheckSpirKernels_ = false)
bool CheckSYCLKernels_ = false)
: ShouldHackArguments(ShouldHackArguments_),
CheckSpirKernels(CheckSpirKernels_) {}
CheckSYCLKernels(CheckSYCLKernels_) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &);

Expand Down Expand Up @@ -123,9 +124,9 @@ class DeadArgumentEliminationPass
/// (used only by bugpoint).
bool ShouldHackArguments = false;

/// This allows to eliminate dead arguments in SPIR kernel functions with
/// external linkage in SYCL environment
bool CheckSpirKernels = false;
/// This allows to eliminate dead arguments in SYCL kernel wrapper functions
/// with external linkage
bool CheckSYCLKernels = false;

private:
Liveness MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses);
Expand All @@ -143,6 +144,45 @@ class DeadArgumentEliminationPass
bool RemoveDeadStuffFromFunction(Function *F);
bool DeleteDeadVarargs(Function &Fn);
bool RemoveDeadArgumentsFromCallers(Function &Fn);

void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF);
llvm::DenseSet<Function *> NVPTXKernelSet;

bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); };

void BuildNVPTXKernelSet(const Module &M) {

auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
if (!NvvmMetadata)
return;

for (auto *MetadataNode : NvvmMetadata->operands()) {
if (MetadataNode->getNumOperands() != 3)
continue;

// NVPTX identifies kernel entry points using metadata nodes of the form:
// !X = !{<function>, !"kernel", i32 1}
auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(1));
// Only process kernel entry points.
if (!Type || Type->getString() != "kernel")
continue;

// Get a pointer to the entry point function from the metadata.
if (const auto &FuncOperand = MetadataNode->getOperand(0)) {
if (auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand)) {
if (auto *Func = dyn_cast<Function>(FuncConstant->getValue())) {
if (auto *Val = mdconst::dyn_extract<ConstantInt>(
MetadataNode->getOperand(2))) {
if (Val->getValue() == 1) {
NVPTXKernelSet.insert(Func);
}
}
}
}
}
}
return;
}
};

class DeadArgumentEliminationSYCLPass
Expand All @@ -155,7 +195,7 @@ class DeadArgumentEliminationSYCLPass
private:
DeadArgumentEliminationPass Impl =
DeadArgumentEliminationPass(/* ShouldHackArguemtns */ false,
/* CheckSpirKernels */ true);
/* CheckSYCLKernels */ true);
};

} // end namespace llvm
Expand Down
58 changes: 42 additions & 16 deletions llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ namespace {
if (skipModule(M))
return false;
DeadArgumentEliminationPass DAEP(ShouldHackArguments(),
CheckSpirKernels());
CheckSYCLKernels());
ModuleAnalysisManager DummyMAM;
PreservedAnalyses PA = DAEP.run(M, DummyMAM);
return !PA.areAllPreserved();
}

virtual bool ShouldHackArguments() const { return false; }
virtual bool CheckSpirKernels() const { return false; }
virtual bool CheckSYCLKernels() const { return false; }
};

} // end anonymous namespace
Expand All @@ -105,7 +105,7 @@ namespace {
DAH() : DAE(ID) {}

bool ShouldHackArguments() const override { return true; }
bool CheckSpirKernels() const override { return false; }
bool CheckSYCLKernels() const override { return false; }
};

} // end anonymous namespace
Expand All @@ -118,7 +118,7 @@ INITIALIZE_PASS(DAH, "deadarghaX0r",

namespace {

/// DAESYCL - DeadArgumentElimination pass for SPIR kernel functions even
/// DAESYCL - DeadArgumentElimination pass for SYCL kernel functions even
/// if they are external.
struct DAESYCL : public DAE {
static char ID;
Expand All @@ -128,21 +128,19 @@ struct DAESYCL : public DAE {
}

StringRef getPassName() const override {
return "Dead Argument Elimination for SPIR kernels in SYCL environment";
return "Dead Argument Elimination for SYCL kernels";
}

bool ShouldHackArguments() const override { return false; }
bool CheckSpirKernels() const override { return true; }
bool CheckSYCLKernels() const override { return true; }
};

} // end anonymous namespace

char DAESYCL::ID = 0;

INITIALIZE_PASS(
DAESYCL, "deadargelim-sycl",
"Dead Argument Elimination for SPIR kernels in SYCL environment", false,
false)
INITIALIZE_PASS(DAESYCL, "deadargelim-sycl",
"Dead Argument Elimination for SYCL kernels", false, false)

/// createDeadArgEliminationPass - This pass removes arguments from functions
/// which are not used by the body of the function.
Expand Down Expand Up @@ -573,12 +571,13 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) {
}

// We can't modify arguments if the function is not local
// but we can do so for SPIR kernel function in SYCL environment.
// but we can do so for SYCL kernel functions.
// DAE is not currently supported for ESIMD kernels.
bool FuncIsSpirNonEsimdKernel =
CheckSpirKernels && F.getCallingConv() == CallingConv::SPIR_KERNEL &&
bool FuncIsSyclNonEsimdKernel =
CheckSYCLKernels &&
(F.getCallingConv() == CallingConv::SPIR_KERNEL || IsNVPTXKernel(&F)) &&
!F.getMetadata("sycl_explicit_simd");
bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSpirNonEsimdKernel;
bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSyclNonEsimdKernel;
if (FuncIsLive && (!ShouldHackArguments || F.isIntrinsic())) {
MarkLive(F);
return;
Expand Down Expand Up @@ -812,15 +811,15 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
}
}

if (CheckSpirKernels) {
if (CheckSYCLKernels) {
SmallVector<Metadata *, 10> MDOmitArgs;
auto MDOmitArgTrue = llvm::ConstantAsMetadata::get(
ConstantInt::get(Type::getInt1Ty(F->getContext()), 1));
auto MDOmitArgFalse = llvm::ConstantAsMetadata::get(
ConstantInt::get(Type::getInt1Ty(F->getContext()), 0));
for (auto &AliveArg : ArgAlive)
MDOmitArgs.push_back(AliveArg ? MDOmitArgFalse : MDOmitArgTrue);
F->setMetadata("spir_kernel_omit_args",
F->setMetadata("sycl_kernel_omit_args",
llvm::MDNode::get(F->getContext(), MDOmitArgs));
}

Expand Down Expand Up @@ -1131,6 +1130,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
for (auto MD : MDs)
NF->addMetadata(MD.first, *MD.second);

if (IsNVPTXKernel(F))
UpdateNVPTXMetadata(*(F->getParent()), F, NF);

// Now that the old function is dead, delete it.
F->eraseFromParent();

Expand All @@ -1141,6 +1143,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
ModuleAnalysisManager &) {
bool Changed = false;

BuildNVPTXKernelSet(M);

// First pass: Do a simple check to see if any functions can have their "..."
// removed. We can do this if they never call va_start. This loop cannot be
// fused with the next loop, because deleting a function invalidates
Expand Down Expand Up @@ -1173,3 +1177,25 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}

void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
Function *NF) {

auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
if (!NvvmMetadata)
return;

for (auto *MetadataNode : NvvmMetadata->operands()) {
const auto &FuncOperand = MetadataNode->getOperand(0);
if (!FuncOperand)
continue;
auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
if (!FuncConstant)
continue;
auto *Func = dyn_cast<Function>(FuncConstant->getValue());
if (Func != F)
continue;
// Update the metadata with the new function
MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF));
}
}
2 changes: 1 addition & 1 deletion llvm/test/Transforms/DeadArgElim/sycl-kernels-neg.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ target triple = "spir64"
; This test ensures dead arguments are not eliminated
; from a global function that is not a SPIR kernel.

; CHECK-NOT: !spir_kernel_omit_args
; CHECK-NOT: !sycl_kernel_omit_args

define weak_odr void @NotASpirKernel(float %arg1, float %arg2) {
; CHECK-LABEL: define {{[^@]+}}@NotASpirKernel
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ define weak_odr spir_kernel void @SpirKernel1(float %arg1, float %arg2) {
; CHECK-NEXT: ret void
;
; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel1
; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS1:[0-9]]]
; CHECK-SYCL-SAME: (float [[ARG1:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS1:[0-9]]]
; CHECK-SYCL-NEXT: call void @foo(float [[ARG1]])
; CHECK-SYCL-NEXT: ret void

Expand All @@ -29,7 +29,7 @@ define weak_odr spir_kernel void @SpirKernel2(float %arg1, float %arg2) {
; CHECK-NEXT: ret void
;
; CHECK-SYCL-LABEL: define {{[^@]+}}@SpirKernel2
; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !spir_kernel_omit_args ![[KERN_ARGS2:[0-9]]]
; CHECK-SYCL-SAME: (float [[ARG2:%.*]]) !sycl_kernel_omit_args ![[KERN_ARGS2:[0-9]]]
; CHECK-SYCL-NEXT: call void @foo(float [[ARG2]])
; CHECK-SYCL-NEXT: ret void

Expand Down
4 changes: 2 additions & 2 deletions llvm/test/tools/sycl-post-link/omit_kernel_args.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

target triple = "spir64-unknown-unknown"

define weak_odr spir_kernel void @SpirKernel1(float %arg1) !spir_kernel_omit_args !0 {
define weak_odr spir_kernel void @SpirKernel1(float %arg1) !sycl_kernel_omit_args !0 {
call void @foo(float %arg1)
ret void
}

define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !spir_kernel_omit_args !1 {
define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !sycl_kernel_omit_args !1 {
call void @bar(i8 %arg1)
call void @bar(i8 %arg2)
call void @bar(i8 %arg3)
Expand Down
2 changes: 1 addition & 1 deletion llvm/tools/sycl-post-link/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include_directories(

add_llvm_tool(sycl-post-link
sycl-post-link.cpp
SPIRKernelParamOptInfo.cpp
SYCLKernelParamOptInfo.cpp
SpecConstants.cpp
SYCLDeviceLibReqMask.cpp
ADDITIONAL_HEADER_DIRS
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
//==-- SPIRKernelParamOptInfo.cpp -- get kernel param optimization info ----==//
//==-- SYCLKernelParamOptInfo.cpp -- get kernel param optimization info ----==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "SPIRKernelParamOptInfo.h"
#include "SYCLKernelParamOptInfo.h"

#include "llvm/IR/Constants.h"
#include "llvm/Support/Casting.h"

namespace {

// Must match the one produced by DeadArgumentElimination
static constexpr char MetaDataID[] = "spir_kernel_omit_args";
static constexpr char MetaDataID[] = "sycl_kernel_omit_args";

} // anonymous namespace

namespace llvm {

void SPIRKernelParamOptInfo::releaseMemory() { clear(); }
void SYCLKernelParamOptInfo::releaseMemory() { clear(); }

SPIRKernelParamOptInfo
SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
SPIRKernelParamOptInfo Res;
SYCLKernelParamOptInfo
SYCLKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
SYCLKernelParamOptInfo Res;

for (const Function &F : M) {
MDNode *MD = F.getMetadata(MetaDataID);
if (!MD)
continue;
using BaseTy = SPIRKernelParamOptInfoBaseTy;
using BaseTy = SYCLKernelParamOptInfoBaseTy;
auto Ins =
Res.insert(BaseTy::value_type{F.getName(), BaseTy::mapped_type{}});
assert(Ins.second && "duplicate kernel?");
Expand All @@ -46,6 +46,6 @@ SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
return Res;
}

AnalysisKey SPIRKernelParamOptInfoAnalysis::Key;
AnalysisKey SYCLKernelParamOptInfoAnalysis::Key;

} // namespace llvm
} // namespace llvm
Loading