Skip to content

[SYCL] Initial printf support for non-constant AS format strings #5069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7f05718
[SYCL] Initial printf support for non-constant AS format strings
Nov 23, 2021
0039b2c
Code style & a few conceptual upgrades
Dec 2, 2021
a24796d
Register the pass with the new pass manager
Dec 3, 2021
b5dc85a
Merge remote-tracking branch 'intel/sycl' into printf-address-space
Dec 6, 2021
f134ee8
Ensure correct llvm::Type* for literal & more refactoring
Dec 8, 2021
c16fcbb
Improve the function search
Dec 8, 2021
41eaf3b
Ugly implementation for the O1 case
Dec 8, 2021
a26f7c8
Extend value stripping to loads/stores for O0 cases
Dec 9, 2021
9150b8d
Replace call operands instead of re-creating calls
Dec 9, 2021
f467434
Huge refactoring
Dec 9, 2021
00066b6
More refactoring & descriptive comments added
Dec 9, 2021
896f296
Fix comment style
Dec 9, 2021
2fa7434
Refactor getCASLiteral back to being a function
Dec 9, 2021
ac0372b
Merge remote-tracking branch 'intel/sycl' into printf-address-space
Dec 10, 2021
1e83f16
Avoid "null character SetName" assertions caused by StringRef specifics
Dec 13, 2021
5c5b875
Improve formatting
Dec 13, 2021
e5ae867
Avoid using the error-prone Module::getOrInsertGlobal w/ callback
Dec 14, 2021
39efbb4
Wrap the newly created globals into constant pointer casts if needed
Dec 14, 2021
4d1c6be
Replace the wrapper calls directly with builtin calls
Dec 14, 2021
87d4137
Add LIT tests
Dec 13, 2021
6c8122b
Simplify call replacing (since we're always dealing with constants)
Dec 15, 2021
c1563da
Correct registration with the new PM & run LITs with that as well
Dec 15, 2021
f783309
[Review] Move/rename the LIT tests
Dec 16, 2021
52f7ca7
[Review] Address LIT-related comments
Dec 16, 2021
a37be74
[Review] Address code comments, pt. I
Dec 16, 2021
d2fb5f9
[Review][RFC] Address code comments, pt. II
Dec 16, 2021
6959690
[Review] Address code comments, pt. III
Dec 16, 2021
dd68d58
[Review] Update CODEOWNERS
Dec 16, 2021
d74fb42
[Review] Address code comments, pt. IV
Dec 16, 2021
040752c
[Review] Update the comment section in builtins.hpp
Dec 16, 2021
ee76f9d
Fix LIT setup for negative checks & factor them out into a separate test
Dec 20, 2021
1e3493e
Remove unneeded CHECK-NOT's from regular LIT
Dec 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ sycl/doc/extensions/ExplicitSIMD/ @kbobrovs @v-klochkov @kychendev
llvm/lib/Transforms/Instrumentation/SPIRITTAnnotations.cpp @MrSidims @vzakhari
llvm/include/llvm/Transforms/Instrumentation/SPIRITTAnnotations.h @MrSidims @vzakhari
llvm/test/Transforms/SPIRITTAnnotations/* @MrSidims @vzakhari

# Generic address space support for printf
llvm/lib/SYCLLowerIR/MutatePrintfAddrspace.cpp @AGindinson @AlexeySachkov @mlychkov
llvm/include/llvm/SYCLLowerIR/MutatePrintfAddrspace.h @AGindinson @AlexeySachkov @mlychkov
llvm/test/SYCLLowerIR/printf_addrspace/* @AGindinson @AlexeySachkov @mlychkov
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/BackendUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
#include "llvm/Support/BuryPointer.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
Expand Down Expand Up @@ -1053,6 +1054,7 @@ void EmitAssemblyHelper::EmitAssemblyWithLegacyPassManager(
if (CodeGenOpts.DisableLLVMPasses)
PerModulePasses.add(createAlwaysInlinerLegacyPass(false));
PerModulePasses.add(createSYCLLowerWGLocalMemoryLegacyPass());
PerModulePasses.add(createSYCLMutatePrintfAddrspaceLegacyPass());
}

switch (Action) {
Expand Down Expand Up @@ -1470,6 +1472,9 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
MPM.addPass(ModuleMemProfilerPass());
}
}
if (LangOpts.SYCLIsDevice) {
MPM.addPass(SYCLMutatePrintfAddrspacePass());
}

// Add a verifier pass if requested. We don't have to do this if the action
// requires code generation because there will already be a verifier pass in
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ void initializeStripSymbolsPass(PassRegistry&);
void initializeStructurizeCFGLegacyPassPass(PassRegistry &);
void initializeSYCLLowerWGScopeLegacyPassPass(PassRegistry &);
void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
void initializeSYCLMutatePrintfAddrspaceLegacyPassPass(PassRegistry &);
void initializeSPIRITTAnnotationsLegacyPassPass(PassRegistry &);
void initializeESIMDLowerLoadStorePass(PassRegistry &);
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/LinkAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
#include "llvm/Support/Valgrind.h"
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
#include "llvm/Transforms/IPO.h"
Expand Down
32 changes: 32 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/MutatePrintfAddrspace.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===------- MutatePrintfAddrspace.h - SYCL printf AS mutation Pass -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A transformation pass which detects non-constant address space
// literals usage for the first argument of SYCL experimental printf
// function, and moves the string literal to constant address
// space. This a temporary solution for printf's support of generic
// address space literals; the pass should be dropped once SYCL device
// backends learn to handle the generic address-spaced argument properly.
//===----------------------------------------------------------------------===//

#pragma once

#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"

namespace llvm {

class SYCLMutatePrintfAddrspacePass
: public PassInfoMixin<SYCLMutatePrintfAddrspacePass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
};

ModulePass *createSYCLMutatePrintfAddrspaceLegacyPass();

} // namespace llvm
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
#include "llvm/SYCLLowerIR/ESIMDVerifier.h"
#include "llvm/SYCLLowerIR/LowerESIMD.h"
#include "llvm/SYCLLowerIR/LowerWGScope.h"
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass())
MODULE_PASS("LowerESIMD", SYCLLowerESIMDPass())
MODULE_PASS("ESIMDLowerVecArg", ESIMDLowerVecArgPass())
MODULE_PASS("esimd-verifier", ESIMDVerifierPass())
MODULE_PASS("SYCLMutatePrintfAddrspace", SYCLMutatePrintfAddrspacePass())
MODULE_PASS("SPIRITTAnnotations", SPIRITTAnnotationsPass())
MODULE_PASS("deadargelim-sycl", DeadArgumentEliminationSYCLPass())
#undef MODULE_PASS
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
LowerESIMDVecArg.cpp
LowerWGLocalMemory.cpp
ESIMDVerifier.cpp
MutatePrintfAddrspace.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Expand Down
253 changes: 253 additions & 0 deletions llvm/lib/SYCLLowerIR/MutatePrintfAddrspace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
//===------ MutatePrintfAddrspace.cpp - SYCL printf AS mutation Pass ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A transformation pass which detects non-constant address space
// literals usage for the first argument of SYCL experimental printf
// function, and moves the string literal to constant address
// space. This a temporary solution for printf's support of generic
// address space literals; the pass should be dropped once SYCL device
// backends learn to handle the generic address-spaced argument properly.
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"

#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"

using namespace llvm;

namespace {
// Wrapper for the pass to make it working with the old pass manager
class SYCLMutatePrintfAddrspaceLegacyPass : public ModulePass {
public:
static char ID;
SYCLMutatePrintfAddrspaceLegacyPass() : ModulePass(ID) {
initializeSYCLMutatePrintfAddrspaceLegacyPassPass(
*PassRegistry::getPassRegistry());
}

// run the SYCLMutatePrintfAddrspace pass on the specified module
bool runOnModule(Module &M) override {
ModuleAnalysisManager MAM;
auto PA = Impl.run(M, MAM);
return !PA.areAllPreserved();
}

private:
SYCLMutatePrintfAddrspacePass Impl;
};

static constexpr unsigned ConstantAddrspaceID = 2;
// If the variadic version gets picked during FE compilation, we'll only have
// 1 function to replace. However, unique declarations are emitted for each
// of the non-variadic (variadic template) calls.
using FunctionVecTy = SmallVector<Function *, 8>;

Function *getCASPrintfFunction(Module &M, PointerType *CASLiteralType);
size_t setFuncCallsOntoCASPrintf(Function *F, Function *CASPrintfFunc,
FunctionVecTy &FunctionsToDrop);
} // namespace

char SYCLMutatePrintfAddrspaceLegacyPass::ID = 0;
INITIALIZE_PASS(SYCLMutatePrintfAddrspaceLegacyPass,
"SYCLMutatePrintfAddrspace",
"Move SYCL printf literal arguments to constant address space",
false, false)

// Public interface to the SYCLMutatePrintfAddrspacePass.
ModulePass *llvm::createSYCLMutatePrintfAddrspaceLegacyPass() {
return new SYCLMutatePrintfAddrspaceLegacyPass();
}

PreservedAnalyses
SYCLMutatePrintfAddrspacePass::run(Module &M, ModuleAnalysisManager &MAM) {
Type *Int8Type = Type::getInt8Ty(M.getContext());
auto *CASLiteralType = PointerType::get(Int8Type, ConstantAddrspaceID);
Function *CASPrintfFunc = getCASPrintfFunction(M, CASLiteralType);

FunctionVecTy FunctionsToDrop;
bool ModuleChanged = false;
for (Function &F : M) {
if (!F.isDeclaration())
continue;
if (!F.getName().startswith("_Z18__spirv_ocl_printf"))
continue;
if (F.getArg(0)->getType() == CASLiteralType)
// No need to replace the literal type and its printf users
continue;
ModuleChanged |=
setFuncCallsOntoCASPrintf(&F, CASPrintfFunc, FunctionsToDrop);
}
for (Function *F : FunctionsToDrop)
F->eraseFromParent();

return ModuleChanged ? PreservedAnalyses::all() : PreservedAnalyses::none();
}

/// Helper implementations
namespace {

/// Get the constant addrspace version of the __spirv_ocl_printf declaration,
/// or generate it if the IR module doesn't have it yet. Also make it
/// variadic so that it could replace all non-variadic generic AS versions.
Function *getCASPrintfFunction(Module &M, PointerType *CASLiteralType) {
Type *Int32Type = Type::getInt32Ty(M.getContext());
auto *CASPrintfFuncTy = FunctionType::get(Int32Type, CASLiteralType,
/*isVarArg=*/true);
// extern int __spirv_ocl_printf(
// const __attribute__((opencl_constant)) char *Format, ...)
FunctionCallee CASPrintfFuncCallee =
M.getOrInsertFunction("_Z18__spirv_ocl_printfPU3AS2Kcz", CASPrintfFuncTy);
auto *CASPrintfFunc = cast<Function>(CASPrintfFuncCallee.getCallee());
CASPrintfFunc->setCallingConv(CallingConv::SPIR_FUNC);
CASPrintfFunc->setDSOLocal(true);
return CASPrintfFunc;
}

/// Generate the constant addrspace version of the generic addrspace-residing
/// global string. If one exists already, get it from the module.
Constant *getCASLiteral(GlobalVariable *GenericASLiteral) {
Module *M = GenericASLiteral->getParent();
// Appending the stable suffix ensures that only one CAS copy is made for each
// string. In case of the matching name, llvm::Module APIs will ensure that
// the existing global is returned.
std::string CASLiteralName = GenericASLiteral->getName().str() + "._AS2";
if (GlobalVariable *ExistingGlobal =
M->getGlobalVariable(CASLiteralName, /*AllowInternal=*/true))
return ExistingGlobal;

StringRef LiteralValue;
getConstantStringInfo(GenericASLiteral, LiteralValue);
IRBuilder<> Builder(M->getContext());
GlobalVariable *Res = Builder.CreateGlobalString(LiteralValue, CASLiteralName,
ConstantAddrspaceID, M);
Res->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
Res->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
return Res;
}

/// Encapsulates the update of CallInst's literal argument.
void setCallArgOntoCASPrintf(CallInst *CI, Constant *CASArg,
Function *CASPrintfFunc) {
CI->setCalledFunction(CASPrintfFunc);
auto *Const = CASArg;
// In case there's a misalignment between the updated function type and
// the constant literal type, create a constant pointer cast so as to
// duck module verifier complaints.
Type *ParamType = CASPrintfFunc->getFunctionType()->getParamType(0);
if (Const->getType() != ParamType)
Const = ConstantExpr::getPointerCast(Const, ParamType);
CI->setArgOperand(0, Const);
}

/// The function's effect is similar to V->stripPointerCastsAndAliases(), but
/// also strips load/store aliases.
/// NB: This function can only operate on simple CFG, where load/store pairs
/// leading to the global variable are merely a consequence of low optimization
/// level. Re-using it for complex CFG with arbitrary memory paths is definitely
/// not recommended.
Value *stripToMemorySource(Value *V) {
Value *MemoryAccess = V;
if (auto *LI = dyn_cast<LoadInst>(MemoryAccess)) {
Value *LoadSource = LI->getPointerOperand();
auto *Store = cast<StoreInst>(*llvm::find_if(
LoadSource->users(), [](User *U) { return isa<StoreInst>(U); }));
MemoryAccess = Store->getValueOperand();
}
return MemoryAccess->stripPointerCastsAndAliases();
}

void emitError(Function *PrintfInstance, CallInst *PrintfCall,
StringRef RecommendationToUser = "") {
std::string ErrorMsg =
std::string("experimental::printf requires format string to reside "
"in constant "
"address space. The compiler wasn't able to "
"automatically convert "
"your format string into constant address space when "
"processing builtin ") +
PrintfInstance->getName().str() + " called in function " +
PrintfCall->getFunction()->getName().str() + ".\n" +
RecommendationToUser.str();
PrintfInstance->getContext().emitError(PrintfCall, ErrorMsg);
}

/// This routine goes over CallInst users of F, resetting the called function
/// to CASPrintfFunc and generating/retracting constant addrspace format
/// strings to use as operands of the mutated calls.
size_t setFuncCallsOntoCASPrintf(Function *F, Function *CASPrintfFunc,
FunctionVecTy &FunctionsToDrop) {
size_t MutatedCallsCount = 0;
SmallVector<std::pair<CallInst *, Constant *>, 16> CallsToMutate;
for (User *U : F->users()) {
if (!isa<CallInst>(U))
continue;
auto *CI = cast<CallInst>(U);

// This key algorithm reaches the global string used as an argument to a
// __spirv_ocl_printf call. It then generates a constant AS copy of that
// global (or gets an existing one). For the return value, the call
// instruction is paired with its future constant addrspace string
// argument.
Value *Stripped = stripToMemorySource(CI->getArgOperand(0));
if (auto *Literal = dyn_cast<GlobalVariable>(Stripped))
CallsToMutate.emplace_back(CI, getCASLiteral(Literal));
else if (auto *Arg = dyn_cast<Argument>(Stripped)) {
// The global literal is passed to __spirv_ocl_printf via a wrapper
// function argument. We'll update the wrapper calls to use the builtin
// function directly instead.
Function *WrapperFunc = Arg->getParent();
std::string BadWrapperErrorMsg =
"Consider simplifying the code by "
"passing format strings directly into experimental::printf calls, "
"avoiding indirection via wrapper function arguments.";
if (!WrapperFunc->getName().contains("6oneapi12experimental6printf")) {
emitError(WrapperFunc, CI, BadWrapperErrorMsg);
return 0;
}
for (User *WrapperU : WrapperFunc->users()) {
auto *WrapperCI = cast<CallInst>(WrapperU);
Value *StrippedArg = stripToMemorySource(WrapperCI->getArgOperand(0));
auto *Literal = dyn_cast<GlobalVariable>(StrippedArg);
// We only expect 1 level of wrappers
if (!Literal) {
emitError(WrapperFunc, WrapperCI, BadWrapperErrorMsg);
return 0;
}
CallsToMutate.emplace_back(WrapperCI, getCASLiteral(Literal));
}
// We're certain that the wrapper won't have any uses, since we've just
// marked all its calls for replacement with __spirv_ocl_printf.
FunctionsToDrop.emplace_back(WrapperFunc);
// Similar certainty for the generic AS version of __spirv_ocl_printf
// itself - we've determined it only gets called inside the
// soon-to-be-removed wrapper.
assert(F->hasOneUse() && "Unexpected __spirv_ocl_printf call outside of "
"SYCL wrapper function");
FunctionsToDrop.emplace_back(F);
} else {
emitError(
F, CI,
"Make sure each format string literal is "
"known at compile time or use OpenCL constant address space literals "
"for device-side printf calls");
return 0;
}
}
for (auto &CallConstantPair : CallsToMutate) {
setCallArgOntoCASPrintf(CallConstantPair.first, CallConstantPair.second,
CASPrintfFunc);
++MutatedCallsCount;
}
if (F->hasNUses(0))
FunctionsToDrop.emplace_back(F);
return MutatedCallsCount;
}
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <CL/sycl.hpp>

using namespace sycl;

int main() {
queue q;
q.submit([&](handler &cgh) {
cgh.single_task([=]() {
ext::oneapi::experimental::printf("String No. %f\n", 1.0f);
const char *IntFormatString = "String No. %i\n";
ext::oneapi::experimental::printf(IntFormatString, 2);
ext::oneapi::experimental::printf(IntFormatString, 3);
});
});

return 0;
}
Loading