From ede5e8c9b836302a5e3bd4e2d5f154d775184d1d Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Fri, 14 Mar 2025 07:24:46 -0700 Subject: [PATCH 1/9] [offload][SYCL] Add SYCL Module splitting. This patch adds SYCL Module splitting - the necessary step in the SYCL compilation pipeline. Only 2 splitting modes are being added in this patch: by kernel and by source. The previous attempt was at #119713. In this patch there is no dependency in `TransformUtils` on "IPO" and on "Printing Passes". In this patch a module splitting is self-contained and it doesn't introduce linking issues. --- .../llvm/Transforms/Utils/SYCLSplitModule.h | 64 +++ .../include/llvm/Transforms/Utils/SYCLUtils.h | 26 ++ llvm/lib/Transforms/Utils/CMakeLists.txt | 2 + llvm/lib/Transforms/Utils/SYCLSplitModule.cpp | 401 ++++++++++++++++++ llvm/lib/Transforms/Utils/SYCLUtils.cpp | 26 ++ .../device-code-split/amd-kernel-split.ll | 17 + .../complex-indirect-call-chain.ll | 75 ++++ .../module-split-func-ptr.ll | 43 ++ .../one-kernel-per-module.ll | 108 +++++ .../SYCL/device-code-split/split-by-source.ll | 97 +++++ .../split-with-kernel-declarations.ll | 66 +++ llvm/tools/llvm-split/CMakeLists.txt | 1 + llvm/tools/llvm-split/llvm-split.cpp | 121 ++++++ 13 files changed, 1047 insertions(+) create mode 100644 llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h create mode 100644 llvm/include/llvm/Transforms/Utils/SYCLUtils.h create mode 100644 llvm/lib/Transforms/Utils/SYCLSplitModule.cpp create mode 100644 llvm/lib/Transforms/Utils/SYCLUtils.cpp create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll diff --git a/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h b/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h new file mode 100644 index 0000000000000..a3425d19b9c4b --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h @@ -0,0 +1,64 @@ +//===-------- SYCLSplitModule.h - module split ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Functionality to split a module into callgraphs. A callgraph here is a set +// of entry points with all functions reachable from them via a call. The result +// of the split is new modules containing corresponding callgraph. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#define LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" + +#include +#include +#include + +namespace llvm { + +class Module; + +enum class IRSplitMode { + IRSM_PER_TU, // one module per translation unit + IRSM_PER_KERNEL, // one module per kernel + IRSM_NONE // no splitting +}; + +/// \returns IRSplitMode value if \p S is recognized. Otherwise, std::nullopt is +/// returned. +std::optional convertStringToSplitMode(StringRef S); + +/// The structure represents a split LLVM Module accompanied by additional +/// information. Split Modules are being stored at disk due to the high RAM +/// consumption during the whole splitting process. +struct ModuleAndSYCLMetadata { + std::string ModuleFilePath; + std::string Symbols; + + ModuleAndSYCLMetadata() = default; + ModuleAndSYCLMetadata(const ModuleAndSYCLMetadata &) = default; + ModuleAndSYCLMetadata &operator=(const ModuleAndSYCLMetadata &) = default; + ModuleAndSYCLMetadata(ModuleAndSYCLMetadata &&) = default; + ModuleAndSYCLMetadata &operator=(ModuleAndSYCLMetadata &&) = default; + + ModuleAndSYCLMetadata(std::string_view File, std::string Symbols) + : ModuleFilePath(File), Symbols(std::move(Symbols)) {} +}; + +using PostSYCLSplitCallbackType = + function_ref Part, std::string Symbols)>; + +/// Splits the given module \p M according to the given \p Settings. +/// Every split image is being passed to \p Callback. +void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, + PostSYCLSplitCallbackType Callback); + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H diff --git a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h new file mode 100644 index 0000000000000..75459eed6ac0f --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h @@ -0,0 +1,26 @@ +//===------------ SYCLUtils.h - SYCL utility functions --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Utility functions for SYCL. +//===----------------------------------------------------------------------===// +#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H +#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H + +#include +#include + +namespace llvm { + +class raw_ostream; + +using SYCLStringTable = SmallVector>>; + +void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS); + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 78cad0d253be8..0ba46bdadea8d 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -83,6 +83,8 @@ add_llvm_component_library(LLVMTransformUtils SizeOpts.cpp SplitModule.cpp StripNonLineTableDebugInfo.cpp + SYCLSplitModule.cpp + SYCLUtils.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp UnifyLoopExits.cpp diff --git a/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp b/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp new file mode 100644 index 0000000000000..18eca4237c8ae --- /dev/null +++ b/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp @@ -0,0 +1,401 @@ +//===-------- SYCLSplitModule.cpp - Split a module into call graphs -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// See comments in the header. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" + +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "sycl-split-module" + +static bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL; +} + +static bool isEntryPoint(const Function &F) { + // Skip declarations, if any: they should not be included into a vector of + // entry points groups or otherwise we will end up with incorrectly generated + // list of symbols. + if (F.isDeclaration()) + return false; + + // Kernels are always considered to be entry points + return isKernel(F); +} + +namespace { + +// A vector that contains all entry point functions in a split module. +using EntryPointSet = SetVector; + +/// Represents a named group entry points. +struct EntryPointGroup { + std::string GroupName; + EntryPointSet Functions; + + EntryPointGroup() = default; + EntryPointGroup(const EntryPointGroup &) = default; + EntryPointGroup &operator=(const EntryPointGroup &) = default; + EntryPointGroup(EntryPointGroup &&) = default; + EntryPointGroup &operator=(EntryPointGroup &&) = default; + + EntryPointGroup(StringRef GroupName, + EntryPointSet Functions = EntryPointSet()) + : GroupName(GroupName), Functions(std::move(Functions)) {} + + void clear() { + GroupName.clear(); + Functions.clear(); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + constexpr size_t INDENT = 4; + dbgs().indent(INDENT) << "ENTRY POINTS" + << " " << GroupName << " {\n"; + for (const Function *F : Functions) + dbgs().indent(INDENT) << " " << F->getName() << "\n"; + + dbgs().indent(INDENT) << "}\n"; + } +#endif +}; + +/// Annotates an llvm::Module with information necessary to perform and track +/// the result of device code (llvm::Module instances) splitting: +/// - entry points group from the module. +class ModuleDesc { + std::unique_ptr M; + EntryPointGroup EntryPoints; + +public: + ModuleDesc() = delete; + ModuleDesc(const ModuleDesc &) = delete; + ModuleDesc &operator=(const ModuleDesc &) = delete; + ModuleDesc(ModuleDesc &&) = default; + ModuleDesc &operator=(ModuleDesc &&) = default; + + ModuleDesc(std::unique_ptr M, + EntryPointGroup EntryPoints = EntryPointGroup()) + : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { + assert(this->M && "Module should be non-null"); + } + + Module &getModule() { return *M; } + const Module &getModule() const { return *M; } + + std::unique_ptr releaseModule() { + EntryPoints.clear(); + return std::move(M); + } + + std::string makeSymbolTable() const { + SmallString<0> Data; + raw_svector_ostream OS(Data); + for (const Function *F : EntryPoints.Functions) + OS << F->getName() << '\n'; + + return std::string(OS.str()); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; + EntryPoints.dump(); + dbgs() << "}\n"; + } +#endif +}; + +// Represents "dependency" or "use" graph of global objects (functions and +// global variables) in a module. It is used during device code split to +// understand which global variables and functions (other than entry points) +// should be included into a split module. +// +// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent +// the fact that if "A" is included into a module, then "B" should be included +// as well. +// +// Examples of dependencies which are represented in this graph: +// - Function FA calls function FB +// - Function FA uses global variable GA +// - Global variable GA references (initialized with) function FB +// - Function FA stores address of a function FB somewhere +// +// The following cases are treated as dependencies between global objects: +// 1. Global object A is used within by a global object B in any way (store, +// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the +// graph; +// 2. function A performs an indirect call of a function with signature S and +// there is a function B with signature S. "A" -> "B" edge will be added to +// the graph; +class DependencyGraph { +public: + using GlobalSet = SmallPtrSet; + + DependencyGraph(const Module &M) { + // Group functions by their signature to handle case (2) described above + DenseMap + FuncTypeToFuncsMap; + for (const auto &F : M.functions()) { + // Kernels can't be called (either directly or indirectly) in SYCL + if (isKernel(F)) + continue; + + FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); + } + + for (const auto &F : M.functions()) { + // case (1), see comment above the class definition + for (const Value *U : F.users()) + addUserToGraphRecursively(cast(U), &F); + + // case (2), see comment above the class definition + for (const auto &I : instructions(F)) { + const auto *CI = dyn_cast(&I); + if (!CI || !CI->isIndirectCall()) // Direct calls were handled above + continue; + + const FunctionType *Signature = CI->getFunctionType(); + const auto &PotentialCallees = FuncTypeToFuncsMap[Signature]; + Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end()); + } + } + + // And every global variable (but their handling is a bit simpler) + for (const auto &GV : M.globals()) + for (const Value *U : GV.users()) + addUserToGraphRecursively(cast(U), &GV); + } + + iterator_range + dependencies(const GlobalValue *Val) const { + auto It = Graph.find(Val); + return (It == Graph.end()) + ? make_range(EmptySet.begin(), EmptySet.end()) + : make_range(It->second.begin(), It->second.end()); + } + +private: + void addUserToGraphRecursively(const User *Root, const GlobalValue *V) { + SmallVector WorkList; + WorkList.push_back(Root); + + while (!WorkList.empty()) { + const User *U = WorkList.pop_back_val(); + if (const auto *I = dyn_cast(U)) { + const auto *UFunc = I->getFunction(); + Graph[UFunc].insert(V); + } else if (isa(U)) { + if (const auto *GV = dyn_cast(U)) + Graph[GV].insert(V); + // This could be a global variable or some constant expression (like + // bitcast or gep). We trace users of this constant further to reach + // global objects they are used by and add them to the graph. + for (const auto *UU : U->users()) + WorkList.push_back(UU); + } else + llvm_unreachable("Unhandled type of function user"); + } + } + + DenseMap Graph; + SmallPtrSet EmptySet; +}; + +void collectFunctionsAndGlobalVariablesToExtract( + SetVector &GVs, const Module &M, + const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { + // We start with module entry points + for (const auto *F : ModuleEntryPoints.Functions) + GVs.insert(F); + + // Non-discardable global variables are also include into the initial set + for (const auto &GV : M.globals()) + if (!GV.isDiscardableIfUnused()) + GVs.insert(&GV); + + // GVs has SetVector type. This type inserts a value only if it is not yet + // present there. So, recursion is not expected here. + size_t Idx = 0; + while (Idx < GVs.size()) { + const GlobalValue *Obj = GVs[Idx++]; + + for (const GlobalValue *Dep : DG.dependencies(Obj)) { + if (const auto *Func = dyn_cast(Dep)) { + if (!Func->isDeclaration()) + GVs.insert(Func); + } else + GVs.insert(Dep); // Global variables are added unconditionally + } + } +} + +ModuleDesc extractSubModule(const Module &M, + const SetVector &GVs, + EntryPointGroup ModuleEntryPoints) { + // For each group of entry points collect all dependencies. + ValueToValueMapTy VMap; + // Clone definitions only for needed globals. Others will be added as + // declarations and removed later. + std::unique_ptr SubM = CloneModule( + M, VMap, [&](const GlobalValue *GV) { return GVs.count(GV); }); + // Replace entry points with cloned ones. + EntryPointSet NewEPs; + const EntryPointSet &EPs = ModuleEntryPoints.Functions; + std::for_each(EPs.begin(), EPs.end(), [&](const Function *F) { + NewEPs.insert(cast(VMap[F])); + }); + ModuleEntryPoints.Functions = std::move(NewEPs); + return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; +} + +// The function produces a copy of input LLVM IR module M with only those +// functions and globals that can be called from entry points that are specified +// in ModuleEntryPoints vector, in addition to the entry point functions. +ModuleDesc extractCallGraph(const Module &M, EntryPointGroup ModuleEntryPoints, + const DependencyGraph &DG) { + SetVector GVs; + collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG); + + ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints)); + LLVM_DEBUG(SplitM.dump()); + return SplitM; +} + +using EntryPointGroupVec = SmallVector; + +/// Module Splitter. +/// It gets a module (in a form of module descriptor, to get additional info) +/// and a collection of entry points groups. Each group specifies subset entry +/// points from input module that should be included in a split module. +class ModuleSplitter { +private: + ModuleDesc Input; + EntryPointGroupVec Groups; + DependencyGraph DG; + +private: + EntryPointGroup drawEntryPointGroup() { + assert(Groups.size() > 0 && "Reached end of entry point groups list."); + EntryPointGroup Group = std::move(Groups.back()); + Groups.pop_back(); + return Group; + } + +public: + ModuleSplitter(ModuleDesc MD, EntryPointGroupVec GroupVec) + : Input(std::move(MD)), Groups(std::move(GroupVec)), + DG(Input.getModule()) { + assert(!Groups.empty() && "Entry points groups collection is empty!"); + } + + /// Gets next subsequence of entry points in an input module and provides + /// split submodule containing these entry points and their dependencies. + ModuleDesc getNextSplit() { + return extractCallGraph(Input.getModule(), drawEntryPointGroup(), DG); + } + + /// Check that there are still submodules to split. + bool hasMoreSplits() const { return Groups.size() > 0; } +}; + +} // namespace + +static EntryPointGroupVec selectEntryPointGroups(const Module &M, + IRSplitMode Mode) { + // std::map is used here to ensure stable ordering of entry point groups, + // which is based on their contents, this greatly helps LIT tests + std::map EntryPointsMap; + + static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; + for (const auto &F : M.functions()) { + if (!isEntryPoint(F)) + continue; + + std::string Key; + switch (Mode) { + case IRSplitMode::IRSM_PER_KERNEL: + Key = F.getName(); + break; + case IRSplitMode::IRSM_PER_TU: + Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString(); + break; + case IRSplitMode::IRSM_NONE: + llvm_unreachable(""); + } + + EntryPointsMap[Key].insert(&F); + } + + EntryPointGroupVec Groups; + if (EntryPointsMap.empty()) { + // No entry points met, record this. + Groups.emplace_back("-", EntryPointSet()); + } else { + Groups.reserve(EntryPointsMap.size()); + // Start with properties of a source module + for (auto &[Key, EntryPoints] : EntryPointsMap) + Groups.emplace_back(Key, std::move(EntryPoints)); + } + + return Groups; +} + +namespace llvm { + +std::optional convertStringToSplitMode(StringRef S) { + static const StringMap Values = { + {"source", IRSplitMode::IRSM_PER_TU}, + {"kernel", IRSplitMode::IRSM_PER_KERNEL}, + {"none", IRSplitMode::IRSM_NONE}}; + + auto It = Values.find(S); + if (It == Values.end()) + return std::nullopt; + + return It->second; +} + +void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, + PostSYCLSplitCallbackType Callback) { + SmallVector OutputImages; + if (Mode == IRSplitMode::IRSM_NONE) { + auto MD = ModuleDesc(std::move(M)); + auto Symbols = MD.makeSymbolTable(); + Callback(std::move(MD.releaseModule()), std::move(Symbols)); + return; + } + + EntryPointGroupVec Groups = selectEntryPointGroups(*M, Mode); + ModuleDesc MD = std::move(M); + ModuleSplitter Splitter(std::move(MD), std::move(Groups)); + while (Splitter.hasMoreSplits()) { + ModuleDesc MD = Splitter.getNextSplit(); + auto Symbols = MD.makeSymbolTable(); + Callback(std::move(MD.releaseModule()), std::move(Symbols)); + } +} + +} // namespace llvm diff --git a/llvm/lib/Transforms/Utils/SYCLUtils.cpp b/llvm/lib/Transforms/Utils/SYCLUtils.cpp new file mode 100644 index 0000000000000..ad9864fadb828 --- /dev/null +++ b/llvm/lib/Transforms/Utils/SYCLUtils.cpp @@ -0,0 +1,26 @@ +//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// SYCL utility functions. +//===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/SYCLUtils.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { + +void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS) { + assert(!Table.empty() && "table should contain at least column titles"); + assert(!Table[0].empty() && "table should be non-empty"); + OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n"; + for (size_t I = 1, E = Table.size(); I != E; ++I) { + assert(Table[I].size() == Table[0].size() && "row's size should be equal"); + OS << join(Table[I].begin(), Table[I].end(), "|") << '\n'; + } +} + +} // namespace llvm diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll new file mode 100644 index 0000000000000..a40a52107fb0c --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll @@ -0,0 +1,17 @@ +; -- Per-kernel split +; RUN: llvm-split -sycl-split=kernel -S < %s -o %tC +; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 +; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 + +define dso_local amdgpu_kernel void @KernelA() { + ret void +} + +define dso_local amdgpu_kernel void @KernelB() { + ret void +} + +; CHECK-A0: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A0-NOT: define dso_local amdgpu_kernel void @KernelA() +; CHECK-A1-NOT: define dso_local amdgpu_kernel void @KernelB() +; CHECK-A1: define dso_local amdgpu_kernel void @KernelA() diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll new file mode 100644 index 0000000000000..5a25e491b1b93 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll @@ -0,0 +1,75 @@ +; Check that Module splitting can trace through more complex call stacks +; involving several nested indirect calls. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B --implicit-check-not @baz +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ +; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ +; RUN: --implicit-check-not @kernel_B +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix CHECK1 \ +; RUN: --implicit-check-not @kernel_A --implicit-check-not @kernel_C +; RUN: FileCheck %s -input-file=%t_2.ll --check-prefix CHECK2 \ +; RUN: --implicit-check-not @foo --implicit-check-not @bar \ +; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ +; RUN: --implicit-check-not @kernel_C + +; CHECK0-DAG: define spir_kernel void @kernel_C +; CHECK0-DAG: define spir_func i32 @bar +; CHECK0-DAG: define spir_func void @baz +; CHECK0-DAG: define spir_func void @BAZ + +; CHECK1-DAG: define spir_kernel void @kernel_B +; CHECK1-DAG: define {{.*}}spir_func i32 @foo +; CHECK1-DAG: define spir_func i32 @bar +; CHECK1-DAG: define spir_func void @baz +; CHECK1-DAG: define spir_func void @BAZ + +; CHECK2-DAG: define spir_kernel void @kernel_A +; CHECK2-DAG: define {{.*}}spir_func void @baz + +define spir_func i32 @foo(i32 (i32, void ()*)* %ptr1, void ()* %ptr2) { + %1 = call spir_func i32 %ptr1(i32 42, void ()* %ptr2) + ret i32 %1 +} + +define spir_func i32 @bar(i32 %arg, void ()* %ptr) { + call spir_func void %ptr() + ret i32 %arg +} + +define spir_func void @baz() { + ret void +} + +define spir_func void @BAZ() { + ret void +} + +define spir_kernel void @kernel_A() #0 { + call spir_func void @baz() + ret void +} + +define spir_kernel void @kernel_B() #1 { + call spir_func i32 @foo(i32 (i32, void ()*)* null, void ()* null) + ret void +} + +define spir_kernel void @kernel_C() #2 { + call spir_func i32 @bar(i32 42, void ()* null) + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } +attributes #2 = { "sycl-module-id"="TU3.cpp" } diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll new file mode 100644 index 0000000000000..c9289d78b1fda --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll @@ -0,0 +1,43 @@ +; This test checks that Module splitting can properly perform device code split by tracking +; all uses of functions (not only direct calls). + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.sym --check-prefix=CHECK-SYM0 +; RUN: FileCheck %s -input-file=%t_1.sym --check-prefix=CHECK-SYM1 +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix=CHECK-IR0 +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix=CHECK-IR1 + +; CHECK-SYM0: kernelA +; CHECK-SYM1: kernelB +; +; CHECK-IR0: define dso_local spir_kernel void @kernelA +; +; CHECK-IR1: @FuncTable = weak global ptr @func +; CHECK-IR1: define {{.*}} i32 @func +; CHECK-IR1: define weak_odr dso_local spir_kernel void @kernelB + +@FuncTable = weak global ptr @func, align 8 + +define dso_local spir_func i32 @func(i32 %a) { +entry: + ret i32 %a +} + +define weak_odr dso_local spir_kernel void @kernelB() #0 { +entry: + %0 = call i32 @indirect_call(ptr addrspace(4) addrspacecast ( ptr getelementptr inbounds ( [1 x ptr] , ptr @FuncTable, i64 0, i64 0) to ptr addrspace(4)), i32 0) + ret void +} + +define dso_local spir_kernel void @kernelA() #1 { +entry: + ret void +} + +declare dso_local spir_func i32 @indirect_call(ptr addrspace(4), i32) local_unnamed_addr + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; CHECK: kernel1 +; CHECK: kernel2 diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll new file mode 100644 index 0000000000000..b949ab7530f39 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll @@ -0,0 +1,108 @@ +; Test checks "kernel" splitting mode. + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t.files +; RUN: FileCheck %s -input-file=%t.files_0.ll --check-prefixes CHECK-MODULE0,CHECK +; RUN: FileCheck %s -input-file=%t.files_0.sym --check-prefixes CHECK-MODULE0-TXT +; RUN: FileCheck %s -input-file=%t.files_1.ll --check-prefixes CHECK-MODULE1,CHECK +; RUN: FileCheck %s -input-file=%t.files_1.sym --check-prefixes CHECK-MODULE1-TXT +; RUN: FileCheck %s -input-file=%t.files_2.ll --check-prefixes CHECK-MODULE2,CHECK +; RUN: FileCheck %s -input-file=%t.files_2.sym --check-prefixes CHECK-MODULE2-TXT + +;CHECK-MODULE0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +;CHECK-MODULE1-NOT: @GV +;CHECK-MODULE2-NOT: @GV +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-MODULE0-TXT-NOT: T0_kernelA +; CHECK-MODULE1-TXT-NOT: TU0_kernelA +; CHECK-MODULE2-TXT: TU0_kernelA + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU0_kernelA +; CHECK-MODULE2: define dso_local spir_kernel void @TU0_kernelA +define dso_local spir_kernel void @TU0_kernelA() #0 { +entry: +; CHECK-MODULE2: call spir_func void @foo() + call spir_func void @foo() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo() +; CHECK-MODULE2: define {{.*}} spir_func void @foo() +define dso_local spir_func void @foo() { +entry: +; CHECK-MODULE2: call spir_func void @bar() + call spir_func void @bar() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @bar() +; CHECK-MODULE2: define {{.*}} spir_func void @bar() +define linkonce_odr dso_local spir_func void @bar() { +entry: + ret void +} + +; CHECK-MODULE0-TXT-NOT: TU0_kernelB +; CHECK-MODULE1-TXT: TU0_kernelB +; CHECK-MODULE2-TXT-NOT: TU0_kernelB + +; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE1: define dso_local spir_kernel void @TU0_kernelB() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU0_kernelB() +define dso_local spir_kernel void @TU0_kernelB() #0 { +entry: +; CHECK-MODULE1: call spir_func void @foo1() + call spir_func void @foo1() + ret void +} + +; CHECK-MODULE0-NOT: define {{.*}} spir_func void @foo1() +; CHECK-MODULE1: define {{.*}} spir_func void @foo1() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo1() +define dso_local spir_func void @foo1() { +entry: + ret void +} + +; CHECK-MODULE0-TXT: TU1_kernel +; CHECK-MODULE1-TXT-NOT: TU1_kernel +; CHECK-MODULE2-TXT-NOT: TU1_kernel + +; CHECK-MODULE0: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU1_kernel() +; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU1_kernel() +define dso_local spir_kernel void @TU1_kernel() #1 { +entry: +; CHECK-MODULE0: call spir_func void @foo2() + call spir_func void @foo2() + ret void +} + +; CHECK-MODULE0: define {{.*}} spir_func void @foo2() +; CHECK-MODULE1-NOT: define {{.*}} spir_func void @foo2() +; CHECK-MODULE2-NOT: define {{.*}} spir_func void @foo2() +define dso_local spir_func void @foo2() { +entry: +; CHECK-MODULE0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK; !0 = !{i32 1, i32 2} +; CHECK; !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll new file mode 100644 index 0000000000000..6a4e543209526 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll @@ -0,0 +1,97 @@ +; Test checks that kernels are being split by attached TU metadata and +; used functions are being moved with kernels that use them. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: FileCheck %s -input-file=%t_0.ll --check-prefixes CHECK-TU0,CHECK +; RUN: FileCheck %s -input-file=%t_1.ll --check-prefixes CHECK-TU1,CHECK +; RUN: FileCheck %s -input-file=%t_0.sym --check-prefixes CHECK-TU0-TXT +; RUN: FileCheck %s -input-file=%t_1.sym --check-prefixes CHECK-TU1-TXT + +; CHECK-TU1-NOT: @GV +; CHECK-TU0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 +@GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 + +; CHECK-TU0-TXT-NOT: TU1_kernelA +; CHECK-TU1-TXT: TU1_kernelA + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelA +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelA +define dso_local spir_kernel void @TU1_kernelA() #0 { +entry: +; CHECK-TU1: call spir_func void @func1_TU1() + call spir_func void @func1_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func1_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func1_TU1() +define dso_local spir_func void @func1_TU1() { +entry: +; CHECK-TU1: call spir_func void @func2_TU1() + call spir_func void @func2_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func2_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func2_TU1() +define linkonce_odr dso_local spir_func void @func2_TU1() { +entry: + ret void +} + + +; CHECK-TU0-TXT-NOT: TU1_kernelB +; CHECK-TU1-TXT: TU1_kernelB + +; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelB() +; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelB() +define dso_local spir_kernel void @TU1_kernelB() #0 { +entry: +; CHECK-TU1: call spir_func void @func3_TU1() + call spir_func void @func3_TU1() + ret void +} + +; CHECK-TU0-NOT: define {{.*}} spir_func void @func3_TU1() +; CHECK-TU1: define {{.*}} spir_func void @func3_TU1() +define dso_local spir_func void @func3_TU1() { +entry: + ret void +} + +; CHECK-TU0-TXT: TU0_kernel +; CHECK-TU1-TXT-NOT: TU0_kernel + +; CHECK-TU0: define dso_local spir_kernel void @TU0_kernel() +; CHECK-TU1-NOT: define dso_local spir_kernel void @TU0_kernel() +define dso_local spir_kernel void @TU0_kernel() #1 { +entry: +; CHECK-TU0: call spir_func void @func_TU0() + call spir_func void @func_TU0() + ret void +} + +; CHECK-TU0: define {{.*}} spir_func void @func_TU0() +; CHECK-TU1-NOT: define {{.*}} spir_func void @func_TU0() +define dso_local spir_func void @func_TU0() { +entry: +; CHECK-TU0: %0 = load i32, ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), align 4 + %0 = load i32, ptr addrspace(4) getelementptr inbounds ([1 x i32], ptr addrspace(4) addrspacecast (ptr addrspace(1) @GV to ptr addrspace(4)), i64 0, i64 0), align 4 + ret void +} + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } + +; Metadata is saved in both modules. +; CHECK: !opencl.spir.version = !{!0, !0} +; CHECK: !spirv.Source = !{!1, !1} + +!opencl.spir.version = !{!0, !0} +!spirv.Source = !{!1, !1} + +; CHECK: !0 = !{i32 1, i32 2} +; CHECK: !1 = !{i32 4, i32 100000} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll new file mode 100644 index 0000000000000..1f188d8e32db6 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll @@ -0,0 +1,66 @@ +; The test checks that Module splitting does not treat declarations as entry points. + +; RUN: llvm-split -sycl-split=source -S < %s -o %t1 +; RUN: FileCheck %s -input-file=%t1.table --check-prefix CHECK-PER-SOURCE-TABLE +; RUN: FileCheck %s -input-file=%t1_0.sym --check-prefix CHECK-PER-SOURCE-SYM0 +; RUN: FileCheck %s -input-file=%t1_1.sym --check-prefix CHECK-PER-SOURCE-SYM1 + +; RUN: llvm-split -sycl-split=kernel -S < %s -o %t2 +; RUN: FileCheck %s -input-file=%t2.table --check-prefix CHECK-PER-KERNEL-TABLE +; RUN: FileCheck %s -input-file=%t2_0.sym --check-prefix CHECK-PER-KERNEL-SYM0 +; RUN: FileCheck %s -input-file=%t2_1.sym --check-prefix CHECK-PER-KERNEL-SYM1 +; RUN: FileCheck %s -input-file=%t2_2.sym --check-prefix CHECK-PER-KERNEL-SYM2 + +; With per-source split, there should be two device images +; CHECK-PER-SOURCE-TABLE: [Code|Symbols] +; CHECK-PER-SOURCE-TABLE: {{.*}}_0.ll|{{.*}}_0.sym +; CHECK-PER-SOURCE-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym +; CHECK-PER-SOURCE-TABLE-EMPTY: +; +; CHECK-PER-SOURCE-SYM0-NOT: TU1_kernel1 +; CHECK-PER-SOURCE-SYM0: TU1_kernel0 +; CHECK-PER-SOURCE-SYM0-EMPTY: +; +; CHECK-PER-SOURCE-SYM1-NOT: TU1_kernel1 +; CHECK-PER-SOURCE-SYM1: TU0_kernel0 +; CHECK-PER-SOURCE-SYM1-NEXT: TU0_kernel1 +; CHECK-PER-SOURCE-SYM1-EMPTY: + +; With per-kernel split, there should be three device images +; CHECK-PER-KERNEL-TABLE: [Code|Symbols] +; CHECK-PER-KERNEL-TABLE: {{.*}}_0.ll|{{.*}}_0.sym +; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym +; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_2.ll|{{.*}}_2.sym +; CHECK-PER-KERNEL-TABLE-EMPTY: +; +; CHECK-PER-KERNEL-SYM0-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM0: TU1_kernel0 +; CHECK-PER-KERNEL-SYM0-EMPTY: +; +; CHECK-PER-KERNEL-SYM1-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM1: TU0_kernel1 +; CHECK-PER-KERNEL-SYM1-EMPTY: +; +; CHECK-PER-KERNEL-SYM2-NOT: TU1_kernel1 +; CHECK-PER-KERNEL-SYM2: TU0_kernel0 +; CHECK-PER-KERNEL-SYM2-EMPTY: + + +define spir_kernel void @TU0_kernel0() #0 { +entry: + ret void +} + +define spir_kernel void @TU0_kernel1() #0 { +entry: + ret void +} + +define spir_kernel void @TU1_kernel0() #1 { + ret void +} + +declare spir_kernel void @TU1_kernel1() #1 + +attributes #0 = { "sycl-module-id"="TU1.cpp" } +attributes #1 = { "sycl-module-id"="TU2.cpp" } diff --git a/llvm/tools/llvm-split/CMakeLists.txt b/llvm/tools/llvm-split/CMakeLists.txt index 1104e3145952c..b755755a984fc 100644 --- a/llvm/tools/llvm-split/CMakeLists.txt +++ b/llvm/tools/llvm-split/CMakeLists.txt @@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS Support Target TargetParser + ipo ) add_llvm_tool(llvm-split diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 9f6678a1fa466..f6e90985304d6 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -11,14 +11,19 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" @@ -27,6 +32,9 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" +#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" #include "llvm/Transforms/Utils/SplitModule.h" using namespace llvm; @@ -70,6 +78,108 @@ static cl::opt MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"), cl::value_desc("cpu"), cl::cat(SplitCategory)); +static cl::opt SYCLSplitMode( + "sycl-split", + cl::desc("SYCL Split Mode. If present, SYCL splitting algorithm is used " + "with the specified mode."), + cl::Optional, cl::init(IRSplitMode::IRSM_NONE), + cl::values(clEnumValN(IRSplitMode::IRSM_PER_TU, "source", + "1 ouptput module per translation unit"), + clEnumValN(IRSplitMode::IRSM_PER_KERNEL, "kernel", + "1 output module per kernel")), + cl::cat(SplitCategory)); + +static cl::opt OutputAssembly{ + "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)}; + +void writeStringToFile(StringRef Content, StringRef Path) { + std::error_code EC; + raw_fd_ostream OS(Path, EC); + if (EC) { + errs() << formatv("error opening file: {0}, error: {1}\n", Path, + EC.message()); + exit(1); + } + + OS << Content << "\n"; +} + +void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { + int FD = -1; + if (std::error_code EC = sys::fs::openFileForWrite(Path, FD)) { + errs() << formatv("error opening file: {0}, error: {1}", Path, EC.message()) + << '\n'; + exit(1); + } + + raw_fd_ostream OS(FD, /*ShouldClose*/ true); + if (OutputAssembly) + M.print(OS, /*AssemblyAnnotationWriter*/ nullptr); + else + WriteBitcodeToFile(M, OS); +} + +void writeSplitModulesAsTable(ArrayRef Modules, + StringRef Path) { + SmallVector> Columns; + Columns.emplace_back("Code"); + Columns.emplace_back("Symbols"); + + SYCLStringTable Table; + Table.emplace_back(std::move(Columns)); + for (const auto &[I, SM] : enumerate(Modules)) { + SmallString<128> SymbolsFile; + (Twine(Path) + "_" + Twine(I) + ".sym").toVector(SymbolsFile); + writeStringToFile(SM.Symbols, SymbolsFile); + SmallVector> Row; + Row.emplace_back(SM.ModuleFilePath); + Row.emplace_back(SymbolsFile); + Table.emplace_back(std::move(Row)); + } + + std::error_code EC; + raw_fd_ostream OS((Path + ".table").str(), EC); + if (EC) { + errs() << formatv("error opening file: {0}\n", Path); + exit(1); + } + + writeSYCLStringTable(Table, OS); +} + +void cleanupModule(Module &M) { + ModuleAnalysisManager MAM; + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + ModulePassManager MPM; + MPM.addPass(GlobalDCEPass()); // Delete unreachable globals. + MPM.run(M, MAM); +} + +Error runSYCLSplitModule(std::unique_ptr M) { + SmallVector SplitModules; + auto PostSYCLSplitCallback = [&](std::unique_ptr MPart, + std::string Symbols) { + if (verifyModule(*MPart)) { + errs() << "Broken Module!\n"; + exit(1); + } + + // TODO: DCE is a crucial pass in a SYCL post-link pipeline. + // At the moment, LIT checking can't be perfomed without DCE. + cleanupModule(*MPart); + size_t ID = SplitModules.size(); + StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc"; + std::string ModulePath = + (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str(); + writeModuleToFile(*MPart, ModulePath, OutputAssembly); + SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); + }; + + SYCLSplitModule(std::move(M), SYCLSplitMode, PostSYCLSplitCallback); + writeSplitModulesAsTable(SplitModules, OutputFilename); + return Error::success(); +} + int main(int argc, char **argv) { InitLLVM X(argc, argv); @@ -123,6 +233,17 @@ int main(int argc, char **argv) { Out->keep(); }; + if (SYCLSplitMode != IRSplitMode::IRSM_NONE) { + auto E = runSYCLSplitModule(std::move(M)); + if (E) { + errs() << E << "\n"; + Err.print(argv[0], errs()); + return 1; + } + + return 0; + } + if (TM) { if (PreserveLocals) { errs() << "warning: --preserve-locals has no effect when using " From c764d7f8bec51b8d7bc17ab99964c34542595826 Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Thu, 3 Apr 2025 08:07:10 -0700 Subject: [PATCH 2/9] Move SYCL Module splitting into llvm/Frontend/SYCL/. --- llvm/include/llvm/Frontend/SYCL/SplitModule.h | 39 +++++++ .../SYCL/Utils.h} | 45 ++++---- .../include/llvm/Transforms/Utils/SYCLUtils.h | 26 ----- llvm/lib/Frontend/CMakeLists.txt | 1 + llvm/lib/Frontend/SYCL/CMakeLists.txt | 13 +++ .../SYCL/SplitModule.cpp} | 106 ++++++------------ llvm/lib/Frontend/SYCL/Utils.cpp | 79 +++++++++++++ llvm/lib/Transforms/Utils/CMakeLists.txt | 2 - llvm/lib/Transforms/Utils/SYCLUtils.cpp | 26 ----- .../device-code-split/ptx-kernel-split.ll | 17 +++ llvm/tools/llvm-split/CMakeLists.txt | 1 + llvm/tools/llvm-split/llvm-split.cpp | 28 ++--- 12 files changed, 220 insertions(+), 163 deletions(-) create mode 100644 llvm/include/llvm/Frontend/SYCL/SplitModule.h rename llvm/include/llvm/{Transforms/Utils/SYCLSplitModule.h => Frontend/SYCL/Utils.h} (53%) delete mode 100644 llvm/include/llvm/Transforms/Utils/SYCLUtils.h create mode 100644 llvm/lib/Frontend/SYCL/CMakeLists.txt rename llvm/lib/{Transforms/Utils/SYCLSplitModule.cpp => Frontend/SYCL/SplitModule.cpp} (81%) create mode 100644 llvm/lib/Frontend/SYCL/Utils.cpp delete mode 100644 llvm/lib/Transforms/Utils/SYCLUtils.cpp create mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll diff --git a/llvm/include/llvm/Frontend/SYCL/SplitModule.h b/llvm/include/llvm/Frontend/SYCL/SplitModule.h new file mode 100644 index 0000000000000..2f332d045706f --- /dev/null +++ b/llvm/include/llvm/Frontend/SYCL/SplitModule.h @@ -0,0 +1,39 @@ +//===-------- SplitModule.h - module split ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Functionality to split a module by categories. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#define LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/Frontend/SYCL/Utils.h" + +#include +#include +#include + +namespace llvm { + +class Module; +class Function; + +namespace sycl { + +using PostSplitCallbackType = function_ref Part)>; + +/// Splits the given module \p M. +/// Every split image is being passed to \p Callback for further possible +/// processing. +void splitModule(std::unique_ptr M, IRSplitMode Mode, + PostSplitCallbackType Callback); + +} // namespace sycl +} // namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H diff --git a/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h b/llvm/include/llvm/Frontend/SYCL/Utils.h similarity index 53% rename from llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h rename to llvm/include/llvm/Frontend/SYCL/Utils.h index a3425d19b9c4b..6db1fb5a4715c 100644 --- a/llvm/include/llvm/Transforms/Utils/SYCLSplitModule.h +++ b/llvm/include/llvm/Frontend/SYCL/Utils.h @@ -1,28 +1,29 @@ -//===-------- SYCLSplitModule.h - module split ------------------*- C++ -*-===// +//===------------ Utils.h - SYCL utility functions ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// Functionality to split a module into callgraphs. A callgraph here is a set -// of entry points with all functions reachable from them via a call. The result -// of the split is new modules containing corresponding callgraph. +// Utility functions for SYCL. //===----------------------------------------------------------------------===// +#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H +#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H -#ifndef LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H -#define LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H - -#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" -#include -#include #include namespace llvm { class Module; +class Function; +class raw_ostream; + +namespace sycl { enum class IRSplitMode { IRSM_PER_TU, // one module per translation unit @@ -34,31 +35,33 @@ enum class IRSplitMode { /// returned. std::optional convertStringToSplitMode(StringRef S); -/// The structure represents a split LLVM Module accompanied by additional +/// The structure represents a LLVM Module accompanied by additional /// information. Split Modules are being stored at disk due to the high RAM /// consumption during the whole splitting process. struct ModuleAndSYCLMetadata { std::string ModuleFilePath; std::string Symbols; - ModuleAndSYCLMetadata() = default; + ModuleAndSYCLMetadata() = delete; ModuleAndSYCLMetadata(const ModuleAndSYCLMetadata &) = default; ModuleAndSYCLMetadata &operator=(const ModuleAndSYCLMetadata &) = default; ModuleAndSYCLMetadata(ModuleAndSYCLMetadata &&) = default; ModuleAndSYCLMetadata &operator=(ModuleAndSYCLMetadata &&) = default; - ModuleAndSYCLMetadata(std::string_view File, std::string Symbols) - : ModuleFilePath(File), Symbols(std::move(Symbols)) {} + ModuleAndSYCLMetadata(const Twine &File, std::string Symbols) + : ModuleFilePath(File.str()), Symbols(std::move(Symbols)) {} }; -using PostSYCLSplitCallbackType = - function_ref Part, std::string Symbols)>; +/// Checks whether the function is a SYCL entry point. +bool isEntryPoint(const Function &F); + +std::string makeSymbolTable(const Module &M); + +using StringTable = SmallVector>>; -/// Splits the given module \p M according to the given \p Settings. -/// Every split image is being passed to \p Callback. -void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, - PostSYCLSplitCallbackType Callback); +void writeStringTable(const StringTable &Table, raw_ostream &OS); +} // namespace sycl } // namespace llvm -#endif // LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H diff --git a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h deleted file mode 100644 index 75459eed6ac0f..0000000000000 --- a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h +++ /dev/null @@ -1,26 +0,0 @@ -//===------------ SYCLUtils.h - SYCL utility functions --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// Utility functions for SYCL. -//===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H -#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H - -#include -#include - -namespace llvm { - -class raw_ostream; - -using SYCLStringTable = SmallVector>>; - -void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS); - -} // namespace llvm - -#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H diff --git a/llvm/lib/Frontend/CMakeLists.txt b/llvm/lib/Frontend/CMakeLists.txt index b305ce7d771ce..00d51bd178974 100644 --- a/llvm/lib/Frontend/CMakeLists.txt +++ b/llvm/lib/Frontend/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(HLSL) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Offloading) +add_subdirectory(SYCL) diff --git a/llvm/lib/Frontend/SYCL/CMakeLists.txt b/llvm/lib/Frontend/SYCL/CMakeLists.txt new file mode 100644 index 0000000000000..893abcf9aebd8 --- /dev/null +++ b/llvm/lib/Frontend/SYCL/CMakeLists.txt @@ -0,0 +1,13 @@ +add_llvm_component_library(LLVMFrontendSYCL + SplitModule.cpp + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend + ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/SYCL + + LINK_COMPONENTS + Core + Support + TransformUtils + ) diff --git a/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp b/llvm/lib/Frontend/SYCL/SplitModule.cpp similarity index 81% rename from llvm/lib/Transforms/Utils/SYCLSplitModule.cpp rename to llvm/lib/Frontend/SYCL/SplitModule.cpp index 18eca4237c8ae..e0ba06aa617b4 100644 --- a/llvm/lib/Transforms/Utils/SYCLSplitModule.cpp +++ b/llvm/lib/Frontend/SYCL/SplitModule.cpp @@ -1,4 +1,4 @@ -//===-------- SYCLSplitModule.cpp - Split a module into call graphs -------===// +//===-------- SplitModule.cpp - split a module by categories --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,10 +8,11 @@ // See comments in the header. //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/SYCLSplitModule.h" +#include "llvm/Frontend/SYCL/SplitModule.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -19,39 +20,30 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/SYCLUtils.h" #include #include +#include using namespace llvm; +using namespace llvm::sycl; #define DEBUG_TYPE "sycl-split-module" -static bool isKernel(const Function &F) { - return F.getCallingConv() == CallingConv::SPIR_KERNEL || - F.getCallingConv() == CallingConv::AMDGPU_KERNEL; -} - -static bool isEntryPoint(const Function &F) { - // Skip declarations, if any: they should not be included into a vector of - // entry points groups or otherwise we will end up with incorrectly generated - // list of symbols. - if (F.isDeclaration()) - return false; +namespace { - // Kernels are always considered to be entry points - return isKernel(F); +bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; // TODO: add test. } -namespace { - -// A vector that contains all entry point functions in a split module. +// A vector that contains a group of function with the same category. using EntryPointSet = SetVector; -/// Represents a named group entry points. +/// Represents a group of functions with one category. struct EntryPointGroup { - std::string GroupName; + std::string GroupId; EntryPointSet Functions; EntryPointGroup() = default; @@ -60,20 +52,16 @@ struct EntryPointGroup { EntryPointGroup(EntryPointGroup &&) = default; EntryPointGroup &operator=(EntryPointGroup &&) = default; - EntryPointGroup(StringRef GroupName, - EntryPointSet Functions = EntryPointSet()) - : GroupName(GroupName), Functions(std::move(Functions)) {} + EntryPointGroup(std::string GroupId, EntryPointSet Functions = EntryPointSet()) + : GroupId(GroupId), Functions(std::move(Functions)) {} - void clear() { - GroupName.clear(); - Functions.clear(); - } + void clear() { Functions.clear(); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void dump() const { constexpr size_t INDENT = 4; dbgs().indent(INDENT) << "ENTRY POINTS" - << " " << GroupName << " {\n"; + << " " << GroupId << " {\n"; for (const Function *F : Functions) dbgs().indent(INDENT) << " " << F->getName() << "\n"; @@ -83,7 +71,7 @@ struct EntryPointGroup { }; /// Annotates an llvm::Module with information necessary to perform and track -/// the result of device code (llvm::Module instances) splitting: +/// the result of code (llvm::Module instances) splitting: /// - entry points group from the module. class ModuleDesc { std::unique_ptr M; @@ -110,15 +98,6 @@ class ModuleDesc { return std::move(M); } - std::string makeSymbolTable() const { - SmallString<0> Data; - raw_svector_ostream OS(Data); - for (const Function *F : EntryPoints.Functions) - OS << F->getName() << '\n'; - - return std::string(OS.str()); - } - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void dump() const { dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; @@ -159,7 +138,7 @@ class DependencyGraph { DenseMap FuncTypeToFuncsMap; for (const auto &F : M.functions()) { - // Kernels can't be called (either directly or indirectly) in SYCL + // Kernels can't be called (either directly or indirectly). if (isKernel(F)) continue; @@ -284,7 +263,7 @@ ModuleDesc extractCallGraph(const Module &M, EntryPointGroup ModuleEntryPoints, return SplitM; } -using EntryPointGroupVec = SmallVector; +using EntryPointGroupVec = SmallVector; /// Module Splitter. /// It gets a module (in a form of module descriptor, to get additional info) @@ -321,10 +300,7 @@ class ModuleSplitter { bool hasMoreSplits() const { return Groups.size() > 0; } }; -} // namespace - -static EntryPointGroupVec selectEntryPointGroups(const Module &M, - IRSplitMode Mode) { +EntryPointGroupVec selectEntryPointGroups(const Module &M, IRSplitMode Mode) { // std::map is used here to ensure stable ordering of entry point groups, // which is based on their contents, this greatly helps LIT tests std::map EntryPointsMap; @@ -353,38 +329,23 @@ static EntryPointGroupVec selectEntryPointGroups(const Module &M, if (EntryPointsMap.empty()) { // No entry points met, record this. Groups.emplace_back("-", EntryPointSet()); - } else { - Groups.reserve(EntryPointsMap.size()); - // Start with properties of a source module - for (auto &[Key, EntryPoints] : EntryPointsMap) - Groups.emplace_back(Key, std::move(EntryPoints)); + return Groups; } + Groups.reserve(EntryPointsMap.size()); + // Start with properties of a source module + for (auto &[Key, EntryPoints] : EntryPointsMap) + Groups.emplace_back(Key, std::move(EntryPoints)); + return Groups; } -namespace llvm { - -std::optional convertStringToSplitMode(StringRef S) { - static const StringMap Values = { - {"source", IRSplitMode::IRSM_PER_TU}, - {"kernel", IRSplitMode::IRSM_PER_KERNEL}, - {"none", IRSplitMode::IRSM_NONE}}; - - auto It = Values.find(S); - if (It == Values.end()) - return std::nullopt; - - return It->second; -} +} // namespace -void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, - PostSYCLSplitCallbackType Callback) { - SmallVector OutputImages; +void llvm::sycl::splitModule(std::unique_ptr M, IRSplitMode Mode, + PostSplitCallbackType Callback) { if (Mode == IRSplitMode::IRSM_NONE) { - auto MD = ModuleDesc(std::move(M)); - auto Symbols = MD.makeSymbolTable(); - Callback(std::move(MD.releaseModule()), std::move(Symbols)); + Callback(std::move(M)); return; } @@ -393,9 +354,6 @@ void SYCLSplitModule(std::unique_ptr M, IRSplitMode Mode, ModuleSplitter Splitter(std::move(MD), std::move(Groups)); while (Splitter.hasMoreSplits()) { ModuleDesc MD = Splitter.getNextSplit(); - auto Symbols = MD.makeSymbolTable(); - Callback(std::move(MD.releaseModule()), std::move(Symbols)); + Callback(std::move(MD.releaseModule())); } } - -} // namespace llvm diff --git a/llvm/lib/Frontend/SYCL/Utils.cpp b/llvm/lib/Frontend/SYCL/Utils.cpp new file mode 100644 index 0000000000000..6fe578a1961de --- /dev/null +++ b/llvm/lib/Frontend/SYCL/Utils.cpp @@ -0,0 +1,79 @@ +//===------------ Utils.cpp - SYCL utility functions ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// SYCL utility functions. +//===----------------------------------------------------------------------===// +#include "llvm/Frontend/SYCL/Utils.h" + +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace sycl; + +namespace { + +bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; +} + +} // anonymous namespace + +namespace llvm { +namespace sycl { + +std::optional convertStringToSplitMode(StringRef S) { + static const StringMap Values = { + {"source", IRSplitMode::IRSM_PER_TU}, + {"kernel", IRSplitMode::IRSM_PER_KERNEL}, + {"none", IRSplitMode::IRSM_NONE}}; + + auto It = Values.find(S); + if (It == Values.end()) + return std::nullopt; + + return It->second; +} + +bool isEntryPoint(const Function &F) { + // Skip declarations, if any: they should not be included into a vector of + // entry points groups or otherwise we will end up with incorrectly generated + // list of symbols. + if (F.isDeclaration()) + return false; + + // Kernels are always considered to be entry points + return isKernel(F); +} + +std::string makeSymbolTable(const Module &M) { + SmallString<0> Data; + raw_svector_ostream OS(Data); + for (const auto &F : M) + if (isEntryPoint(F)) + OS << F.getName() << '\n'; + + return std::string(OS.str()); +} + +void writeStringTable(const StringTable &Table, raw_ostream &OS) { + assert(!Table.empty() && "table should contain at least column titles"); + assert(!Table[0].empty() && "table should be non-empty"); + OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n"; + for (size_t I = 1, E = Table.size(); I != E; ++I) { + assert(Table[I].size() == Table[0].size() && "row's size should be equal"); + OS << join(Table[I].begin(), Table[I].end(), "|") << '\n'; + } +} + +} // namespace sycl +} // namespace llvm diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 0ba46bdadea8d..78cad0d253be8 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -83,8 +83,6 @@ add_llvm_component_library(LLVMTransformUtils SizeOpts.cpp SplitModule.cpp StripNonLineTableDebugInfo.cpp - SYCLSplitModule.cpp - SYCLUtils.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp UnifyLoopExits.cpp diff --git a/llvm/lib/Transforms/Utils/SYCLUtils.cpp b/llvm/lib/Transforms/Utils/SYCLUtils.cpp deleted file mode 100644 index ad9864fadb828..0000000000000 --- a/llvm/lib/Transforms/Utils/SYCLUtils.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// SYCL utility functions. -//===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/SYCLUtils.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/raw_ostream.h" - -namespace llvm { - -void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS) { - assert(!Table.empty() && "table should contain at least column titles"); - assert(!Table[0].empty() && "table should be non-empty"); - OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n"; - for (size_t I = 1, E = Table.size(); I != E; ++I) { - assert(Table[I].size() == Table[0].size() && "row's size should be equal"); - OS << join(Table[I].begin(), Table[I].end(), "|") << '\n'; - } -} - -} // namespace llvm diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll new file mode 100644 index 0000000000000..0c40c1b4f4ff0 --- /dev/null +++ b/llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll @@ -0,0 +1,17 @@ +; -- Per-kernel split +; RUN: llvm-split -sycl-split=kernel -S < %s -o %tC +; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 +; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 + +define dso_local ptx_kernel void @KernelA() { + ret void +} + +define dso_local ptx_kernel void @KernelB() { + ret void +} + +; CHECK-A0: define dso_local ptx_kernel void @KernelB() +; CHECK-A0-NOT: define dso_local ptx_kernel void @KernelA() +; CHECK-A1-NOT: define dso_local ptx_kernel void @KernelB() +; CHECK-A1: define dso_local ptx_kernel void @KernelA() diff --git a/llvm/tools/llvm-split/CMakeLists.txt b/llvm/tools/llvm-split/CMakeLists.txt index b755755a984fc..c80ca4aba6ec6 100644 --- a/llvm/tools/llvm-split/CMakeLists.txt +++ b/llvm/tools/llvm-split/CMakeLists.txt @@ -7,6 +7,7 @@ set(LLVM_LINK_COMPONENTS BitWriter CodeGen Core + FrontendSYCL IRReader MC Support diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index f6e90985304d6..3eeb069294e25 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -33,9 +33,9 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/GlobalDCE.h" -#include "llvm/Transforms/Utils/SYCLSplitModule.h" -#include "llvm/Transforms/Utils/SYCLUtils.h" +#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/Transforms/Utils/SplitModule.h" +#include "llvm/Frontend/SYCL/SplitModule.h" using namespace llvm; @@ -78,14 +78,14 @@ static cl::opt MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"), cl::value_desc("cpu"), cl::cat(SplitCategory)); -static cl::opt SYCLSplitMode( +static cl::opt SYCLSplitMode( "sycl-split", cl::desc("SYCL Split Mode. If present, SYCL splitting algorithm is used " "with the specified mode."), - cl::Optional, cl::init(IRSplitMode::IRSM_NONE), - cl::values(clEnumValN(IRSplitMode::IRSM_PER_TU, "source", + cl::Optional, cl::init(sycl::IRSplitMode::IRSM_NONE), + cl::values(clEnumValN(sycl::IRSplitMode::IRSM_PER_TU, "source", "1 ouptput module per translation unit"), - clEnumValN(IRSplitMode::IRSM_PER_KERNEL, "kernel", + clEnumValN(sycl::IRSplitMode::IRSM_PER_KERNEL, "kernel", "1 output module per kernel")), cl::cat(SplitCategory)); @@ -119,13 +119,13 @@ void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { WriteBitcodeToFile(M, OS); } -void writeSplitModulesAsTable(ArrayRef Modules, +void writeSplitModulesAsTable(ArrayRef Modules, StringRef Path) { SmallVector> Columns; Columns.emplace_back("Code"); Columns.emplace_back("Symbols"); - SYCLStringTable Table; + sycl::StringTable Table; Table.emplace_back(std::move(Columns)); for (const auto &[I, SM] : enumerate(Modules)) { SmallString<128> SymbolsFile; @@ -144,7 +144,7 @@ void writeSplitModulesAsTable(ArrayRef Modules, exit(1); } - writeSYCLStringTable(Table, OS); + sycl::writeStringTable(Table, OS); } void cleanupModule(Module &M) { @@ -156,9 +156,8 @@ void cleanupModule(Module &M) { } Error runSYCLSplitModule(std::unique_ptr M) { - SmallVector SplitModules; - auto PostSYCLSplitCallback = [&](std::unique_ptr MPart, - std::string Symbols) { + SmallVector SplitModules; + auto PostSplitCallback = [&](std::unique_ptr MPart) { if (verifyModule(*MPart)) { errs() << "Broken Module!\n"; exit(1); @@ -172,10 +171,11 @@ Error runSYCLSplitModule(std::unique_ptr M) { std::string ModulePath = (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str(); writeModuleToFile(*MPart, ModulePath, OutputAssembly); + auto Symbols = sycl::makeSymbolTable(*MPart); SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); }; - SYCLSplitModule(std::move(M), SYCLSplitMode, PostSYCLSplitCallback); + sycl::splitModule(std::move(M), SYCLSplitMode, PostSplitCallback); writeSplitModulesAsTable(SplitModules, OutputFilename); return Error::success(); } @@ -233,7 +233,7 @@ int main(int argc, char **argv) { Out->keep(); }; - if (SYCLSplitMode != IRSplitMode::IRSM_NONE) { + if (SYCLSplitMode != sycl::IRSplitMode::IRSM_NONE) { auto E = runSYCLSplitModule(std::move(M)); if (E) { errs() << E << "\n"; From c69c62ed850dd25affae8676b8198d90f54e9e26 Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Thu, 24 Apr 2025 07:41:24 -0700 Subject: [PATCH 3/9] apply clang-format --- llvm/include/llvm/Frontend/SYCL/SplitModule.h | 6 +++--- llvm/include/llvm/Frontend/SYCL/Utils.h | 6 +++--- llvm/lib/Frontend/SYCL/SplitModule.cpp | 7 ++++--- llvm/tools/llvm-split/llvm-split.cpp | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/llvm/include/llvm/Frontend/SYCL/SplitModule.h b/llvm/include/llvm/Frontend/SYCL/SplitModule.h index 2f332d045706f..47c501a9ace57 100644 --- a/llvm/include/llvm/Frontend/SYCL/SplitModule.h +++ b/llvm/include/llvm/Frontend/SYCL/SplitModule.h @@ -8,8 +8,8 @@ // Functionality to split a module by categories. //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H -#define LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H +#define LLVM_FRONTEND_SYCL_SPLIT_MODULE_H #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Frontend/SYCL/Utils.h" @@ -36,4 +36,4 @@ void splitModule(std::unique_ptr M, IRSplitMode Mode, } // namespace sycl } // namespace llvm -#endif // LLVM_TRANSFORMS_UTILS_SYCLSPLITMODULE_H +#endif // LLVM_FRONTEND_SYCL_SPLIT_MODULE_H diff --git a/llvm/include/llvm/Frontend/SYCL/Utils.h b/llvm/include/llvm/Frontend/SYCL/Utils.h index 6db1fb5a4715c..d90c7fbe32e24 100644 --- a/llvm/include/llvm/Frontend/SYCL/Utils.h +++ b/llvm/include/llvm/Frontend/SYCL/Utils.h @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// // Utility functions for SYCL. //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H -#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H +#ifndef LLVM_FRONTEND_SYCL_UTILS_H +#define LLVM_FRONTEND_SYCL_UTILS_H #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -64,4 +64,4 @@ void writeStringTable(const StringTable &Table, raw_ostream &OS); } // namespace sycl } // namespace llvm -#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H +#endif // LLVM_FRONTEND_SYCL_UTILS_H diff --git a/llvm/lib/Frontend/SYCL/SplitModule.cpp b/llvm/lib/Frontend/SYCL/SplitModule.cpp index e0ba06aa617b4..658e1cd5befc6 100644 --- a/llvm/lib/Frontend/SYCL/SplitModule.cpp +++ b/llvm/lib/Frontend/SYCL/SplitModule.cpp @@ -22,8 +22,8 @@ #include "llvm/Transforms/Utils/Cloning.h" #include -#include #include +#include using namespace llvm; using namespace llvm::sycl; @@ -52,7 +52,8 @@ struct EntryPointGroup { EntryPointGroup(EntryPointGroup &&) = default; EntryPointGroup &operator=(EntryPointGroup &&) = default; - EntryPointGroup(std::string GroupId, EntryPointSet Functions = EntryPointSet()) + EntryPointGroup(std::string GroupId, + EntryPointSet Functions = EntryPointSet()) : GroupId(GroupId), Functions(std::move(Functions)) {} void clear() { Functions.clear(); } @@ -343,7 +344,7 @@ EntryPointGroupVec selectEntryPointGroups(const Module &M, IRSplitMode Mode) { } // namespace void llvm::sycl::splitModule(std::unique_ptr M, IRSplitMode Mode, - PostSplitCallbackType Callback) { + PostSplitCallbackType Callback) { if (Mode == IRSplitMode::IRSM_NONE) { Callback(std::move(M)); return; diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 3eeb069294e25..84ec7d4bdb59b 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -15,6 +15,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Frontend/SYCL/SplitModule.h" +#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassInstrumentation.h" #include "llvm/IR/PassManager.h" @@ -33,9 +35,7 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/GlobalDCE.h" -#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/Transforms/Utils/SplitModule.h" -#include "llvm/Frontend/SYCL/SplitModule.h" using namespace llvm; From 483933b604fa3fa569e4db8e51803cbec5ea067a Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Thu, 24 Apr 2025 07:45:50 -0700 Subject: [PATCH 4/9] update comment --- llvm/include/llvm/Frontend/SYCL/SplitModule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/include/llvm/Frontend/SYCL/SplitModule.h b/llvm/include/llvm/Frontend/SYCL/SplitModule.h index 47c501a9ace57..c886e8e5612d6 100644 --- a/llvm/include/llvm/Frontend/SYCL/SplitModule.h +++ b/llvm/include/llvm/Frontend/SYCL/SplitModule.h @@ -5,7 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// Functionality to split a module by categories. +// Functionality to split a module for SYCL Offloading Kind. //===----------------------------------------------------------------------===// #ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H From 141c0395dd1e6cbe3d0242e65367778d79726a3f Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Fri, 16 May 2025 10:06:09 -0700 Subject: [PATCH 5/9] Make splitting algorithm generic. Now function's interface accepts FunctionCategorizer. --- .../Utils.h => Transforms/Utils/SYCLUtils.h} | 42 +++++++++++-- .../Utils/SplitModuleByCategory.h} | 13 ++-- llvm/lib/Frontend/CMakeLists.txt | 1 - llvm/lib/Frontend/SYCL/CMakeLists.txt | 13 ---- llvm/lib/Transforms/Utils/CMakeLists.txt | 2 + .../Utils/SYCLUtils.cpp} | 58 +++++++++++++++--- .../Utils/SplitModuleByCategory.cpp} | 61 ++++++------------- llvm/tools/llvm-split/llvm-split.cpp | 8 ++- 8 files changed, 121 insertions(+), 77 deletions(-) rename llvm/include/llvm/{Frontend/SYCL/Utils.h => Transforms/Utils/SYCLUtils.h} (61%) rename llvm/include/llvm/{Frontend/SYCL/SplitModule.h => Transforms/Utils/SplitModuleByCategory.h} (65%) delete mode 100644 llvm/lib/Frontend/SYCL/CMakeLists.txt rename llvm/lib/{Frontend/SYCL/Utils.cpp => Transforms/Utils/SYCLUtils.cpp} (64%) rename llvm/lib/{Frontend/SYCL/SplitModule.cpp => Transforms/Utils/SplitModuleByCategory.cpp} (88%) diff --git a/llvm/include/llvm/Frontend/SYCL/Utils.h b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h similarity index 61% rename from llvm/include/llvm/Frontend/SYCL/Utils.h rename to llvm/include/llvm/Transforms/Utils/SYCLUtils.h index d90c7fbe32e24..3519855a82657 100644 --- a/llvm/include/llvm/Frontend/SYCL/Utils.h +++ b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h @@ -1,4 +1,4 @@ -//===------------ Utils.h - SYCL utility functions ------------------------===// +//===------------ SYCLUtils.h - SYCL utility functions --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,11 +10,14 @@ #ifndef LLVM_FRONTEND_SYCL_UTILS_H #define LLVM_FRONTEND_SYCL_UTILS_H +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" +#include #include namespace llvm { @@ -35,6 +38,40 @@ enum class IRSplitMode { /// returned. std::optional convertStringToSplitMode(StringRef S); +/// FunctionCategorizer used for splitting in SYCL compilation flow. +class FunctionCategorizer { +public: + FunctionCategorizer(IRSplitMode SM); + + FunctionCategorizer() = delete; + FunctionCategorizer(FunctionCategorizer &) = delete; + FunctionCategorizer &operator=(const FunctionCategorizer &) = delete; + FunctionCategorizer(FunctionCategorizer &&) = default; + FunctionCategorizer &operator=(FunctionCategorizer &&) = default; + + /// Returns integer specifying the category for the entry point. + /// If the given function isn't an entry point then returns std::nullopt. + std::optional operator()(const Function &F); + +private: + struct KeyInfo { + static SmallString<0> getEmptyKey() { return SmallString<0>(""); } + + static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); } + + static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) { + return LHS == RHS; + } + + static unsigned getHashValue(const SmallString<0> &S) { + return llvm::hash_value(StringRef(S)); + } + }; + + IRSplitMode SM; + DenseMap, int, KeyInfo> StrKeyToID; +}; + /// The structure represents a LLVM Module accompanied by additional /// information. Split Modules are being stored at disk due to the high RAM /// consumption during the whole splitting process. @@ -52,9 +89,6 @@ struct ModuleAndSYCLMetadata { : ModuleFilePath(File.str()), Symbols(std::move(Symbols)) {} }; -/// Checks whether the function is a SYCL entry point. -bool isEntryPoint(const Function &F); - std::string makeSymbolTable(const Module &M); using StringTable = SmallVector>>; diff --git a/llvm/include/llvm/Frontend/SYCL/SplitModule.h b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h similarity index 65% rename from llvm/include/llvm/Frontend/SYCL/SplitModule.h rename to llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h index c886e8e5612d6..dc5f00e5060d1 100644 --- a/llvm/include/llvm/Frontend/SYCL/SplitModule.h +++ b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h @@ -1,18 +1,17 @@ -//===-------- SplitModule.h - module split ----------------------*- C++ -*-===// +//===-------- SplitModuleByCategory.h - module split ------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// Functionality to split a module for SYCL Offloading Kind. +// Functionality to split a module by categories. //===----------------------------------------------------------------------===// #ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H #define LLVM_FRONTEND_SYCL_SPLIT_MODULE_H #include "llvm/ADT/STLFunctionalExtras.h" -#include "llvm/Frontend/SYCL/Utils.h" #include #include @@ -25,13 +24,17 @@ class Function; namespace sycl { +/// FunctionCategorizer returns integer category for the given Function. +/// Otherwise, it returns std::nullopt if function doesn't have a category. +using FunctionCategorizer = function_ref(const Function &F)>; + using PostSplitCallbackType = function_ref Part)>; /// Splits the given module \p M. /// Every split image is being passed to \p Callback for further possible /// processing. -void splitModule(std::unique_ptr M, IRSplitMode Mode, - PostSplitCallbackType Callback); +void splitModuleByCategory(std::unique_ptr M, FunctionCategorizer FC, + PostSplitCallbackType Callback); } // namespace sycl } // namespace llvm diff --git a/llvm/lib/Frontend/CMakeLists.txt b/llvm/lib/Frontend/CMakeLists.txt index 00d51bd178974..b305ce7d771ce 100644 --- a/llvm/lib/Frontend/CMakeLists.txt +++ b/llvm/lib/Frontend/CMakeLists.txt @@ -4,4 +4,3 @@ add_subdirectory(HLSL) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(Offloading) -add_subdirectory(SYCL) diff --git a/llvm/lib/Frontend/SYCL/CMakeLists.txt b/llvm/lib/Frontend/SYCL/CMakeLists.txt deleted file mode 100644 index 893abcf9aebd8..0000000000000 --- a/llvm/lib/Frontend/SYCL/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_llvm_component_library(LLVMFrontendSYCL - SplitModule.cpp - Utils.cpp - - ADDITIONAL_HEADER_DIRS - ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend - ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/SYCL - - LINK_COMPONENTS - Core - Support - TransformUtils - ) diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 78cad0d253be8..6c3cd042fe602 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -82,7 +82,9 @@ add_llvm_component_library(LLVMTransformUtils SimplifyLibCalls.cpp SizeOpts.cpp SplitModule.cpp + SplitModuleByCategory.cpp StripNonLineTableDebugInfo.cpp + SYCLUtils.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp UnifyLoopExits.cpp diff --git a/llvm/lib/Frontend/SYCL/Utils.cpp b/llvm/lib/Transforms/Utils/SYCLUtils.cpp similarity index 64% rename from llvm/lib/Frontend/SYCL/Utils.cpp rename to llvm/lib/Transforms/Utils/SYCLUtils.cpp index 6fe578a1961de..e2fe097ade00a 100644 --- a/llvm/lib/Frontend/SYCL/Utils.cpp +++ b/llvm/lib/Transforms/Utils/SYCLUtils.cpp @@ -1,4 +1,4 @@ -//===------------ Utils.cpp - SYCL utility functions ----------------------===// +//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,7 +7,8 @@ //===----------------------------------------------------------------------===// // SYCL utility functions. //===----------------------------------------------------------------------===// -#include "llvm/Frontend/SYCL/Utils.h" + +#include "llvm/Transforms/Utils/SYCLUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" @@ -15,17 +16,48 @@ #include "llvm/IR/Module.h" #include "llvm/Support/raw_ostream.h" +#include + using namespace llvm; using namespace sycl; namespace { +SmallString<0> computeFunctionCategoryForSplitting(IRSplitMode SM, + const Function &F) { + static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; + SmallString<0> Key; + switch (SM) { + case IRSplitMode::IRSM_PER_KERNEL: + Key = F.getName().str(); + break; + case IRSplitMode::IRSM_PER_TU: + Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString().str(); + break; + default: + llvm_unreachable("other modes aren't expected"); + } + + return Key; +} + bool isKernel(const Function &F) { return F.getCallingConv() == CallingConv::SPIR_KERNEL || F.getCallingConv() == CallingConv::AMDGPU_KERNEL || F.getCallingConv() == CallingConv::PTX_Kernel; } +bool isEntryPoint(const Function &F) { + // Skip declarations, if any: they should not be included into a vector of + // entry points groups or otherwise we will end up with incorrectly generated + // list of symbols. + if (F.isDeclaration()) + return false; + + // Kernels are always considered to be entry points + return isKernel(F); +} + } // anonymous namespace namespace llvm { @@ -44,15 +76,21 @@ std::optional convertStringToSplitMode(StringRef S) { return It->second; } -bool isEntryPoint(const Function &F) { - // Skip declarations, if any: they should not be included into a vector of - // entry points groups or otherwise we will end up with incorrectly generated - // list of symbols. - if (F.isDeclaration()) - return false; +FunctionCategorizer::FunctionCategorizer(IRSplitMode SM) : SM(SM) { + if (SM == IRSplitMode::IRSM_NONE) + llvm_unreachable("FunctionCategorizer isn't supported to none splitting."); +} - // Kernels are always considered to be entry points - return isKernel(F); +std::optional FunctionCategorizer::operator()(const Function &F) { + if (!isEntryPoint(F)) + return std::nullopt; // skip the function. + + auto StringKey = computeFunctionCategoryForSplitting(SM, F); + if (auto it = StrKeyToID.find(StringRef(StringKey)); it != StrKeyToID.end()) + return it->second; + + int ID = static_cast(StrKeyToID.size()); + return StrKeyToID.try_emplace(std::move(StringKey), ID).first->second; } std::string makeSymbolTable(const Module &M) { diff --git a/llvm/lib/Frontend/SYCL/SplitModule.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp similarity index 88% rename from llvm/lib/Frontend/SYCL/SplitModule.cpp rename to llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp index 658e1cd5befc6..4568c1f30c282 100644 --- a/llvm/lib/Frontend/SYCL/SplitModule.cpp +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -1,4 +1,4 @@ -//===-------- SplitModule.cpp - split a module by categories --------------===// +//===-------- SplitModuleByCategory.cpp - split a module by categories ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,11 +8,10 @@ // See comments in the header. //===----------------------------------------------------------------------===// -#include "llvm/Frontend/SYCL/SplitModule.h" +#include "llvm/Transforms/Utils/SplitModuleByCategory.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -35,7 +34,7 @@ namespace { bool isKernel(const Function &F) { return F.getCallingConv() == CallingConv::SPIR_KERNEL || F.getCallingConv() == CallingConv::AMDGPU_KERNEL || - F.getCallingConv() == CallingConv::PTX_Kernel; // TODO: add test. + F.getCallingConv() == CallingConv::PTX_Kernel; } // A vector that contains a group of function with the same category. @@ -43,7 +42,7 @@ using EntryPointSet = SetVector; /// Represents a group of functions with one category. struct EntryPointGroup { - std::string GroupId; + int ID; EntryPointSet Functions; EntryPointGroup() = default; @@ -52,9 +51,8 @@ struct EntryPointGroup { EntryPointGroup(EntryPointGroup &&) = default; EntryPointGroup &operator=(EntryPointGroup &&) = default; - EntryPointGroup(std::string GroupId, - EntryPointSet Functions = EntryPointSet()) - : GroupId(GroupId), Functions(std::move(Functions)) {} + EntryPointGroup(int ID, EntryPointSet Functions = EntryPointSet()) + : ID(ID), Functions(std::move(Functions)) {} void clear() { Functions.clear(); } @@ -62,7 +60,7 @@ struct EntryPointGroup { LLVM_DUMP_METHOD void dump() const { constexpr size_t INDENT = 4; dbgs().indent(INDENT) << "ENTRY POINTS" - << " " << GroupId << " {\n"; + << " " << ID << " {\n"; for (const Function *F : Functions) dbgs().indent(INDENT) << " " << F->getName() << "\n"; @@ -301,38 +299,23 @@ class ModuleSplitter { bool hasMoreSplits() const { return Groups.size() > 0; } }; -EntryPointGroupVec selectEntryPointGroups(const Module &M, IRSplitMode Mode) { +EntryPointGroupVec selectEntryPointGroups(const Module &M, + FunctionCategorizer FC) { // std::map is used here to ensure stable ordering of entry point groups, // which is based on their contents, this greatly helps LIT tests - std::map EntryPointsMap; + std::map EntryPointsMap; - static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; for (const auto &F : M.functions()) { - if (!isEntryPoint(F)) - continue; - - std::string Key; - switch (Mode) { - case IRSplitMode::IRSM_PER_KERNEL: - Key = F.getName(); - break; - case IRSplitMode::IRSM_PER_TU: - Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString(); - break; - case IRSplitMode::IRSM_NONE: - llvm_unreachable(""); - } + if (auto Key = FC(F); Key) { + auto It = EntryPointsMap.find(*Key); + if (It == EntryPointsMap.end()) + It = EntryPointsMap.emplace(*Key, EntryPointSet()).first; - EntryPointsMap[Key].insert(&F); + It->second.insert(&F); + } } EntryPointGroupVec Groups; - if (EntryPointsMap.empty()) { - // No entry points met, record this. - Groups.emplace_back("-", EntryPointSet()); - return Groups; - } - Groups.reserve(EntryPointsMap.size()); // Start with properties of a source module for (auto &[Key, EntryPoints] : EntryPointsMap) @@ -343,14 +326,10 @@ EntryPointGroupVec selectEntryPointGroups(const Module &M, IRSplitMode Mode) { } // namespace -void llvm::sycl::splitModule(std::unique_ptr M, IRSplitMode Mode, - PostSplitCallbackType Callback) { - if (Mode == IRSplitMode::IRSM_NONE) { - Callback(std::move(M)); - return; - } - - EntryPointGroupVec Groups = selectEntryPointGroups(*M, Mode); +void llvm::sycl::splitModuleByCategory(std::unique_ptr M, + FunctionCategorizer FC, + PostSplitCallbackType Callback) { + EntryPointGroupVec Groups = selectEntryPointGroups(*M, FC); ModuleDesc MD = std::move(M); ModuleSplitter Splitter(std::move(MD), std::move(Groups)); while (Splitter.hasMoreSplits()) { diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 84ec7d4bdb59b..d79005ebe6cd7 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Frontend/SYCL/SplitModule.h" -#include "llvm/Frontend/SYCL/Utils.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassInstrumentation.h" #include "llvm/IR/PassManager.h" @@ -35,7 +33,9 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/GlobalDCE.h" +#include "llvm/Transforms/Utils/SYCLUtils.h" #include "llvm/Transforms/Utils/SplitModule.h" +#include "llvm/Transforms/Utils/SplitModuleByCategory.h" using namespace llvm; @@ -175,7 +175,9 @@ Error runSYCLSplitModule(std::unique_ptr M) { SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); }; - sycl::splitModule(std::move(M), SYCLSplitMode, PostSplitCallback); + auto Categorizer = sycl::FunctionCategorizer(SYCLSplitMode); + sycl::SplitModuleByCategory(std::move(M), std::move(Categorizer), + PostSplitCallback); writeSplitModulesAsTable(SplitModules, OutputFilename); return Error::success(); } From 1729c50aa76718a8bd0f5959649711b4261008d2 Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Wed, 4 Jun 2025 07:01:52 -0700 Subject: [PATCH 6/9] Remove SYCL specialization from the PR. --- .../include/llvm/Transforms/Utils/SYCLUtils.h | 101 ------------- .../Transforms/Utils/SplitModuleByCategory.h | 34 +++-- llvm/lib/Transforms/Utils/CMakeLists.txt | 1 - llvm/lib/Transforms/Utils/SYCLUtils.cpp | 117 --------------- .../Utils/SplitModuleByCategory.cpp | 35 ++--- .../split-with-kernel-declarations.ll | 66 -------- .../amd-kernel-split.ll | 2 +- .../complex-indirect-call-chain.ll | 10 +- .../module-split-func-ptr.ll | 11 +- .../one-kernel-per-module.ll | 21 +-- .../ptx-kernel-split.ll | 2 +- .../split-by-source.ll | 17 +-- .../split-with-kernel-declarations.ll | 52 +++++++ llvm/tools/llvm-split/CMakeLists.txt | 1 - llvm/tools/llvm-split/llvm-split.cpp | 142 ++++++++++++------ 15 files changed, 202 insertions(+), 410 deletions(-) delete mode 100644 llvm/include/llvm/Transforms/Utils/SYCLUtils.h delete mode 100644 llvm/lib/Transforms/Utils/SYCLUtils.cpp delete mode 100644 llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/amd-kernel-split.ll (89%) rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/complex-indirect-call-chain.ll (90%) rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/module-split-func-ptr.ll (76%) rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/one-kernel-per-module.ll (81%) rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/ptx-kernel-split.ll (89%) rename llvm/test/tools/llvm-split/{SYCL/device-code-split => SplitByCategory}/split-by-source.ll (84%) create mode 100644 llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll diff --git a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h b/llvm/include/llvm/Transforms/Utils/SYCLUtils.h deleted file mode 100644 index 3519855a82657..0000000000000 --- a/llvm/include/llvm/Transforms/Utils/SYCLUtils.h +++ /dev/null @@ -1,101 +0,0 @@ -//===------------ SYCLUtils.h - SYCL utility functions --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// Utility functions for SYCL. -//===----------------------------------------------------------------------===// -#ifndef LLVM_FRONTEND_SYCL_UTILS_H -#define LLVM_FRONTEND_SYCL_UTILS_H - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" - -#include -#include - -namespace llvm { - -class Module; -class Function; -class raw_ostream; - -namespace sycl { - -enum class IRSplitMode { - IRSM_PER_TU, // one module per translation unit - IRSM_PER_KERNEL, // one module per kernel - IRSM_NONE // no splitting -}; - -/// \returns IRSplitMode value if \p S is recognized. Otherwise, std::nullopt is -/// returned. -std::optional convertStringToSplitMode(StringRef S); - -/// FunctionCategorizer used for splitting in SYCL compilation flow. -class FunctionCategorizer { -public: - FunctionCategorizer(IRSplitMode SM); - - FunctionCategorizer() = delete; - FunctionCategorizer(FunctionCategorizer &) = delete; - FunctionCategorizer &operator=(const FunctionCategorizer &) = delete; - FunctionCategorizer(FunctionCategorizer &&) = default; - FunctionCategorizer &operator=(FunctionCategorizer &&) = default; - - /// Returns integer specifying the category for the entry point. - /// If the given function isn't an entry point then returns std::nullopt. - std::optional operator()(const Function &F); - -private: - struct KeyInfo { - static SmallString<0> getEmptyKey() { return SmallString<0>(""); } - - static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); } - - static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) { - return LHS == RHS; - } - - static unsigned getHashValue(const SmallString<0> &S) { - return llvm::hash_value(StringRef(S)); - } - }; - - IRSplitMode SM; - DenseMap, int, KeyInfo> StrKeyToID; -}; - -/// The structure represents a LLVM Module accompanied by additional -/// information. Split Modules are being stored at disk due to the high RAM -/// consumption during the whole splitting process. -struct ModuleAndSYCLMetadata { - std::string ModuleFilePath; - std::string Symbols; - - ModuleAndSYCLMetadata() = delete; - ModuleAndSYCLMetadata(const ModuleAndSYCLMetadata &) = default; - ModuleAndSYCLMetadata &operator=(const ModuleAndSYCLMetadata &) = default; - ModuleAndSYCLMetadata(ModuleAndSYCLMetadata &&) = default; - ModuleAndSYCLMetadata &operator=(ModuleAndSYCLMetadata &&) = default; - - ModuleAndSYCLMetadata(const Twine &File, std::string Symbols) - : ModuleFilePath(File.str()), Symbols(std::move(Symbols)) {} -}; - -std::string makeSymbolTable(const Module &M); - -using StringTable = SmallVector>>; - -void writeStringTable(const StringTable &Table, raw_ostream &OS); - -} // namespace sycl -} // namespace llvm - -#endif // LLVM_FRONTEND_SYCL_UTILS_H diff --git a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h index dc5f00e5060d1..8e2a6127df7dd 100644 --- a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h +++ b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h @@ -8,8 +8,8 @@ // Functionality to split a module by categories. //===----------------------------------------------------------------------===// -#ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H -#define LLVM_FRONTEND_SYCL_SPLIT_MODULE_H +#ifndef LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H +#define LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H #include "llvm/ADT/STLFunctionalExtras.h" @@ -22,21 +22,23 @@ namespace llvm { class Module; class Function; -namespace sycl { - -/// FunctionCategorizer returns integer category for the given Function. -/// Otherwise, it returns std::nullopt if function doesn't have a category. -using FunctionCategorizer = function_ref(const Function &F)>; - -using PostSplitCallbackType = function_ref Part)>; - -/// Splits the given module \p M. -/// Every split image is being passed to \p Callback for further possible +/// Splits the given module \p M using the given \p FunctionCategorizer. +/// \p FunctionCategorizer returns integer category for an input Function. +/// It may return std::nullopt if a function doesn't have a category. +/// Module's functions are being grouped by categories. Every such group +/// populates a call graph containing group's functions themselves and all +/// reachable functions and globals. Split outputs are populated from each call +/// graph associated with some category. +/// +/// Every split output is being passed to \p Callback for further possible /// processing. -void splitModuleByCategory(std::unique_ptr M, FunctionCategorizer FC, - PostSplitCallbackType Callback); +/// +/// Currently, the supported targets are SPIRV, AMDGPU and NVPTX. +void splitModuleByCategory( + std::unique_ptr M, + function_ref(const Function &F)> FunctionCategorizer, + function_ref Part)> Callback); -} // namespace sycl } // namespace llvm -#endif // LLVM_FRONTEND_SYCL_SPLIT_MODULE_H +#endif // LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index 6c3cd042fe602..bcc081d1f91d3 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -84,7 +84,6 @@ add_llvm_component_library(LLVMTransformUtils SplitModule.cpp SplitModuleByCategory.cpp StripNonLineTableDebugInfo.cpp - SYCLUtils.cpp SymbolRewriter.cpp UnifyFunctionExitNodes.cpp UnifyLoopExits.cpp diff --git a/llvm/lib/Transforms/Utils/SYCLUtils.cpp b/llvm/lib/Transforms/Utils/SYCLUtils.cpp deleted file mode 100644 index e2fe097ade00a..0000000000000 --- a/llvm/lib/Transforms/Utils/SYCLUtils.cpp +++ /dev/null @@ -1,117 +0,0 @@ -//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// SYCL utility functions. -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Utils/SYCLUtils.h" - -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/raw_ostream.h" - -#include - -using namespace llvm; -using namespace sycl; - -namespace { - -SmallString<0> computeFunctionCategoryForSplitting(IRSplitMode SM, - const Function &F) { - static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; - SmallString<0> Key; - switch (SM) { - case IRSplitMode::IRSM_PER_KERNEL: - Key = F.getName().str(); - break; - case IRSplitMode::IRSM_PER_TU: - Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString().str(); - break; - default: - llvm_unreachable("other modes aren't expected"); - } - - return Key; -} - -bool isKernel(const Function &F) { - return F.getCallingConv() == CallingConv::SPIR_KERNEL || - F.getCallingConv() == CallingConv::AMDGPU_KERNEL || - F.getCallingConv() == CallingConv::PTX_Kernel; -} - -bool isEntryPoint(const Function &F) { - // Skip declarations, if any: they should not be included into a vector of - // entry points groups or otherwise we will end up with incorrectly generated - // list of symbols. - if (F.isDeclaration()) - return false; - - // Kernels are always considered to be entry points - return isKernel(F); -} - -} // anonymous namespace - -namespace llvm { -namespace sycl { - -std::optional convertStringToSplitMode(StringRef S) { - static const StringMap Values = { - {"source", IRSplitMode::IRSM_PER_TU}, - {"kernel", IRSplitMode::IRSM_PER_KERNEL}, - {"none", IRSplitMode::IRSM_NONE}}; - - auto It = Values.find(S); - if (It == Values.end()) - return std::nullopt; - - return It->second; -} - -FunctionCategorizer::FunctionCategorizer(IRSplitMode SM) : SM(SM) { - if (SM == IRSplitMode::IRSM_NONE) - llvm_unreachable("FunctionCategorizer isn't supported to none splitting."); -} - -std::optional FunctionCategorizer::operator()(const Function &F) { - if (!isEntryPoint(F)) - return std::nullopt; // skip the function. - - auto StringKey = computeFunctionCategoryForSplitting(SM, F); - if (auto it = StrKeyToID.find(StringRef(StringKey)); it != StrKeyToID.end()) - return it->second; - - int ID = static_cast(StrKeyToID.size()); - return StrKeyToID.try_emplace(std::move(StringKey), ID).first->second; -} - -std::string makeSymbolTable(const Module &M) { - SmallString<0> Data; - raw_svector_ostream OS(Data); - for (const auto &F : M) - if (isEntryPoint(F)) - OS << F.getName() << '\n'; - - return std::string(OS.str()); -} - -void writeStringTable(const StringTable &Table, raw_ostream &OS) { - assert(!Table.empty() && "table should contain at least column titles"); - assert(!Table[0].empty() && "table should be non-empty"); - OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n"; - for (size_t I = 1, E = Table.size(); I != E; ++I) { - assert(Table[I].size() == Table[0].size() && "row's size should be equal"); - OS << join(Table[I].begin(), Table[I].end(), "|") << '\n'; - } -} - -} // namespace sycl -} // namespace llvm diff --git a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp index 4568c1f30c282..932f57a77a7bc 100644 --- a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -25,18 +25,11 @@ #include using namespace llvm; -using namespace llvm::sycl; -#define DEBUG_TYPE "sycl-split-module" +#define DEBUG_TYPE "split-module-by-category" namespace { -bool isKernel(const Function &F) { - return F.getCallingConv() == CallingConv::SPIR_KERNEL || - F.getCallingConv() == CallingConv::AMDGPU_KERNEL || - F.getCallingConv() == CallingConv::PTX_Kernel; -} - // A vector that contains a group of function with the same category. using EntryPointSet = SetVector; @@ -106,6 +99,12 @@ class ModuleDesc { #endif }; +bool isKernel(const Function &F) { + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; +} + // Represents "dependency" or "use" graph of global objects (functions and // global variables) in a module. It is used during device code split to // understand which global variables and functions (other than entry points) @@ -299,17 +298,18 @@ class ModuleSplitter { bool hasMoreSplits() const { return Groups.size() > 0; } }; -EntryPointGroupVec selectEntryPointGroups(const Module &M, - FunctionCategorizer FC) { +EntryPointGroupVec +selectEntryPointGroups(const Module &M, + function_ref(const Function &F)> FC) { // std::map is used here to ensure stable ordering of entry point groups, // which is based on their contents, this greatly helps LIT tests std::map EntryPointsMap; for (const auto &F : M.functions()) { - if (auto Key = FC(F); Key) { - auto It = EntryPointsMap.find(*Key); + if (auto Category = FC(F); Category) { + auto It = EntryPointsMap.find(*Category); if (It == EntryPointsMap.end()) - It = EntryPointsMap.emplace(*Key, EntryPointSet()).first; + It = EntryPointsMap.emplace(*Category, EntryPointSet()).first; It->second.insert(&F); } @@ -326,10 +326,11 @@ EntryPointGroupVec selectEntryPointGroups(const Module &M, } // namespace -void llvm::sycl::splitModuleByCategory(std::unique_ptr M, - FunctionCategorizer FC, - PostSplitCallbackType Callback) { - EntryPointGroupVec Groups = selectEntryPointGroups(*M, FC); +void llvm::splitModuleByCategory( + std::unique_ptr M, + function_ref(const Function &F)> FunctionCategorizer, + function_ref Part)> Callback) { + EntryPointGroupVec Groups = selectEntryPointGroups(*M, FunctionCategorizer); ModuleDesc MD = std::move(M); ModuleSplitter Splitter(std::move(MD), std::move(Groups)); while (Splitter.hasMoreSplits()) { diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll b/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll deleted file mode 100644 index 1f188d8e32db6..0000000000000 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-with-kernel-declarations.ll +++ /dev/null @@ -1,66 +0,0 @@ -; The test checks that Module splitting does not treat declarations as entry points. - -; RUN: llvm-split -sycl-split=source -S < %s -o %t1 -; RUN: FileCheck %s -input-file=%t1.table --check-prefix CHECK-PER-SOURCE-TABLE -; RUN: FileCheck %s -input-file=%t1_0.sym --check-prefix CHECK-PER-SOURCE-SYM0 -; RUN: FileCheck %s -input-file=%t1_1.sym --check-prefix CHECK-PER-SOURCE-SYM1 - -; RUN: llvm-split -sycl-split=kernel -S < %s -o %t2 -; RUN: FileCheck %s -input-file=%t2.table --check-prefix CHECK-PER-KERNEL-TABLE -; RUN: FileCheck %s -input-file=%t2_0.sym --check-prefix CHECK-PER-KERNEL-SYM0 -; RUN: FileCheck %s -input-file=%t2_1.sym --check-prefix CHECK-PER-KERNEL-SYM1 -; RUN: FileCheck %s -input-file=%t2_2.sym --check-prefix CHECK-PER-KERNEL-SYM2 - -; With per-source split, there should be two device images -; CHECK-PER-SOURCE-TABLE: [Code|Symbols] -; CHECK-PER-SOURCE-TABLE: {{.*}}_0.ll|{{.*}}_0.sym -; CHECK-PER-SOURCE-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym -; CHECK-PER-SOURCE-TABLE-EMPTY: -; -; CHECK-PER-SOURCE-SYM0-NOT: TU1_kernel1 -; CHECK-PER-SOURCE-SYM0: TU1_kernel0 -; CHECK-PER-SOURCE-SYM0-EMPTY: -; -; CHECK-PER-SOURCE-SYM1-NOT: TU1_kernel1 -; CHECK-PER-SOURCE-SYM1: TU0_kernel0 -; CHECK-PER-SOURCE-SYM1-NEXT: TU0_kernel1 -; CHECK-PER-SOURCE-SYM1-EMPTY: - -; With per-kernel split, there should be three device images -; CHECK-PER-KERNEL-TABLE: [Code|Symbols] -; CHECK-PER-KERNEL-TABLE: {{.*}}_0.ll|{{.*}}_0.sym -; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_1.ll|{{.*}}_1.sym -; CHECK-PER-KERNEL-TABLE-NEXT: {{.*}}_2.ll|{{.*}}_2.sym -; CHECK-PER-KERNEL-TABLE-EMPTY: -; -; CHECK-PER-KERNEL-SYM0-NOT: TU1_kernel1 -; CHECK-PER-KERNEL-SYM0: TU1_kernel0 -; CHECK-PER-KERNEL-SYM0-EMPTY: -; -; CHECK-PER-KERNEL-SYM1-NOT: TU1_kernel1 -; CHECK-PER-KERNEL-SYM1: TU0_kernel1 -; CHECK-PER-KERNEL-SYM1-EMPTY: -; -; CHECK-PER-KERNEL-SYM2-NOT: TU1_kernel1 -; CHECK-PER-KERNEL-SYM2: TU0_kernel0 -; CHECK-PER-KERNEL-SYM2-EMPTY: - - -define spir_kernel void @TU0_kernel0() #0 { -entry: - ret void -} - -define spir_kernel void @TU0_kernel1() #0 { -entry: - ret void -} - -define spir_kernel void @TU1_kernel0() #1 { - ret void -} - -declare spir_kernel void @TU1_kernel1() #1 - -attributes #0 = { "sycl-module-id"="TU1.cpp" } -attributes #1 = { "sycl-module-id"="TU2.cpp" } diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll b/llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll similarity index 89% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll rename to llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll index a40a52107fb0c..41a4674118267 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/amd-kernel-split.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/amd-kernel-split.ll @@ -1,5 +1,5 @@ ; -- Per-kernel split -; RUN: llvm-split -sycl-split=kernel -S < %s -o %tC +; RUN: llvm-split -split-by-category=kernel -S < %s -o %tC ; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 ; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll b/llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll similarity index 90% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll rename to llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll index 5a25e491b1b93..80123d4dd8fb7 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/complex-indirect-call-chain.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/complex-indirect-call-chain.ll @@ -1,7 +1,7 @@ ; Check that Module splitting can trace through more complex call stacks ; involving several nested indirect calls. -; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t ; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ ; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ ; RUN: --implicit-check-not @kernel_B --implicit-check-not @baz @@ -12,7 +12,7 @@ ; RUN: --implicit-check-not @BAZ --implicit-check-not @kernel_B \ ; RUN: --implicit-check-not @kernel_C -; RUN: llvm-split -sycl-split=kernel -S < %s -o %t +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t ; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix CHECK0 \ ; RUN: --implicit-check-not @foo --implicit-check-not @kernel_A \ ; RUN: --implicit-check-not @kernel_B @@ -70,6 +70,6 @@ define spir_kernel void @kernel_C() #2 { ret void } -attributes #0 = { "sycl-module-id"="TU1.cpp" } -attributes #1 = { "sycl-module-id"="TU2.cpp" } -attributes #2 = { "sycl-module-id"="TU3.cpp" } +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } +attributes #2 = { "module-id"="TU3.cpp" } diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll b/llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll similarity index 76% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll rename to llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll index c9289d78b1fda..316500a4c7611 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/module-split-func-ptr.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/module-split-func-ptr.ll @@ -1,15 +1,10 @@ ; This test checks that Module splitting can properly perform device code split by tracking ; all uses of functions (not only direct calls). -; RUN: llvm-split -sycl-split=source -S < %s -o %t -; RUN: FileCheck %s -input-file=%t_0.sym --check-prefix=CHECK-SYM0 -; RUN: FileCheck %s -input-file=%t_1.sym --check-prefix=CHECK-SYM1 +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t ; RUN: FileCheck %s -input-file=%t_0.ll --check-prefix=CHECK-IR0 ; RUN: FileCheck %s -input-file=%t_1.ll --check-prefix=CHECK-IR1 -; CHECK-SYM0: kernelA -; CHECK-SYM1: kernelB -; ; CHECK-IR0: define dso_local spir_kernel void @kernelA ; ; CHECK-IR1: @FuncTable = weak global ptr @func @@ -36,8 +31,8 @@ entry: declare dso_local spir_func i32 @indirect_call(ptr addrspace(4), i32) local_unnamed_addr -attributes #0 = { "sycl-module-id"="TU1.cpp" } -attributes #1 = { "sycl-module-id"="TU2.cpp" } +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } ; CHECK: kernel1 ; CHECK: kernel2 diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll b/llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll similarity index 81% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll rename to llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll index b949ab7530f39..c95563016ba67 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/one-kernel-per-module.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/one-kernel-per-module.ll @@ -1,22 +1,15 @@ ; Test checks "kernel" splitting mode. -; RUN: llvm-split -sycl-split=kernel -S < %s -o %t.files +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t.files ; RUN: FileCheck %s -input-file=%t.files_0.ll --check-prefixes CHECK-MODULE0,CHECK -; RUN: FileCheck %s -input-file=%t.files_0.sym --check-prefixes CHECK-MODULE0-TXT ; RUN: FileCheck %s -input-file=%t.files_1.ll --check-prefixes CHECK-MODULE1,CHECK -; RUN: FileCheck %s -input-file=%t.files_1.sym --check-prefixes CHECK-MODULE1-TXT ; RUN: FileCheck %s -input-file=%t.files_2.ll --check-prefixes CHECK-MODULE2,CHECK -; RUN: FileCheck %s -input-file=%t.files_2.sym --check-prefixes CHECK-MODULE2-TXT ;CHECK-MODULE0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 ;CHECK-MODULE1-NOT: @GV ;CHECK-MODULE2-NOT: @GV @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 -; CHECK-MODULE0-TXT-NOT: T0_kernelA -; CHECK-MODULE1-TXT-NOT: TU0_kernelA -; CHECK-MODULE2-TXT: TU0_kernelA - ; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelA ; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU0_kernelA ; CHECK-MODULE2: define dso_local spir_kernel void @TU0_kernelA @@ -45,10 +38,6 @@ entry: ret void } -; CHECK-MODULE0-TXT-NOT: TU0_kernelB -; CHECK-MODULE1-TXT: TU0_kernelB -; CHECK-MODULE2-TXT-NOT: TU0_kernelB - ; CHECK-MODULE0-NOT: define dso_local spir_kernel void @TU0_kernelB() ; CHECK-MODULE1: define dso_local spir_kernel void @TU0_kernelB() ; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU0_kernelB() @@ -67,10 +56,6 @@ entry: ret void } -; CHECK-MODULE0-TXT: TU1_kernel -; CHECK-MODULE1-TXT-NOT: TU1_kernel -; CHECK-MODULE2-TXT-NOT: TU1_kernel - ; CHECK-MODULE0: define dso_local spir_kernel void @TU1_kernel() ; CHECK-MODULE1-NOT: define dso_local spir_kernel void @TU1_kernel() ; CHECK-MODULE2-NOT: define dso_local spir_kernel void @TU1_kernel() @@ -91,8 +76,8 @@ entry: ret void } -attributes #0 = { "sycl-module-id"="TU1.cpp" } -attributes #1 = { "sycl-module-id"="TU2.cpp" } +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } ; Metadata is saved in both modules. ; CHECK: !opencl.spir.version = !{!0, !0} diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll b/llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll similarity index 89% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll rename to llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll index 0c40c1b4f4ff0..efd1e04f22c8c 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/ptx-kernel-split.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/ptx-kernel-split.ll @@ -1,5 +1,5 @@ ; -- Per-kernel split -; RUN: llvm-split -sycl-split=kernel -S < %s -o %tC +; RUN: llvm-split -split-by-category=kernel -S < %s -o %tC ; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-A0 ; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-A1 diff --git a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll b/llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll similarity index 84% rename from llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll rename to llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll index 6a4e543209526..54485b7b7f348 100644 --- a/llvm/test/tools/llvm-split/SYCL/device-code-split/split-by-source.ll +++ b/llvm/test/tools/llvm-split/SplitByCategory/split-by-source.ll @@ -1,19 +1,14 @@ -; Test checks that kernels are being split by attached TU metadata and +; Test checks that kernels are being split by attached module-id metadata and ; used functions are being moved with kernels that use them. -; RUN: llvm-split -sycl-split=source -S < %s -o %t +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t ; RUN: FileCheck %s -input-file=%t_0.ll --check-prefixes CHECK-TU0,CHECK ; RUN: FileCheck %s -input-file=%t_1.ll --check-prefixes CHECK-TU1,CHECK -; RUN: FileCheck %s -input-file=%t_0.sym --check-prefixes CHECK-TU0-TXT -; RUN: FileCheck %s -input-file=%t_1.sym --check-prefixes CHECK-TU1-TXT ; CHECK-TU1-NOT: @GV ; CHECK-TU0: @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 @GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4 -; CHECK-TU0-TXT-NOT: TU1_kernelA -; CHECK-TU1-TXT: TU1_kernelA - ; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelA ; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelA define dso_local spir_kernel void @TU1_kernelA() #0 { @@ -39,10 +34,6 @@ entry: ret void } - -; CHECK-TU0-TXT-NOT: TU1_kernelB -; CHECK-TU1-TXT: TU1_kernelB - ; CHECK-TU0-NOT: define dso_local spir_kernel void @TU1_kernelB() ; CHECK-TU1: define dso_local spir_kernel void @TU1_kernelB() define dso_local spir_kernel void @TU1_kernelB() #0 { @@ -80,8 +71,8 @@ entry: ret void } -attributes #0 = { "sycl-module-id"="TU1.cpp" } -attributes #1 = { "sycl-module-id"="TU2.cpp" } +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } ; Metadata is saved in both modules. ; CHECK: !opencl.spir.version = !{!0, !0} diff --git a/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll b/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll new file mode 100644 index 0000000000000..0c1bd8b5c5fba --- /dev/null +++ b/llvm/test/tools/llvm-split/SplitByCategory/split-with-kernel-declarations.ll @@ -0,0 +1,52 @@ +; The test checks that Module splitting does not treat declarations as entry points. + +; RUN: llvm-split -split-by-category=module-id -S < %s -o %t1 +; RUN: FileCheck %s -input-file=%t1_0.ll --check-prefix CHECK-MODULE-ID0 +; RUN: FileCheck %s -input-file=%t1_1.ll --check-prefix CHECK-MODULE-ID1 + +; RUN: llvm-split -split-by-category=kernel -S < %s -o %t2 +; RUN: FileCheck %s -input-file=%t2_0.ll --check-prefix CHECK-PER-KERNEL0 +; RUN: FileCheck %s -input-file=%t2_1.ll --check-prefix CHECK-PER-KERNEL1 +; RUN: FileCheck %s -input-file=%t2_2.ll --check-prefix CHECK-PER-KERNEL2 + +; With module-id split, there should be two modules +; CHECK-MODULE-ID0-NOT: TU0 +; CHECK-MODULE-ID0-NOT: TU1_kernel1 +; CHECK-MODULE-ID0: TU1_kernel0 +; +; CHECK-MODULE-ID1-NOT: TU1 +; CHECK-MODULE-ID1: TU0_kernel0 +; CHECK-MODULE-ID1: TU0_kernel1 + +; With per-kernel split, there should be three modules. +; CHECK-PER-KERNEL0-NOT: TU0 +; CHECK-PER-KERNEL0-NOT: TU1_kernel1 +; CHECK-PER-KERNEL0: TU1_kernel0 +; +; CHECK-PER-KERNEL1-NOT: TU0_kernel0 +; CHECK-PER-KERNEL1-NOT: TU1 +; CHECK-PER-KERNEL1: TU0_kernel1 +; +; CHECK-PER-KERNEL2-NOT: TU0_kernel1 +; CHECK-PER-KERNEL2-NOT: TU1 +; CHECK-PER-KERNEL2: TU0_kernel0 + + +define spir_kernel void @TU0_kernel0() #0 { +entry: + ret void +} + +define spir_kernel void @TU0_kernel1() #0 { +entry: + ret void +} + +define spir_kernel void @TU1_kernel0() #1 { + ret void +} + +declare spir_kernel void @TU1_kernel1() #1 + +attributes #0 = { "module-id"="TU1.cpp" } +attributes #1 = { "module-id"="TU2.cpp" } diff --git a/llvm/tools/llvm-split/CMakeLists.txt b/llvm/tools/llvm-split/CMakeLists.txt index c80ca4aba6ec6..b755755a984fc 100644 --- a/llvm/tools/llvm-split/CMakeLists.txt +++ b/llvm/tools/llvm-split/CMakeLists.txt @@ -7,7 +7,6 @@ set(LLVM_LINK_COMPONENTS BitWriter CodeGen Core - FrontendSYCL IRReader MC Support diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index d79005ebe6cd7..8e53647bc296e 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -33,7 +33,6 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/GlobalDCE.h" -#include "llvm/Transforms/Utils/SYCLUtils.h" #include "llvm/Transforms/Utils/SplitModule.h" #include "llvm/Transforms/Utils/SplitModuleByCategory.h" @@ -78,15 +77,22 @@ static cl::opt MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"), cl::value_desc("cpu"), cl::cat(SplitCategory)); -static cl::opt SYCLSplitMode( - "sycl-split", - cl::desc("SYCL Split Mode. If present, SYCL splitting algorithm is used " - "with the specified mode."), - cl::Optional, cl::init(sycl::IRSplitMode::IRSM_NONE), - cl::values(clEnumValN(sycl::IRSplitMode::IRSM_PER_TU, "source", - "1 ouptput module per translation unit"), - clEnumValN(sycl::IRSplitMode::IRSM_PER_KERNEL, "kernel", - "1 output module per kernel")), +enum class SplitByCategoryType { + SBCT_ByModuleId, + SBCT_ByKernel, + SBCT_None, +}; + +static cl::opt SplitByCategory( + "split-by-category", + cl::desc("Split by category. If present, splitting by category is used " + "with the specified categorization type."), + cl::Optional, cl::init(SplitByCategoryType::SBCT_None), + cl::values(clEnumValN(SplitByCategoryType::SBCT_ByModuleId, "module-id", + "one output module per translation unit marked with " + "\"module-id\" attribute"), + clEnumValN(SplitByCategoryType::SBCT_ByKernel, "kernel", + "one output module per kernel")), cl::cat(SplitCategory)); static cl::opt OutputAssembly{ @@ -119,33 +125,82 @@ void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { WriteBitcodeToFile(M, OS); } -void writeSplitModulesAsTable(ArrayRef Modules, - StringRef Path) { - SmallVector> Columns; - Columns.emplace_back("Code"); - Columns.emplace_back("Symbols"); - - sycl::StringTable Table; - Table.emplace_back(std::move(Columns)); - for (const auto &[I, SM] : enumerate(Modules)) { - SmallString<128> SymbolsFile; - (Twine(Path) + "_" + Twine(I) + ".sym").toVector(SymbolsFile); - writeStringToFile(SM.Symbols, SymbolsFile); - SmallVector> Row; - Row.emplace_back(SM.ModuleFilePath); - Row.emplace_back(SymbolsFile); - Table.emplace_back(std::move(Row)); +/// FunctionCategorizer is used for splitting by category either by module-id or +/// by kernels. It doesn't provide categories for functions other than kernels. +/// Categorizer computes a string key for the given Function and records the +/// association between the string key and an integer category. If a string key +/// is already belongs to some category than the corresponding integer category +/// is returned. +class FunctionCategorizer { +public: + FunctionCategorizer(SplitByCategoryType Type) : Type(Type) {} + + FunctionCategorizer() = delete; + FunctionCategorizer(FunctionCategorizer &) = delete; + FunctionCategorizer &operator=(const FunctionCategorizer &) = delete; + FunctionCategorizer(FunctionCategorizer &&) = default; + FunctionCategorizer &operator=(FunctionCategorizer &&) = default; + + /// Returns integer specifying the category for the given \p F. + /// If the given function isn't a kernel then returns std::nullopt. + std::optional operator()(const Function &F) { + if (!isEntryPoint(F)) + return std::nullopt; // skip the function. + + auto StringKey = computeFunctionCategory(Type, F); + if (auto it = StrKeyToID.find(StringRef(StringKey)); it != StrKeyToID.end()) + return it->second; + + int ID = static_cast(StrKeyToID.size()); + return StrKeyToID.try_emplace(std::move(StringKey), ID).first->second; } - std::error_code EC; - raw_fd_ostream OS((Path + ".table").str(), EC); - if (EC) { - errs() << formatv("error opening file: {0}\n", Path); - exit(1); +private: + static bool isEntryPoint(const Function &F) { + if (F.isDeclaration()) + return false; + + return F.getCallingConv() == CallingConv::SPIR_KERNEL || + F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel; } - sycl::writeStringTable(Table, OS); -} + static SmallString<0> computeFunctionCategory(SplitByCategoryType Type, + const Function &F) { + static constexpr char ATTR_MODULE_ID[] = "module-id"; + SmallString<0> Key; + switch (Type) { + case SplitByCategoryType::SBCT_ByKernel: + Key = F.getName().str(); + break; + case SplitByCategoryType::SBCT_ByModuleId: + Key = F.getFnAttribute(ATTR_MODULE_ID).getValueAsString().str(); + break; + default: + llvm_unreachable("unexpected mode."); + } + + return Key; + } + +private: + struct KeyInfo { + static SmallString<0> getEmptyKey() { return SmallString<0>(""); } + + static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); } + + static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) { + return LHS == RHS; + } + + static unsigned getHashValue(const SmallString<0> &S) { + return llvm::hash_value(StringRef(S)); + } + }; + + SplitByCategoryType Type; + DenseMap, int, KeyInfo> StrKeyToID; +}; void cleanupModule(Module &M) { ModuleAnalysisManager MAM; @@ -155,30 +210,27 @@ void cleanupModule(Module &M) { MPM.run(M, MAM); } -Error runSYCLSplitModule(std::unique_ptr M) { - SmallVector SplitModules; +Error runSplitModuleByCategory(std::unique_ptr M) { + size_t OutputID = 0; auto PostSplitCallback = [&](std::unique_ptr MPart) { if (verifyModule(*MPart)) { errs() << "Broken Module!\n"; exit(1); } - // TODO: DCE is a crucial pass in a SYCL post-link pipeline. + // TODO: DCE is a crucial pass since it removes unused declarations. // At the moment, LIT checking can't be perfomed without DCE. cleanupModule(*MPart); - size_t ID = SplitModules.size(); + size_t ID = OutputID; + ++OutputID; StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc"; std::string ModulePath = (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str(); writeModuleToFile(*MPart, ModulePath, OutputAssembly); - auto Symbols = sycl::makeSymbolTable(*MPart); - SplitModules.emplace_back(std::move(ModulePath), std::move(Symbols)); }; - auto Categorizer = sycl::FunctionCategorizer(SYCLSplitMode); - sycl::SplitModuleByCategory(std::move(M), std::move(Categorizer), - PostSplitCallback); - writeSplitModulesAsTable(SplitModules, OutputFilename); + auto Categorizer = FunctionCategorizer(SplitByCategory); + splitModuleByCategory(std::move(M), Categorizer, PostSplitCallback); return Error::success(); } @@ -235,8 +287,8 @@ int main(int argc, char **argv) { Out->keep(); }; - if (SYCLSplitMode != sycl::IRSplitMode::IRSM_NONE) { - auto E = runSYCLSplitModule(std::move(M)); + if (SplitByCategory != SplitByCategoryType::SBCT_None) { + auto E = runSplitModuleByCategory(std::move(M)); if (E) { errs() << E << "\n"; Err.print(argv[0], errs()); From c249af18492fe8e2b76c27ab7c48a8df653b4b9c Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Wed, 11 Jun 2025 07:10:54 -0700 Subject: [PATCH 7/9] address most of CR feedback --- .../Utils/SplitModuleByCategory.cpp | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp index 932f57a77a7bc..dcc9624088c0b 100644 --- a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -39,10 +39,6 @@ struct EntryPointGroup { EntryPointSet Functions; EntryPointGroup() = default; - EntryPointGroup(const EntryPointGroup &) = default; - EntryPointGroup &operator=(const EntryPointGroup &) = default; - EntryPointGroup(EntryPointGroup &&) = default; - EntryPointGroup &operator=(EntryPointGroup &&) = default; EntryPointGroup(int ID, EntryPointSet Functions = EntryPointSet()) : ID(ID), Functions(std::move(Functions)) {} @@ -70,12 +66,6 @@ class ModuleDesc { EntryPointGroup EntryPoints; public: - ModuleDesc() = delete; - ModuleDesc(const ModuleDesc &) = delete; - ModuleDesc &operator=(const ModuleDesc &) = delete; - ModuleDesc(ModuleDesc &&) = default; - ModuleDesc &operator=(ModuleDesc &&) = default; - ModuleDesc(std::unique_ptr M, EntryPointGroup EntryPoints = EntryPointGroup()) : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { @@ -135,7 +125,7 @@ class DependencyGraph { // Group functions by their signature to handle case (2) described above DenseMap FuncTypeToFuncsMap; - for (const auto &F : M.functions()) { + for (const Function &F : M.functions()) { // Kernels can't be called (either directly or indirectly). if (isKernel(F)) continue; @@ -143,25 +133,25 @@ class DependencyGraph { FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); } - for (const auto &F : M.functions()) { + for (const Function &F : M.functions()) { // case (1), see comment above the class definition for (const Value *U : F.users()) addUserToGraphRecursively(cast(U), &F); // case (2), see comment above the class definition - for (const auto &I : instructions(F)) { - const auto *CI = dyn_cast(&I); - if (!CI || !CI->isIndirectCall()) // Direct calls were handled above + for (const Instruction &I : instructions(F)) { + const CallBase *CB = dyn_cast(&I); + if (!CB || !CB->isIndirectCall()) // Direct calls were handled above continue; - const FunctionType *Signature = CI->getFunctionType(); - const auto &PotentialCallees = FuncTypeToFuncsMap[Signature]; - Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end()); + const FunctionType *Signature = CB->getFunctionType(); + GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature]; + Graph.emplace_or_assign(&F, std::move(PotentialCallees)); } } // And every global variable (but their handling is a bit simpler) - for (const auto &GV : M.globals()) + for (const GlobalVariable &GV : M.globals()) for (const Value *U : GV.users()) addUserToGraphRecursively(cast(U), &GV); } @@ -182,7 +172,7 @@ class DependencyGraph { while (!WorkList.empty()) { const User *U = WorkList.pop_back_val(); if (const auto *I = dyn_cast(U)) { - const auto *UFunc = I->getFunction(); + const Function *UFunc = I->getFunction(); Graph[UFunc].insert(V); } else if (isa(U)) { if (const auto *GV = dyn_cast(U)) @@ -190,10 +180,11 @@ class DependencyGraph { // This could be a global variable or some constant expression (like // bitcast or gep). We trace users of this constant further to reach // global objects they are used by and add them to the graph. - for (const auto *UU : U->users()) + for (const User *UU : U->users()) WorkList.push_back(UU); - } else + } else { llvm_unreachable("Unhandled type of function user"); + } } } @@ -205,11 +196,11 @@ void collectFunctionsAndGlobalVariablesToExtract( SetVector &GVs, const Module &M, const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { // We start with module entry points - for (const auto *F : ModuleEntryPoints.Functions) + for (const Function *F : ModuleEntryPoints.Functions) GVs.insert(F); // Non-discardable global variables are also include into the initial set - for (const auto &GV : M.globals()) + for (const GlobalVariable &GV : M.globals()) if (!GV.isDiscardableIfUnused()) GVs.insert(&GV); @@ -223,8 +214,9 @@ void collectFunctionsAndGlobalVariablesToExtract( if (const auto *Func = dyn_cast(Dep)) { if (!Func->isDeclaration()) GVs.insert(Func); - } else + } else { GVs.insert(Dep); // Global variables are added unconditionally + } } } } @@ -237,13 +229,12 @@ ModuleDesc extractSubModule(const Module &M, // Clone definitions only for needed globals. Others will be added as // declarations and removed later. std::unique_ptr SubM = CloneModule( - M, VMap, [&](const GlobalValue *GV) { return GVs.count(GV); }); + M, VMap, [&](const GlobalValue *GV) { return GVs.contains(GV); }); // Replace entry points with cloned ones. EntryPointSet NewEPs; const EntryPointSet &EPs = ModuleEntryPoints.Functions; - std::for_each(EPs.begin(), EPs.end(), [&](const Function *F) { - NewEPs.insert(cast(VMap[F])); - }); + llvm::for_each( + EPs, [&](const Function *F) { NewEPs.insert(cast(VMap[F])); }); ModuleEntryPoints.Functions = std::move(NewEPs); return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; } @@ -305,15 +296,9 @@ selectEntryPointGroups(const Module &M, // which is based on their contents, this greatly helps LIT tests std::map EntryPointsMap; - for (const auto &F : M.functions()) { - if (auto Category = FC(F); Category) { - auto It = EntryPointsMap.find(*Category); - if (It == EntryPointsMap.end()) - It = EntryPointsMap.emplace(*Category, EntryPointSet()).first; - - It->second.insert(&F); - } - } + for (const auto &F : M.functions()) + if (std::optional Category = FC(F); Category) + EntryPointsMap[*Category].insert(&F); EntryPointGroupVec Groups; Groups.reserve(EntryPointsMap.size()); From 7ad079e69c7af1a072994444c7515bdfb49687ea Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Tue, 1 Jul 2025 06:45:11 -0700 Subject: [PATCH 8/9] change function's name and improve the documentation --- .../Transforms/Utils/SplitModuleByCategory.h | 44 ++++++++++++++----- .../Utils/SplitModuleByCategory.cpp | 34 +++++++------- llvm/tools/llvm-split/llvm-split.cpp | 31 ++++++------- 3 files changed, 63 insertions(+), 46 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h index 8e2a6127df7dd..3142ecc377412 100644 --- a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h +++ b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h @@ -22,21 +22,41 @@ namespace llvm { class Module; class Function; -/// Splits the given module \p M using the given \p FunctionCategorizer. -/// \p FunctionCategorizer returns integer category for an input Function. -/// It may return std::nullopt if a function doesn't have a category. -/// Module's functions are being grouped by categories. Every such group -/// populates a call graph containing group's functions themselves and all -/// reachable functions and globals. Split outputs are populated from each call -/// graph associated with some category. +/// Splits the given module \p M in parts. Every output part is being passed to +/// \p Callback for further possible processing. Every part corresponds to a +/// module's subset that is transitively reachable from some entry point group. +/// Every entry point group is defined by \p EntryPointCategorizer (EPC) as +/// follows: 1) If the function is not an entry point then Categorizer returns +/// std::nullopt. Therefore, the function doesn't belong to any group. However, +/// the function and global objects still can be associated with some output +/// parts if it is transitively used from some entry points. 2) If the function +/// belongs to some entry point group then EPC returns an integer which is an +/// identifier of the group. If two entry point belong to one group then EPC +/// returns exact identifiers for both of them. /// -/// Every split output is being passed to \p Callback for further possible -/// processing. +/// Let A and B be global objects in the module. Transitive dependency relation +/// is defined such that: If global object A is used by global object B in any +/// way (e.g., store, bitcast, phi node, call), then "A" -> "B". Transitivity is +/// defined such that: If "A" -> "B" and "B" -> "C", then "A" -> "C". Examples +/// of dependencies: +/// - Function FA calls function FB +/// - Function FA uses global variable GA +/// - Global variable GA references (initialized with) function FB +/// - Function FA stores address of a function FB somewhere /// -/// Currently, the supported targets are SPIRV, AMDGPU and NVPTX. -void splitModuleByCategory( +/// The following cases are treated as dependencies between global objects: +/// 1. Global object A is used within by a global object B in any way (store, +/// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the +/// graph; +/// 2. function A performs an indirect call of a function with signature S and +/// there is a function B with signature S. "A" -> "B" edge will be added to +/// the graph; +/// +/// FIXME: For now the algorithm supposes no recursion in the input Module. That +/// is going to be fixed in the near future. +void splitModuleTransitiveFromEntryPoints( std::unique_ptr M, - function_ref(const Function &F)> FunctionCategorizer, + function_ref(const Function &F)> EntryPointCategorizer, function_ref Part)> Callback); } // namespace llvm diff --git a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp index dcc9624088c0b..f50f7a8e8a658 100644 --- a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -96,7 +96,7 @@ bool isKernel(const Function &F) { } // Represents "dependency" or "use" graph of global objects (functions and -// global variables) in a module. It is used during device code split to +// global variables) in a module. It is used during code split to // understand which global variables and functions (other than entry points) // should be included into a split module. // @@ -255,12 +255,12 @@ ModuleDesc extractCallGraph(const Module &M, EntryPointGroup ModuleEntryPoints, using EntryPointGroupVec = SmallVector; /// Module Splitter. -/// It gets a module (in a form of module descriptor, to get additional info) -/// and a collection of entry points groups. Each group specifies subset entry -/// points from input module that should be included in a split module. +/// It gets a module and a collection of entry points groups. +/// Each group specifies subset entry points from input module that should be +/// included in a split module. class ModuleSplitter { private: - ModuleDesc Input; + std::unique_ptr M; EntryPointGroupVec Groups; DependencyGraph DG; @@ -273,36 +273,33 @@ class ModuleSplitter { } public: - ModuleSplitter(ModuleDesc MD, EntryPointGroupVec GroupVec) - : Input(std::move(MD)), Groups(std::move(GroupVec)), - DG(Input.getModule()) { + ModuleSplitter(std::unique_ptr Module, EntryPointGroupVec GroupVec) + : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) { assert(!Groups.empty() && "Entry points groups collection is empty!"); } /// Gets next subsequence of entry points in an input module and provides /// split submodule containing these entry points and their dependencies. ModuleDesc getNextSplit() { - return extractCallGraph(Input.getModule(), drawEntryPointGroup(), DG); + return extractCallGraph(*M, drawEntryPointGroup(), DG); } /// Check that there are still submodules to split. bool hasMoreSplits() const { return Groups.size() > 0; } }; -EntryPointGroupVec -selectEntryPointGroups(const Module &M, - function_ref(const Function &F)> FC) { +EntryPointGroupVec selectEntryPointGroups( + const Module &M, function_ref(const Function &F)> EPC) { // std::map is used here to ensure stable ordering of entry point groups, // which is based on their contents, this greatly helps LIT tests std::map EntryPointsMap; for (const auto &F : M.functions()) - if (std::optional Category = FC(F); Category) + if (std::optional Category = EPC(F); Category) EntryPointsMap[*Category].insert(&F); EntryPointGroupVec Groups; Groups.reserve(EntryPointsMap.size()); - // Start with properties of a source module for (auto &[Key, EntryPoints] : EntryPointsMap) Groups.emplace_back(Key, std::move(EntryPoints)); @@ -311,13 +308,12 @@ selectEntryPointGroups(const Module &M, } // namespace -void llvm::splitModuleByCategory( +void llvm::splitModuleTransitiveFromEntryPoints( std::unique_ptr M, - function_ref(const Function &F)> FunctionCategorizer, + function_ref(const Function &F)> EntryPointCategorizer, function_ref Part)> Callback) { - EntryPointGroupVec Groups = selectEntryPointGroups(*M, FunctionCategorizer); - ModuleDesc MD = std::move(M); - ModuleSplitter Splitter(std::move(MD), std::move(Groups)); + EntryPointGroupVec Groups = selectEntryPointGroups(*M, EntryPointCategorizer); + ModuleSplitter Splitter(std::move(M), std::move(Groups)); while (Splitter.hasMoreSplits()) { ModuleDesc MD = Splitter.getNextSplit(); Callback(std::move(MD.releaseModule())); diff --git a/llvm/tools/llvm-split/llvm-split.cpp b/llvm/tools/llvm-split/llvm-split.cpp index 8e53647bc296e..97713c481a71a 100644 --- a/llvm/tools/llvm-split/llvm-split.cpp +++ b/llvm/tools/llvm-split/llvm-split.cpp @@ -125,21 +125,21 @@ void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) { WriteBitcodeToFile(M, OS); } -/// FunctionCategorizer is used for splitting by category either by module-id or -/// by kernels. It doesn't provide categories for functions other than kernels. -/// Categorizer computes a string key for the given Function and records the -/// association between the string key and an integer category. If a string key -/// is already belongs to some category than the corresponding integer category -/// is returned. -class FunctionCategorizer { +/// EntryPointCategorizer is used for splitting by category either by module-id +/// or by kernels. It doesn't provide categories for functions other than +/// kernels. Categorizer computes a string key for the given Function and +/// records the association between the string key and an integer category. If a +/// string key is already belongs to some category than the corresponding +/// integer category is returned. +class EntryPointCategorizer { public: - FunctionCategorizer(SplitByCategoryType Type) : Type(Type) {} + EntryPointCategorizer(SplitByCategoryType Type) : Type(Type) {} - FunctionCategorizer() = delete; - FunctionCategorizer(FunctionCategorizer &) = delete; - FunctionCategorizer &operator=(const FunctionCategorizer &) = delete; - FunctionCategorizer(FunctionCategorizer &&) = default; - FunctionCategorizer &operator=(FunctionCategorizer &&) = default; + EntryPointCategorizer() = delete; + EntryPointCategorizer(EntryPointCategorizer &) = delete; + EntryPointCategorizer &operator=(const EntryPointCategorizer &) = delete; + EntryPointCategorizer(EntryPointCategorizer &&) = default; + EntryPointCategorizer &operator=(EntryPointCategorizer &&) = default; /// Returns integer specifying the category for the given \p F. /// If the given function isn't a kernel then returns std::nullopt. @@ -229,8 +229,9 @@ Error runSplitModuleByCategory(std::unique_ptr M) { writeModuleToFile(*MPart, ModulePath, OutputAssembly); }; - auto Categorizer = FunctionCategorizer(SplitByCategory); - splitModuleByCategory(std::move(M), Categorizer, PostSplitCallback); + auto Categorizer = EntryPointCategorizer(SplitByCategory); + splitModuleTransitiveFromEntryPoints(std::move(M), Categorizer, + PostSplitCallback); return Error::success(); } From 7c96d33fd57617de65575292eee886d07707ef55 Mon Sep 17 00:00:00 2001 From: "Sabianin, Maksim" Date: Tue, 1 Jul 2025 07:23:56 -0700 Subject: [PATCH 9/9] Change some comments --- .../Transforms/Utils/SplitModuleByCategory.h | 50 +++++++++---------- .../Utils/SplitModuleByCategory.cpp | 1 - 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h index 3142ecc377412..b32cfaf7859ab 100644 --- a/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h +++ b/llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h @@ -22,38 +22,38 @@ namespace llvm { class Module; class Function; -/// Splits the given module \p M in parts. Every output part is being passed to -/// \p Callback for further possible processing. Every part corresponds to a -/// module's subset that is transitively reachable from some entry point group. -/// Every entry point group is defined by \p EntryPointCategorizer (EPC) as -/// follows: 1) If the function is not an entry point then Categorizer returns -/// std::nullopt. Therefore, the function doesn't belong to any group. However, -/// the function and global objects still can be associated with some output -/// parts if it is transitively used from some entry points. 2) If the function -/// belongs to some entry point group then EPC returns an integer which is an -/// identifier of the group. If two entry point belong to one group then EPC -/// returns exact identifiers for both of them. +/// Splits the given module \p M into parts. Each output part is passed to +/// \p Callback for further possible processing. Each part corresponds to a +/// subset of the module that is transitively reachable from some entry point +/// group. Each entry point group is defined by \p EntryPointCategorizer (EPC) +/// as follows: 1) If the function is not an entry point, then the Categorizer +/// returns std::nullopt. Therefore, the function doesn't belong to any group. +/// However, the function and global objects can still be associated with some +/// output parts if they are transitively used from some entry points. 2) If the +/// function belongs to an entry point group, then EPC returns an integer which +/// is an identifier of the group. If two entry points belong to one group, then +/// EPC returns the same identifier for both of them. /// -/// Let A and B be global objects in the module. Transitive dependency relation -/// is defined such that: If global object A is used by global object B in any -/// way (e.g., store, bitcast, phi node, call), then "A" -> "B". Transitivity is -/// defined such that: If "A" -> "B" and "B" -> "C", then "A" -> "C". Examples -/// of dependencies: +/// Let A and B be global objects in the module. The transitive dependency +/// relation is defined such that: If global object A is used by global object B +/// in any way (e.g., store, bitcast, phi node, call), then "A" -> "B". +/// Transitivity is defined such that: If "A" -> "B" and "B" -> "C", then "A" -> +/// "C". Examples of dependencies: /// - Function FA calls function FB /// - Function FA uses global variable GA -/// - Global variable GA references (initialized with) function FB -/// - Function FA stores address of a function FB somewhere +/// - Global variable GA references (is initialized with) function FB +/// - Function FA stores the address of function FB somewhere /// /// The following cases are treated as dependencies between global objects: -/// 1. Global object A is used within by a global object B in any way (store, -/// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the +/// 1. Global object A is used by global object B in any way (store, +/// bitcast, phi node, call, etc.): an "A" -> "B" edge will be added to the /// graph; -/// 2. function A performs an indirect call of a function with signature S and -/// there is a function B with signature S. "A" -> "B" edge will be added to -/// the graph; +/// 2. Function A performs an indirect call of a function with signature S, and +/// there is a function B with signature S. An "A" -> "B" edge will be added +/// to the graph; /// -/// FIXME: For now the algorithm supposes no recursion in the input Module. That -/// is going to be fixed in the near future. +/// FIXME: For now, the algorithm assumes no recursion in the input Module. This +/// will be addressed in the near future. void splitModuleTransitiveFromEntryPoints( std::unique_ptr M, function_ref(const Function &F)> EntryPointCategorizer, diff --git a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp index f50f7a8e8a658..fad0af0088cc9 100644 --- a/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp +++ b/llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp @@ -224,7 +224,6 @@ void collectFunctionsAndGlobalVariablesToExtract( ModuleDesc extractSubModule(const Module &M, const SetVector &GVs, EntryPointGroup ModuleEntryPoints) { - // For each group of entry points collect all dependencies. ValueToValueMapTy VMap; // Clone definitions only for needed globals. Others will be added as // declarations and removed later.