Skip to content

Commit e7e45b6

Browse files
committed
[SYCL] RTC support for GPU targets
This patch extends RTC support to GPU (AMD and Nvidia) targets. Additionally: * reinstate __SYCL_PROGRAM_METADATA_TAG_NEED_FINALIZATION tag, * split sycl.cpp RTC file to exclude IMF from the body of the main test.
1 parent 57bdbe3 commit e7e45b6

File tree

15 files changed

+365
-119
lines changed

15 files changed

+365
-119
lines changed

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ target_include_directories(sycl-jit
6060
${LLVM_MAIN_INCLUDE_DIR}
6161
${LLVM_SPIRV_INCLUDE_DIRS}
6262
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
63+
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/lib
6364
${CMAKE_BINARY_DIR}/tools/clang/include
6465
)
6566
target_include_directories(sycl-jit

sycl-jit/jit-compiler/include/RTC.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,11 @@ class RTCResult {
176176

177177
/// Calculates a BLAKE3 hash of the pre-processed source string described by
178178
/// \p SourceFile (considering any additional \p IncludeFiles) and the
179-
/// concatenation of the \p UserArgs.
179+
/// concatenation of the \p UserArgs, for a given \p Format.
180180
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
181181
View<InMemoryFile> IncludeFiles,
182-
View<const char *> UserArgs);
182+
View<const char *> UserArgs,
183+
BinaryFormat Format);
183184

184185
/// Compiles, links against device libraries, and finalizes the device code in
185186
/// the source string described by \p SourceFile, considering any additional \p
@@ -191,10 +192,14 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
191192
///
192193
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
193194
/// the frontend invocation is wrapped in bitcode format in the result object.
195+
///
196+
/// \p BinaryFormat describes the desired format of the compilation - which
197+
/// corresponds to the backend that is being targeted.
194198
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
195199
View<InMemoryFile> IncludeFiles,
196200
View<const char *> UserArgs,
197-
View<char> CachedIR, bool SaveIR);
201+
View<char> CachedIR, bool SaveIR,
202+
BinaryFormat Format);
198203

199204
/// Requests that the JIT binary referenced by \p Address is deleted from the
200205
/// `JITContext`.

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 127 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "DeviceCompilation.h"
1010
#include "ESIMD.h"
11+
#include "JITBinaryInfo.h"
12+
#include "translation/Translation.h"
1113

1214
#include <clang/Basic/DiagnosticDriver.h>
1315
#include <clang/Basic/Version.h>
@@ -22,6 +24,15 @@
2224
#include <clang/Frontend/Utils.h>
2325
#include <clang/Tooling/CompilationDatabase.h>
2426
#include <clang/Tooling/Tooling.h>
27+
#if defined(JIT_SUPPORT_PTX) || defined(JIT_SUPPORT_AMDGCN)
28+
#include <clang/Driver/Driver.h>
29+
#endif
30+
#ifdef JIT_SUPPORT_PTX
31+
#include <Driver/ToolChains/Cuda.h>
32+
#include <Driver/ToolChains/LazyDetector.h>
33+
#elif JIT_SUPPORT_AMDGCN
34+
#include <Driver/ToolChains/AMDGPU.h>
35+
#endif
2536

2637
#include <llvm/IR/DiagnosticInfo.h>
2738
#include <llvm/IR/DiagnosticPrinter.h>
@@ -178,7 +189,8 @@ class RTCToolActionBase : public ToolAction {
178189
assert(!hasExecuted() && "Action should only be invoked on a single file");
179190

180191
// Create a compiler instance to handle the actual work.
181-
CompilerInstance Compiler(std::move(Invocation), std::move(PCHContainerOps));
192+
CompilerInstance Compiler(std::move(Invocation),
193+
std::move(PCHContainerOps));
182194
Compiler.setFileManager(Files);
183195
// Suppress summary with number of warnings and errors being printed to
184196
// stdout.
@@ -361,10 +373,24 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361373
});
362374
}
363375

364-
Expected<std::string>
365-
jit_compiler::calculateHash(InMemoryFile SourceFile,
366-
View<InMemoryFile> IncludeFiles,
367-
const InputArgList &UserArgList) {
376+
static void setGPUTarget(BinaryFormat Format,
377+
SmallVector<std::string> &CommandLine) {
378+
auto [CPU, _] = Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
379+
CommandLine.push_back("-fsycl");
380+
if (Format == BinaryFormat::PTX) {
381+
CommandLine.push_back("-fsycl-targets=nvptx64-nvidia-cuda");
382+
CommandLine.push_back("-Xsycl-target-backend");
383+
CommandLine.push_back("--cuda-gpu-arch=" + CPU);
384+
} else if (Format == BinaryFormat::AMDGCN) {
385+
CommandLine.push_back("-fsycl-targets=amdgcn-amd-amdhsa");
386+
CommandLine.push_back("-Xsycl-target-backend=amdgcn-amd-amdhsa");
387+
CommandLine.push_back("--offload-arch=" + CPU);
388+
}
389+
}
390+
391+
Expected<std::string> jit_compiler::calculateHash(
392+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
393+
const InputArgList &UserArgList, BinaryFormat Format) {
368394
TimeTraceScope TTS{"calculateHash"};
369395

370396
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -373,6 +399,9 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373399
}
374400

375401
SmallVector<std::string> CommandLine;
402+
if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
403+
setGPUTarget(Format, CommandLine);
404+
}
376405
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
377406

378407
FixedCompilationDatabase DB{".", CommandLine};
@@ -399,11 +428,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399428
return createStringError("Calculating source hash failed");
400429
}
401430

402-
Expected<ModuleUPtr>
403-
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
404-
View<InMemoryFile> IncludeFiles,
405-
const InputArgList &UserArgList,
406-
std::string &BuildLog, LLVMContext &Context) {
431+
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
432+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
433+
const InputArgList &UserArgList, std::string &BuildLog,
434+
LLVMContext &Context, BinaryFormat Format) {
407435
TimeTraceScope TTS{"compileDeviceCode"};
408436

409437
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -412,6 +440,9 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412440
}
413441

414442
SmallVector<std::string> CommandLine;
443+
if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
444+
setGPUTarget(Format, CommandLine);
445+
}
415446
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
416447

417448
FixedCompilationDatabase DB{".", CommandLine};
@@ -430,12 +461,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430461
return createStringError(BuildLog);
431462
}
432463

433-
// This function is a simplified copy of the device library selection process in
434-
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435-
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
464+
// This function is a simplified copy of the device library selection process
465+
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
466+
// GPU targets (no AoT, no native CPU). Keep in sync!
436467
static bool getDeviceLibraries(const ArgList &Args,
437468
SmallVectorImpl<std::string> &LibraryList,
438-
DiagnosticsEngine &Diags) {
469+
DiagnosticsEngine &Diags, BinaryFormat Format) {
470+
// For CUDA/HIP we only need devicelib, early exit here.
471+
if (Format == BinaryFormat::PTX) {
472+
LibraryList.push_back(
473+
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
474+
return false;
475+
} else if (Format == BinaryFormat::AMDGCN) {
476+
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
477+
return false;
478+
}
479+
439480
struct DeviceLibOptInfo {
440481
StringRef DeviceLibName;
441482
StringRef DeviceLibOption;
@@ -540,7 +581,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540581

541582
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
542583
const InputArgList &UserArgList,
543-
std::string &BuildLog) {
584+
std::string &BuildLog,
585+
BinaryFormat Format) {
544586
TimeTraceScope TTS{"linkDeviceLibraries"};
545587

546588
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -555,11 +597,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555597
/* ShouldOwnClient=*/false);
556598

557599
SmallVector<std::string> LibNames;
558-
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
600+
const bool FoundUnknownLib =
601+
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
559602
if (FoundUnknownLib) {
560603
return createStringError("Could not determine list of device libraries: %s",
561604
BuildLog.c_str());
562605
}
606+
const bool IsGPUTarget =
607+
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
608+
if (IsGPUTarget) {
609+
// Based on the OS and the format decide on the version of libspirv.
610+
// NOTE: this will be problematic if cross-compiling between OSes.
611+
std::string Libclc{"clc/"};
612+
Libclc.append(
613+
#ifdef _WIN32
614+
"remangled-l32-signed_char.libspirv-"
615+
#else
616+
"remangled-l64-signed_char.libspirv-"
617+
#endif
618+
);
619+
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
620+
: "amdgcn-amd-amdhsa.bc");
621+
LibNames.push_back(Libclc);
622+
}
563623

564624
LLVMContext &Context = Module.getContext();
565625
for (const std::string &LibName : LibNames) {
@@ -577,6 +637,57 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577637
}
578638
}
579639

640+
// For GPU targets we need to link against vendor provided libdevice.
641+
if (IsGPUTarget) {
642+
Triple T{Module.getTargetTriple()};
643+
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
644+
auto [CPU, _] =
645+
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
646+
// Helper lambda to link modules.
647+
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
648+
ModuleUPtr LibDeviceModule;
649+
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
650+
.moveInto(LibDeviceModule)) {
651+
return Error;
652+
}
653+
if (Linker::linkModules(Module, std::move(LibDeviceModule),
654+
Linker::LinkOnlyNeeded)) {
655+
return createStringError("Unable to link libdevice: %s",
656+
BuildLog.c_str());
657+
}
658+
return Error::success();
659+
};
660+
SmallVector<std::string, 12> LibDeviceFiles;
661+
#ifdef JIT_SUPPORT_PTX
662+
// For NVPTX we can get away with CudaInstallationDetector.
663+
LazyDetector<CudaInstallationDetector> CudaInstallation{D, T, UserArgList};
664+
auto LibDevice = CudaInstallation->getLibDeviceFile(CPU);
665+
if (LibDevice.empty()) {
666+
return createStringError("Unable to find Cuda libdevice");
667+
}
668+
LibDeviceFiles.push_back(LibDevice);
669+
#elif JIT_SUPPORT_AMDGCN
670+
// AMDGPU requires entire toolchain in order to provide all common bitcode
671+
// libraries.
672+
clang::driver::toolchains::ROCMToolChain TC(D, T, UserArgList);
673+
auto CommonDeviceLibs = TC.getCommonDeviceLibNames(
674+
UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false);
675+
if (CommonDeviceLibs.empty()) {
676+
return createStringError("Unable to find ROCm common device libraries");
677+
}
678+
for (auto &Lib : CommonDeviceLibs) {
679+
LibDeviceFiles.push_back(Lib.Path);
680+
}
681+
#endif
682+
for (auto &LibDeviceFile : LibDeviceFiles) {
683+
auto Res = LinkInLib(LibDeviceFile);
684+
// llvm::Error converts to false on success.
685+
if (Res) {
686+
return Res;
687+
}
688+
}
689+
}
690+
580691
return Error::success();
581692
}
582693

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "JITBinaryInfo.h"
1112
#include "RTC.h"
1213

1314
#include <llvm/ADT/SmallVector.h>
@@ -24,16 +25,17 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;
2425

2526
llvm::Expected<std::string>
2627
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
27-
const llvm::opt::InputArgList &UserArgList);
28+
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);
2829

2930
llvm::Expected<ModuleUPtr>
3031
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
3132
const llvm::opt::InputArgList &UserArgList,
32-
std::string &BuildLog, llvm::LLVMContext &Context);
33+
std::string &BuildLog, llvm::LLVMContext &Context,
34+
BinaryFormat Format);
3335

3436
llvm::Error linkDeviceLibraries(llvm::Module &Module,
3537
const llvm::opt::InputArgList &UserArgList,
36-
std::string &BuildLog);
38+
std::string &BuildLog, BinaryFormat Format);
3739

3840
using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
3941
llvm::Expected<PostLinkResult>

sycl-jit/jit-compiler/lib/rtc/RTC.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "RTC.h"
10+
#include "JITBinaryInfo.h"
1011
#include "helper/ErrorHelper.h"
1112
#include "rtc/DeviceCompilation.h"
1213
#include "translation/SPIRVLLVMTranslation.h"
@@ -26,7 +27,8 @@ using namespace jit_compiler;
2627

2728
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
2829
View<InMemoryFile> IncludeFiles,
29-
View<const char *> UserArgs) {
30+
View<const char *> UserArgs,
31+
BinaryFormat Format) {
3032
llvm::opt::InputArgList UserArgList;
3133
if (auto Error = parseUserArgs(UserArgs).moveInto(UserArgList)) {
3234
return errorTo<RTCHashResult>(std::move(Error),
@@ -36,8 +38,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
3638

3739
auto Start = std::chrono::high_resolution_clock::now();
3840
std::string Hash;
39-
if (auto Error =
40-
calculateHash(SourceFile, IncludeFiles, UserArgList).moveInto(Hash)) {
41+
if (auto Error = calculateHash(SourceFile, IncludeFiles, UserArgList, Format)
42+
.moveInto(Hash)) {
4143
return errorTo<RTCHashResult>(std::move(Error), "Hashing failed",
4244
/*IsHash=*/false);
4345
}
@@ -55,7 +57,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
5557
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
5658
View<InMemoryFile> IncludeFiles,
5759
View<const char *> UserArgs,
58-
View<char> CachedIR, bool SaveIR) {
60+
View<char> CachedIR, bool SaveIR,
61+
BinaryFormat Format) {
5962
llvm::LLVMContext Context;
6063
std::string BuildLog;
6164
configureDiagnostics(Context, BuildLog);
@@ -104,7 +107,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
104107
bool FromSource = !Module;
105108
if (FromSource) {
106109
if (auto Error = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
107-
BuildLog, Context)
110+
BuildLog, Context, Format)
108111
.moveInto(Module)) {
109112
return errorTo<RTCResult>(std::move(Error), "Device compilation failed");
110113
}
@@ -118,7 +121,8 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
118121
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
119122
}
120123

121-
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
124+
if (auto Error =
125+
linkDeviceLibraries(*Module, UserArgList, BuildLog, Format)) {
122126
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
123127
}
124128

@@ -131,9 +135,9 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
131135

132136
for (auto [DevImgInfo, Module] :
133137
llvm::zip_equal(BundleInfo.DevImgInfos, Modules)) {
134-
if (auto Error = Translator::translate(*Module, JITContext::getInstance(),
135-
BinaryFormat::SPIRV)
136-
.moveInto(DevImgInfo.BinaryInfo)) {
138+
if (auto Error =
139+
Translator::translate(*Module, JITContext::getInstance(), Format)
140+
.moveInto(DevImgInfo.BinaryInfo)) {
137141
return errorTo<RTCResult>(std::move(Error), "SPIR-V translation failed");
138142
}
139143
}

0 commit comments

Comments
 (0)