diff --git a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h new file mode 100644 index 0000000000000..8e2a6127df7dd --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h @@ -0,0 +1,44 @@ +//===-------- SplitModuleByCategory.h - module split ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Functionality to split a module by categories. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H +#define LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H + +#include "llvm/ADT/STLFunctionalExtras.h" + +#include +#include +#include + +namespace llvm { + +class Module; +class Function; + +/// Splits the given module \p M using the given \p FunctionCategorizer. +/// \p FunctionCategorizer returns integer category for an input Function. +/// It may return std::nullopt if a function doesn't have a category. +/// Module's functions are being grouped by categories. Every such group +/// populates a call graph containing group's functions themselves and all +/// reachable functions and globals. Split outputs are populated from each call +/// graph associated with some category. +/// +/// Every split output is being passed to \p Callback for further possible +/// processing. +/// +/// Currently, the supported targets are SPIRV, AMDGPU and NVPTX. +void splitModuleByCategory( + std::unique_ptr M, + function_ref(const Function &F)> FunctionCategorizer, + function_ref Part)> Callback); + +} // namespace llvm + +#endif // LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 78cad0d253be8..bcc081d1f91d3 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -82,6 +82,7 @@ add_llvm_component_library(LLVMTransformUtils SimplifyLibCalls.cpp SizeOpts.cpp SplitModule.cpp + SplitModuleByCategory.cpp StripNonLineTableDebugInfo.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp diff --git a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp new file mode 100644 index 0000000000000..dcc9624088c0b --- /dev/null +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -0,0 +1,325 @@ +//===-------- SplitModuleByCategory.cpp - split a module by categories ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// See comments in the header. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/SplitModuleByCategory.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "split-module-by-category" + +namespace { + +// A vector that contains a group of function with the same category. +using EntryPointSet = SetVector; + +/// Represents a group of functions with one category. +struct EntryPointGroup { + int ID; + EntryPointSet Functions; + + EntryPointGroup() = default; + + EntryPointGroup(int ID, EntryPointSet Functions = EntryPointSet()) + : ID(ID), Functions(std::move(Functions)) {} + + void clear() { Functions.clear(); } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + constexpr size_t INDENT = 4; + dbgs().indent(INDENT) << "ENTRY POINTS" + << " " << ID << " {\n"; + for (const Function *F : Functions) + dbgs().indent(INDENT) << " " << F->getName() << "\n"; + + dbgs().indent(INDENT) << "}\n"; + } +#endif +}; + +/// Annotates an llvm::Module with information necessary to perform and track +/// the result of code (llvm::Module instances) splitting: +/// - entry points group from the module. +class ModuleDesc { + std::unique_ptr M; + EntryPointGroup EntryPoints; + +public: + ModuleDesc(std::unique_ptr M, + EntryPointGroup EntryPoints = EntryPointGroup()) + : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { + assert(this->M && "Module should be non-null"); + } + + Module &getModule() { return *M; } + const Module &getModule() const { return *M; } + + std::unique_ptr releaseModule() { + EntryPoints.clear(); + return std::move(M); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; + EntryPoints.dump(); + dbgs() << "}\n"; + } +#endif +}; + +bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; +} + +// Represents "dependency" or "use" graph of global objects (functions and +// global variables) in a module. It is used during device code split to +// understand which global variables and functions (other than entry points) +// should be included into a split module. +// +// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent +// the fact that if "A" is included into a module, then "B" should be included +// as well. +// +// Examples of dependencies which are represented in this graph: +// - Function FA calls function FB +// - Function FA uses global variable GA +// - Global variable GA references (initialized with) function FB +// - Function FA stores address of a function FB somewhere +// +// The following cases are treated as dependencies between global objects: +// 1. Global object A is used within by a global object B in any way (store, +// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the +// graph; +// 2. function A performs an indirect call of a function with signature S and +// there is a function B with signature S. "A" -> "B" edge will be added to +// the graph; +class DependencyGraph { +public: + using GlobalSet = SmallPtrSet; + + DependencyGraph(const Module &M) { + // Group functions by their signature to handle case (2) described above + DenseMap + FuncTypeToFuncsMap; + for (const Function &F : M.functions()) { + // Kernels can't be called (either directly or indirectly). + if (isKernel(F)) + continue; + + FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); + } + + for (const Function &F : M.functions()) { + // case (1), see comment above the class definition + for (const Value *U : F.users()) + addUserToGraphRecursively(cast(U), &F); + + // case (2), see comment above the class definition + for (const Instruction &I : instructions(F)) { + const CallBase *CB = dyn_cast(&I); + if (!CB || !CB->isIndirectCall()) // Direct calls were handled above + continue; + + const FunctionType *Signature = CB->getFunctionType(); + GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature]; + Graph.emplace_or_assign(&F, std::move(PotentialCallees)); + } + } + + // And every global variable (but their handling is a bit simpler) + for (const GlobalVariable &GV : M.globals()) + for (const Value *U : GV.users()) + addUserToGraphRecursively(cast(U), &GV); + } + + iterator_range + dependencies(const GlobalValue *Val) const { + auto It = Graph.find(Val); + return (It == Graph.end()) + ? make_range(EmptySet.begin(), EmptySet.end()) + : make_range(It->second.begin(), It->second.end()); + } + +private: + void addUserToGraphRecursively(const User *Root, const GlobalValue *V) { + SmallVector WorkList; + WorkList.push_back(Root); + + while (!WorkList.empty()) { + const User *U = WorkList.pop_back_val(); + if (const auto *I = dyn_cast(U)) { + const Function *UFunc = I->getFunction(); + Graph[UFunc].insert(V); + } else if (isa(U)) { + if (const auto *GV = dyn_cast(U)) + Graph[GV].insert(V); + // This could be a global variable or some constant expression (like + // bitcast or gep). We trace users of this constant further to reach + // global objects they are used by and add them to the graph. + for (const User *UU : U->users()) + WorkList.push_back(UU); + } else { + llvm_unreachable("Unhandled type of function user"); + } + } + } + + DenseMap Graph; + SmallPtrSet EmptySet; +}; + +void collectFunctionsAndGlobalVariablesToExtract( + SetVector &GVs, const Module &M, + const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { + // We start with module entry points + for (const Function *F : ModuleEntryPoints.Functions) + GVs.insert(F); + + // Non-discardable global variables are also include into the initial set + for (const GlobalVariable &GV : M.globals()) + if (!GV.isDiscardableIfUnused()) + GVs.insert(&GV); + + // GVs has SetVector type. This type inserts a value only if it is not yet + // present there. So, recursion is not expected here. + size_t Idx = 0; + while (Idx < GVs.size()) { + const GlobalValue *Obj = GVs[Idx++]; + + for (const GlobalValue *Dep : DG.dependencies(Obj)) { + if (const auto *Func = dyn_cast(Dep)) { + if (!Func->isDeclaration()) + GVs.insert(Func); + } else { + GVs.insert(Dep); // Global variables are added unconditionally + } + } + } +} + +ModuleDesc extractSubModule(const Module &M, + const SetVector &GVs, + EntryPointGroup ModuleEntryPoints) { + // For each group of entry points collect all dependencies. + ValueToValueMapTy VMap; + // Clone definitions only for needed globals. Others will be added as + // declarations and removed later. + std::unique_ptr SubM = CloneModule( + M, VMap, [&](const GlobalValue *GV) { return GVs.contains(GV); }); + // Replace entry points with cloned ones. + EntryPointSet NewEPs; + const EntryPointSet &EPs = ModuleEntryPoints.Functions; + llvm::for_each( + EPs, [&](const Function *F) { NewEPs.insert(cast(VMap[F])); }); + ModuleEntryPoints.Functions = std::move(NewEPs); + return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; +} + +// The function produces a copy of input LLVM IR module M with only those +// functions and globals that can be called from entry points that are specified +// in ModuleEntryPoints vector, in addition to the entry point functions. +ModuleDesc extractCallGraph(const Module &M, EntryPointGroup ModuleEntryPoints, + const DependencyGraph &DG) { + SetVector GVs; + collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG); + + ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints)); + LLVM_DEBUG(SplitM.dump()); + return SplitM; +} + +using EntryPointGroupVec = SmallVector; + +/// Module Splitter. +/// It gets a module (in a form of module descriptor, to get additional info) +/// and a collection of entry points groups. Each group specifies subset entry +/// points from input module that should be included in a split module. +class ModuleSplitter { +private: + ModuleDesc Input; + EntryPointGroupVec Groups; + DependencyGraph DG; + +private: + EntryPointGroup drawEntryPointGroup() { + assert(Groups.size() > 0 && "Reached end of entry point groups list."); + EntryPointGroup Group = std::move(Groups.back()); + Groups.pop_back(); + return Group; + } + +public: + ModuleSplitter(ModuleDesc MD, EntryPointGroupVec GroupVec) + : Input(std::move(MD)), Groups(std::move(GroupVec)), + DG(Input.getModule()) { + assert(!Groups.empty() && "Entry points groups collection is empty!"); + } + + /// Gets next subsequence of entry points in an input module and provides + /// split submodule containing these entry points and their dependencies. + ModuleDesc getNextSplit() { + return extractCallGraph(Input.getModule(), drawEntryPointGroup(), DG); + } + + /// Check that there are still submodules to split. + bool hasMoreSplits() const { return Groups.size() > 0; } +}; + +EntryPointGroupVec +selectEntryPointGroups(const Module &M, + function_ref(const Function &F)> FC) { + // std::map is used here to ensure stable ordering of entry point groups, + // which is based on their contents, this greatly helps LIT tests + std::map EntryPointsMap; + + for (const auto &F : M.functions()) + if (std::optional Category = FC(F); Category) + EntryPointsMap[*Category].insert(&F); + + EntryPointGroupVec Groups; + Groups.reserve(EntryPointsMap.size()); + // Start with properties of a source module + for (auto &[Key, EntryPoints] : EntryPointsMap) + Groups.emplace_back(Key, std::move(EntryPoints)); + + return Groups; +} + +} // namespace + +void llvm::splitModuleByCategory( + std::unique_ptr M, + function_ref(const Function &F)> FunctionCategorizer, + function_ref Part)> Callback) { + EntryPointGroupVec Groups = selectEntryPointGroups(*M, FunctionCategorizer); + ModuleDesc MD = std::move(M); + ModuleSplitter Splitter(std::move(MD), std::move(Groups)); + while (Splitter.hasMoreSplits()) { + ModuleDesc MD = Splitter.getNextSplit(); + Callback(std::move(MD.releaseModule())); + } +} diff --git a/llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll b/llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll new file mode 100644 index 0000000000000..41a4674118267 --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll @@ -0,0 +1,17 @@ +; -- Per-kernel split +; RUN: llvm-split -split-by-category=kernel -S < %s -o %tC +; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 +; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 + +define dso_local amdgpu_kernel void @KernelA() { + ret void +} + +define dso_local amdgpu_kernel void @KernelB() { + ret void +} + +; CHECK-A0: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A0-NOT: define dso_local amdgpu_kernel void @KernelA() +; CHECK-A1-NOT: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A1: define dso_local amdgpu_kernel void @KernelA() diff --git a/llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll b/llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll new file mode 100644 index 0000000000000..80123d4dd8fb7 --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll @@ -0,0 +1,75 @@ +; Check that Module splitting can trace through more complex call stacks +; involving several nested indirect calls. + +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B --implicit-check-not @baz +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; CHECK0-DAG: define spir_kernel void @kernel_C +; CHECK0-DAG: define spir_func i32 @bar +; CHECK0-DAG: define spir_func void @baz +; CHECK0-DAG: define spir_func void @BAZ + +; CHECK1-DAG: define spir_kernel void @kernel_B +; CHECK1-DAG: define {{.*}}spir_func i32 @foo +; CHECK1-DAG: define spir_func i32 @bar +; CHECK1-DAG: define spir_func void @baz +; CHECK1-DAG: define spir_func void @BAZ + +; CHECK2-DAG: define spir_kernel void @kernel_A +; CHECK2-DAG: define {{.*}}spir_func void @baz + +define spir_func i32 @foo(i32 (i32, void ()*)* %ptr1, void ()* %ptr2) { + %1 = call spir_func i32 %ptr1(i32 42, void ()* %ptr2) + ret i32 %1 +} + +define spir_func i32 @bar(i32 %arg, void ()* %ptr) { + call spir_func void %ptr() + ret i32 %arg +} + +define spir_func void @baz() { + ret void +} + +define spir_func void @BAZ() { + ret void +} + +define spir_kernel void @kernel_A() #0 { + call spir_func void @baz() + ret void +} + +define spir_kernel void @kernel_B() #1 { + call spir_func i32 @foo(i32 (i32, void ()*)* null, void ()* null) + ret void +} + +define spir_kernel void @kernel_C() #2 { + call spir_func i32 @bar(i32 42, void ()* null) + ret void +} + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } +attributes #2 = { "module-id"="TU3.cpp" } diff --git a/llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll b/llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll new file mode 100644 index 0000000000000..316500a4c7611 --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll @@ -0,0 +1,38 @@ +; This test checks that Module splitting can properly perform device code split by tracking +; all uses of functions (not only direct calls). + +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix=CHECK-IR0 +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix=CHECK-IR1 + +; CHECK-IR0: define dso_local spir_kernel void @kernelA +; +; CHECK-IR1: @FuncTable = weak global ptr @func +; CHECK-IR1: define {{.*}} i32 @func +; CHECK-IR1: define weak_odr dso_local spir_kernel void @kernelB + +@FuncTable = weak global ptr @func, align 8 + +define dso_local spir_func i32 @func(i32 %a) { +entry: + ret i32 %a +} + +define weak_odr dso_local spir_kernel void @kernelB() #0 { +entry: + %0 = call i32 @indirect_call(ptr addrspace(4) addrspacecast ( ptr getelementptr inbounds ( [1 x ptr] , ptr @FuncTable, i64 0, i64 0) to ptr addrspace(4)), i32 0) + ret void +} + +define dso_local spir_kernel void @kernelA() #1 { +entry: + ret void +} + +declare dso_local spir_func i32 @indirect_call(ptr addrspace(4), i32) local_unnamed_addr + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } + +; CHECK: kernel1 +; CHECK: kernel2 diff --git a/llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll b/llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll new file mode 100644 index 0000000000000..c95563016ba67 --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll @@ -0,0 +1,93 @@ +; Test checks "kernel" splitting mode. + +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t.files +; RUN: FileCheck %s -input-file=%t.files_0.ll --check-prefixes CHECK-MODULE0,CHECK +; RUN: FileCheck %s -input-file=%t.files_1.ll --check-prefixes CHECK-MODULE1,CHECK +; RUN: FileCheck %s -input-file=%t.files_2.ll --check-prefixes CHECK-MODULE2,CHECK + +;CHECK-MODULE0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +;CHECK-MODULE1-NOT: @GV +;CHECK-MODULE2-NOT: @GV +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE2: define dso_local spir_kernel void @TU0_kernelA +define dso_local spir_kernel void @TU0_kernelA() #0 { +entry: +; CHECK-MODULE2: call spir_func void @foo() + call spir_func void @foo() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE2: define {{.*}} spir_func void @foo() +define dso_local spir_func void @foo() { +entry: +; CHECK-MODULE2: call spir_func void @bar() + call spir_func void @bar() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE2: define {{.*}} spir_func void @bar() +define linkonce_odr dso_local spir_func void @bar() { +entry: + ret void +} + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE1: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU0_kernelB() +define dso_local spir_kernel void @TU0_kernelB() #0 { +entry: +; CHECK-MODULE1: call spir_func void @foo1() + call spir_func void @foo1() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo1() +; CHECK-MODULE1: define {{.*}} spir_func void @foo1() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo1() +define dso_local spir_func void @foo1() { +entry: + ret void +} + +; CHECK-MODULE0: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU1_kernel() +define dso_local spir_kernel void @TU1_kernel() #1 { +entry: +; CHECK-MODULE0: call spir_func void @foo2() + call spir_func void @foo2() + ret void +} + +; CHECK-MODULE0: define {{.*}} spir_func void @foo2() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo2() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo2() +define dso_local spir_func void @foo2() { +entry: +; CHECK-MODULE0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK; !0 = !{i32 1, i32 2} +; CHECK; !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll b/llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll new file mode 100644 index 0000000000000..efd1e04f22c8c --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll @@ -0,0 +1,17 @@ +; -- Per-kernel split +; RUN: llvm-split -split-by-category=kernel -S < %s -o %tC +; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 +; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 + +define dso_local ptx_kernel void @KernelA() { + ret void +} + +define dso_local ptx_kernel void @KernelB() { + ret void +} + +; CHECK-A0: define dso_local ptx_kernel void @KernelB() +; CHECK-A0-NOT: define dso_local ptx_kernel void @KernelA() +; CHECK-A1-NOT: define dso_local ptx_kernel void @KernelB() +; CHECK-A1: define dso_local ptx_kernel void @KernelA() diff --git a/llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll b/llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll new file mode 100644 index 0000000000000..54485b7b7f348 --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll @@ -0,0 +1,88 @@ +; Test checks that kernels are being split by attached module-id metadata and +; used functions are being moved with kernels that use them. + +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefixes CHECK-TU0,CHECK +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefixes CHECK-TU1,CHECK + +; CHECK-TU1-NOT: @GV +; CHECK-TU0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelA +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelA +define dso_local spir_kernel void @TU1_kernelA() #0 { +entry: +; CHECK-TU1: call spir_func void @func1_TU1() + call spir_func void @func1_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func1_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func1_TU1() +define dso_local spir_func void @func1_TU1() { +entry: +; CHECK-TU1: call spir_func void @func2_TU1() + call spir_func void @func2_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func2_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func2_TU1() +define linkonce_odr dso_local spir_func void @func2_TU1() { +entry: + ret void +} + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelB() +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelB() +define dso_local spir_kernel void @TU1_kernelB() #0 { +entry: +; CHECK-TU1: call spir_func void @func3_TU1() + call spir_func void @func3_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func3_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func3_TU1() +define dso_local spir_func void @func3_TU1() { +entry: + ret void +} + +; CHECK-TU0-TXT: TU0_kernel +; CHECK-TU1-TXT-NOT: TU0_kernel + +; CHECK-TU0: define dso_local spir_kernel void @TU0_kernel() +; CHECK-TU1-NOT: define dso_local spir_kernel void @TU0_kernel() +define dso_local spir_kernel void @TU0_kernel() #1 { +entry: +; CHECK-TU0: call spir_func void @func_TU0() + call spir_func void @func_TU0() + ret void +} + +; CHECK-TU0: define {{.*}} spir_func void @func_TU0() +; CHECK-TU1-NOT: define {{.*}} spir_func void @func_TU0() +define dso_local spir_func void @func_TU0() { +entry: +; CHECK-TU0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK: !0 = !{i32 1, i32 2} +; CHECK: !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll b/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll new file mode 100644 index 0000000000000..0c1bd8b5c5fba --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll @@ -0,0 +1,52 @@ +; The test checks that Module splitting does not treat declarations as entry points. + +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t1 +; RUN: FileCheck %s -input-file=%t1_0.ll --check-prefix CHECK-MODULE-ID0 +; RUN: FileCheck %s -input-file=%t1_1.ll --check-prefix CHECK-MODULE-ID1 + +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t2 +; RUN: FileCheck %s -input-file=%t2_0.ll --check-prefix CHECK-PER-KERNEL0 +; RUN: FileCheck %s -input-file=%t2_1.ll --check-prefix CHECK-PER-KERNEL1 +; RUN: FileCheck %s -input-file=%t2_2.ll --check-prefix CHECK-PER-KERNEL2 + +; With module-id split, there should be two modules +; CHECK-MODULE-ID0-NOT: TU0 +; CHECK-MODULE-ID0-NOT: TU1_kernel1 +; CHECK-MODULE-ID0: TU1_kernel0 +; +; CHECK-MODULE-ID1-NOT: TU1 +; CHECK-MODULE-ID1: TU0_kernel0 +; CHECK-MODULE-ID1: TU0_kernel1 + +; With per-kernel split, there should be three modules. +; CHECK-PER-KERNEL0-NOT: TU0 +; CHECK-PER-KERNEL0-NOT: TU1_kernel1 +; CHECK-PER-KERNEL0: TU1_kernel0 +; +; CHECK-PER-KERNEL1-NOT: TU0_kernel0 +; CHECK-PER-KERNEL1-NOT: TU1 +; CHECK-PER-KERNEL1: TU0_kernel1 +; +; CHECK-PER-KERNEL2-NOT: TU0_kernel1 +; CHECK-PER-KERNEL2-NOT: TU1 +; CHECK-PER-KERNEL2: TU0_kernel0 + + +define spir_kernel void @TU0_kernel0() #0 { +entry: + ret void +} + +define spir_kernel void @TU0_kernel1() #0 { +entry: + ret void +} + +define spir_kernel void @TU1_kernel0() #1 { + ret void +} + +declare spir_kernel void @TU1_kernel1() #1 + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } diff --git a/llvm/tools/llvm-split/CMakeLists.txt b/llvm/tools/llvm-split/CMakeLists.txt index 1104e3145952c..b755755a984fc 100644 --- a/llvm/tools/llvm-split/CMakeLists.txt +++ b/llvm/tools/llvm-split/CMakeLists.txt @@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS Support Target TargetParser + ipo ) add_llvm_tool(llvm-split diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 9f6678a1fa466..8e53647bc296e 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -11,14 +11,19 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" @@ -27,7 +32,9 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/Transforms/Utils/SplitModule.h" +#include "llvm/Transforms/Utils/SplitModuleByCategory.h" using namespace llvm; @@ -70,6 +77,163 @@ static cl::opt MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"), cl::value_desc("cpu"), cl::cat(SplitCategory)); +enum class SplitByCategoryType { + SBCT_ByModuleId, + SBCT_ByKernel, + SBCT_None, +}; + +static cl::opt SplitByCategory( + "split-by-category", + cl::desc("Split by category. If present, splitting by category is used " + "with the specified categorization type."), + cl::Optional, cl::init(SplitByCategoryType::SBCT_None), + cl::values(clEnumValN(SplitByCategoryType::SBCT_ByModuleId, "module-id", + "one output module per translation unit marked with " + "\"module-id\" attribute"), + clEnumValN(SplitByCategoryType::SBCT_ByKernel, "kernel", + "one output module per kernel")), + cl::cat(SplitCategory)); + +static cl::opt OutputAssembly{ + "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)}; + +void writeStringToFile(StringRef Content, StringRef Path) { + std::error_code EC; + raw_fd_ostream OS(Path, EC); + if (EC) { + errs() << formatv("error opening file: {0}, error: {1}\n", Path, + EC.message()); + exit(1); + } + + OS << Content << "\n"; +} + +void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { + int FD = -1; + if (std::error_code EC = sys::fs::openFileForWrite(Path, FD)) { + errs() << formatv("error opening file: {0}, error: {1}", Path, EC.message()) + << '\n'; + exit(1); + } + + raw_fd_ostream OS(FD, /*ShouldClose*/ true); + if (OutputAssembly) + M.print(OS, /*AssemblyAnnotationWriter*/ nullptr); + else + WriteBitcodeToFile(M, OS); +} + +/// FunctionCategorizer is used for splitting by category either by module-id or +/// by kernels. It doesn't provide categories for functions other than kernels. +/// Categorizer computes a string key for the given Function and records the +/// association between the string key and an integer category. If a string key +/// is already belongs to some category than the corresponding integer category +/// is returned. +class FunctionCategorizer { +public: + FunctionCategorizer(SplitByCategoryType Type) : Type(Type) {} + + FunctionCategorizer() = delete; + FunctionCategorizer(FunctionCategorizer &) = delete; + FunctionCategorizer &operator=(const FunctionCategorizer &) = delete; + FunctionCategorizer(FunctionCategorizer &&) = default; + FunctionCategorizer &operator=(FunctionCategorizer &&) = default; + + /// Returns integer specifying the category for the given \p F. + /// If the given function isn't a kernel then returns std::nullopt. + std::optional operator()(const Function &F) { + if (!isEntryPoint(F)) + return std::nullopt; // skip the function. + + auto StringKey = computeFunctionCategory(Type, F); + if (auto it = StrKeyToID.find(StringRef(StringKey)); it != StrKeyToID.end()) + return it->second; + + int ID = static_cast(StrKeyToID.size()); + return StrKeyToID.try_emplace(std::move(StringKey), ID).first->second; + } + +private: + static bool isEntryPoint(const Function &F) { + if (F.isDeclaration()) + return false; + + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; + } + + static SmallString<0> computeFunctionCategory(SplitByCategoryType Type, + const Function &F) { + static constexpr char ATTR_MODULE_ID[] = "module-id"; + SmallString<0> Key; + switch (Type) { + case SplitByCategoryType::SBCT_ByKernel: + Key = F.getName().str(); + break; + case SplitByCategoryType::SBCT_ByModuleId: + Key = F.getFnAttribute(ATTR_MODULE_ID).getValueAsString().str(); + break; + default: + llvm_unreachable("unexpected mode."); + } + + return Key; + } + +private: + struct KeyInfo { + static SmallString<0> getEmptyKey() { return SmallString<0>(""); } + + static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); } + + static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) { + return LHS == RHS; + } + + static unsigned getHashValue(const SmallString<0> &S) { + return llvm::hash_value(StringRef(S)); + } + }; + + SplitByCategoryType Type; + DenseMap, int, KeyInfo> StrKeyToID; +}; + +void cleanupModule(Module &M) { + ModuleAnalysisManager MAM; + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + ModulePassManager MPM; + MPM.addPass(GlobalDCEPass()); // Delete unreachable globals. + MPM.run(M, MAM); +} + +Error runSplitModuleByCategory(std::unique_ptr M) { + size_t OutputID = 0; + auto PostSplitCallback = [&](std::unique_ptr MPart) { + if (verifyModule(*MPart)) { + errs() << "Broken Module!\n"; + exit(1); + } + + // TODO: DCE is a crucial pass since it removes unused declarations. + // At the moment, LIT checking can't be perfomed without DCE. + cleanupModule(*MPart); + size_t ID = OutputID; + ++OutputID; + StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc"; + std::string ModulePath = + (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str(); + writeModuleToFile(*MPart, ModulePath, OutputAssembly); + }; + + auto Categorizer = FunctionCategorizer(SplitByCategory); + splitModuleByCategory(std::move(M), Categorizer, PostSplitCallback); + return Error::success(); +} + int main(int argc, char **argv) { InitLLVM X(argc, argv); @@ -123,6 +287,17 @@ int main(int argc, char **argv) { Out->keep(); }; + if (SplitByCategory != SplitByCategoryType::SBCT_None) { + auto E = runSplitModuleByCategory(std::move(M)); + if (E) { + errs() << E << "\n"; + Err.print(argv[0], errs()); + return 1; + } + + return 0; + } + if (TM) { if (PreserveLocals) { errs() << "warning: --preserve-locals has no effect when using "