Skip to content

Commit cc966df

Browse files
authored
[SYCL] Bring back RTC support for AMD and Nvidia GPU targets (#19342)
This PR brings back #18918 and #19302, and fixes the issue with shared library builds. The problem was that we accessed hidden symbols defined in headers from the `clang/lib` directory to obtain paths to the vendor-specific device library files. We now use the `ToolChain::getDeviceLibs` API, and supply a minimal implementation for the `CudaToolchain`. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent 4d6499a commit cc966df

34 files changed

+442
-146
lines changed

clang/lib/Driver/ToolChains/Cuda.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,20 @@ void CudaToolChain::AddIAMCUIncludeArgs(const ArgList &Args,
11621162
HostTC.AddIAMCUIncludeArgs(Args, CC1Args);
11631163
}
11641164

1165+
llvm::SmallVector<ToolChain::BitCodeLibraryInfo, 12>
1166+
CudaToolChain::getDeviceLibs(
1167+
const llvm::opt::ArgList &DriverArgs,
1168+
const Action::OffloadKind DeviceOffloadingKind) const {
1169+
StringRef GpuArch = DriverArgs.getLastArgValue(options::OPT_march_EQ);
1170+
std::string LibDeviceFile = CudaInstallation.getLibDeviceFile(GpuArch);
1171+
if (LibDeviceFile.empty()) {
1172+
getDriver().Diag(diag::err_drv_no_cuda_libdevice) << GpuArch;
1173+
return {};
1174+
}
1175+
1176+
return {BitCodeLibraryInfo{LibDeviceFile}};
1177+
}
1178+
11651179
SanitizerMask CudaToolChain::getSupportedSanitizers() const {
11661180
// The CudaToolChain only supports sanitizers in the sense that it allows
11671181
// sanitizer arguments on the command line if they are supported by the host

clang/lib/Driver/ToolChains/Cuda.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ class LLVM_LIBRARY_VISIBILITY CudaToolChain : public NVPTXToolChain {
248248
void AddIAMCUIncludeArgs(const llvm::opt::ArgList &DriverArgs,
249249
llvm::opt::ArgStringList &CC1Args) const override;
250250

251+
llvm::SmallVector<BitCodeLibraryInfo, 12>
252+
getDeviceLibs(const llvm::opt::ArgList &Args,
253+
const Action::OffloadKind DeviceOffloadingKind) const override;
254+
251255
SanitizerMask getSupportedSanitizers() const override;
252256

253257
VersionTuple

sycl-jit/jit-compiler/include/RTC.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,11 @@ class RTCResult {
176176

177177
/// Calculates a BLAKE3 hash of the pre-processed source string described by
178178
/// \p SourceFile (considering any additional \p IncludeFiles) and the
179-
/// concatenation of the \p UserArgs.
179+
/// concatenation of the \p UserArgs, for a given \p Format.
180180
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
181181
View<InMemoryFile> IncludeFiles,
182-
View<const char *> UserArgs);
182+
View<const char *> UserArgs,
183+
BinaryFormat Format);
183184

184185
/// Compiles, links against device libraries, and finalizes the device code in
185186
/// the source string described by \p SourceFile, considering any additional \p
@@ -191,10 +192,14 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
191192
///
192193
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
193194
/// the frontend invocation is wrapped in bitcode format in the result object.
195+
///
196+
/// \p BinaryFormat describes the desired format of the compilation - which
197+
/// corresponds to the backend that is being targeted.
194198
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
195199
View<InMemoryFile> IncludeFiles,
196200
View<const char *> UserArgs,
197-
View<char> CachedIR, bool SaveIR);
201+
View<char> CachedIR, bool SaveIR,
202+
BinaryFormat Format);
198203

199204
/// Requests that the JIT binary referenced by \p Address is deleted from the
200205
/// `JITContext`.

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88

99
#include "DeviceCompilation.h"
1010
#include "ESIMD.h"
11+
#include "JITBinaryInfo.h"
12+
#include "translation/Translation.h"
1113

1214
#include <clang/Basic/DiagnosticDriver.h>
1315
#include <clang/Basic/Version.h>
1416
#include <clang/CodeGen/CodeGenAction.h>
1517
#include <clang/Driver/Compilation.h>
18+
#include <clang/Driver/Driver.h>
1619
#include <clang/Driver/Options.h>
20+
#include <clang/Driver/ToolChain.h>
1721
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
1822
#include <clang/Frontend/CompilerInstance.h>
1923
#include <clang/Frontend/FrontendActions.h>
@@ -52,6 +56,7 @@ using namespace llvm::opt;
5256
using namespace llvm::sycl;
5357
using namespace llvm::module_split;
5458
using namespace llvm::util;
59+
using namespace llvm::vfs;
5560
using namespace jit_compiler;
5661

5762
#ifdef _GNU_SOURCE
@@ -313,7 +318,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
313318
} // anonymous namespace
314319

315320
static void adjustArgs(const InputArgList &UserArgList,
316-
const std::string &DPCPPRoot,
321+
const std::string &DPCPPRoot, BinaryFormat Format,
317322
SmallVectorImpl<std::string> &CommandLine) {
318323
DerivedArgList DAL{UserArgList};
319324
const auto &OptTable = getDriverOptTable();
@@ -326,6 +331,17 @@ static void adjustArgs(const InputArgList &UserArgList,
326331
// unused argument warning.
327332
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));
328333

334+
if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
335+
auto [CPU, Features] =
336+
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
337+
(void)Features;
338+
StringRef OT = Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda"
339+
: "amdgcn-amd-amdhsa";
340+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ), OT);
341+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ), OT);
342+
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
343+
}
344+
329345
ArgStringList ASL;
330346
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
331347
for_each(UserArgList,
@@ -362,10 +378,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
362378
});
363379
}
364380

365-
Expected<std::string>
366-
jit_compiler::calculateHash(InMemoryFile SourceFile,
367-
View<InMemoryFile> IncludeFiles,
368-
const InputArgList &UserArgList) {
381+
Expected<std::string> jit_compiler::calculateHash(
382+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
383+
const InputArgList &UserArgList, BinaryFormat Format) {
369384
TimeTraceScope TTS{"calculateHash"};
370385

371386
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -374,7 +389,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
374389
}
375390

376391
SmallVector<std::string> CommandLine;
377-
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
392+
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
378393

379394
FixedCompilationDatabase DB{".", CommandLine};
380395
ClangTool Tool{DB, {SourceFile.Path}};
@@ -400,11 +415,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
400415
return createStringError("Calculating source hash failed");
401416
}
402417

403-
Expected<ModuleUPtr>
404-
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
405-
View<InMemoryFile> IncludeFiles,
406-
const InputArgList &UserArgList,
407-
std::string &BuildLog, LLVMContext &Context) {
418+
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
419+
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
420+
const InputArgList &UserArgList, std::string &BuildLog,
421+
LLVMContext &Context, BinaryFormat Format) {
408422
TimeTraceScope TTS{"compileDeviceCode"};
409423

410424
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -413,7 +427,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
413427
}
414428

415429
SmallVector<std::string> CommandLine;
416-
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
430+
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
417431

418432
FixedCompilationDatabase DB{".", CommandLine};
419433
ClangTool Tool{DB, {SourceFile.Path}};
@@ -431,12 +445,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
431445
return createStringError(BuildLog);
432446
}
433447

434-
// This function is a simplified copy of the device library selection process in
435-
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
436-
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
448+
// This function is a simplified copy of the device library selection process
449+
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
450+
// GPU targets (no AoT, no native CPU). Keep in sync!
437451
static bool getDeviceLibraries(const ArgList &Args,
438452
SmallVectorImpl<std::string> &LibraryList,
439-
DiagnosticsEngine &Diags) {
453+
DiagnosticsEngine &Diags, BinaryFormat Format) {
454+
// For CUDA/HIP we only need devicelib, early exit here.
455+
if (Format == BinaryFormat::PTX) {
456+
LibraryList.push_back(
457+
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
458+
return false;
459+
} else if (Format == BinaryFormat::AMDGCN) {
460+
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
461+
return false;
462+
}
463+
440464
struct DeviceLibOptInfo {
441465
StringRef DeviceLibName;
442466
StringRef DeviceLibOption;
@@ -541,7 +565,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
541565

542566
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
543567
const InputArgList &UserArgList,
544-
std::string &BuildLog) {
568+
std::string &BuildLog,
569+
BinaryFormat Format) {
545570
TimeTraceScope TTS{"linkDeviceLibraries"};
546571

547572
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -556,11 +581,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
556581
/* ShouldOwnClient=*/false);
557582

558583
SmallVector<std::string> LibNames;
559-
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
584+
const bool FoundUnknownLib =
585+
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
560586
if (FoundUnknownLib) {
561587
return createStringError("Could not determine list of device libraries: %s",
562588
BuildLog.c_str());
563589
}
590+
const bool IsCudaHIP =
591+
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
592+
if (IsCudaHIP) {
593+
// Based on the OS and the format decide on the version of libspirv.
594+
// NOTE: this will be problematic if cross-compiling between OSes.
595+
std::string Libclc{"clc/"};
596+
Libclc.append(
597+
#ifdef _WIN32
598+
"remangled-l32-signed_char.libspirv-"
599+
#else
600+
"remangled-l64-signed_char.libspirv-"
601+
#endif
602+
);
603+
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
604+
: "amdgcn-amd-amdhsa.bc");
605+
LibNames.push_back(Libclc);
606+
}
564607

565608
LLVMContext &Context = Module.getContext();
566609
for (const std::string &LibName : LibNames) {
@@ -578,6 +621,72 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
578621
}
579622
}
580623

624+
// For GPU targets we need to link against vendor provided libdevice.
625+
if (IsCudaHIP) {
626+
std::string Argv0 = DPCPPRoot + "/bin/clang++";
627+
Triple T{Module.getTargetTriple()};
628+
IntrusiveRefCntPtr<OverlayFileSystem> OFS{
629+
new OverlayFileSystem{getRealFileSystem()}};
630+
IntrusiveRefCntPtr<InMemoryFileSystem> VFS{new InMemoryFileSystem};
631+
std::string CppFileName{"a.cpp"};
632+
VFS->addFile(CppFileName, /*ModificationTime=*/0,
633+
MemoryBuffer::getMemBuffer("", ""));
634+
OFS->pushOverlay(VFS);
635+
Driver D{Argv0, T.getTriple(), Diags, "dpcpp compiler driver", OFS};
636+
637+
SmallVector<std::string> CommandLine;
638+
CommandLine.push_back(Argv0);
639+
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
640+
CommandLine.push_back(CppFileName);
641+
SmallVector<const char *> CommandLineCStr(CommandLine.size());
642+
llvm::transform(CommandLine, CommandLineCStr.begin(),
643+
[](const auto &S) { return S.c_str(); });
644+
645+
Compilation *C = D.BuildCompilation(CommandLineCStr);
646+
if (!C) {
647+
return createStringError("Unable to construct driver for CUDA/HIP");
648+
}
649+
650+
const ToolChain *OffloadTC =
651+
C->getSingleOffloadToolChain<Action::OFK_SYCL>();
652+
InputArgList EmptyArgList;
653+
auto Archs =
654+
D.getOffloadArchs(*C, EmptyArgList, Action::OFK_SYCL, OffloadTC);
655+
assert(Archs.size() == 1 &&
656+
"Offload toolchain should be configured to single architecture");
657+
StringRef CPU = *Archs.begin();
658+
659+
// Pass only `-march=` or `-mcpu=` with the GPU arch determined by the
660+
// driver to `getDeviceLibs`.
661+
DerivedArgList CPUArgList{EmptyArgList};
662+
if (Format == BinaryFormat::PTX) {
663+
CPUArgList.AddJoinedArg(nullptr, D.getOpts().getOption(OPT_march_EQ),
664+
CPU);
665+
} else {
666+
CPUArgList.AddJoinedArg(nullptr, D.getOpts().getOption(OPT_mcpu_EQ), CPU);
667+
}
668+
669+
SmallVector<ToolChain::BitCodeLibraryInfo, 12> CommonDeviceLibs =
670+
OffloadTC->getDeviceLibs(CPUArgList, Action::OffloadKind::OFK_SYCL);
671+
if (CommonDeviceLibs.empty()) {
672+
return createStringError("Unable to find common device libraries");
673+
}
674+
675+
for (auto &Lib : CommonDeviceLibs) {
676+
ModuleUPtr LibModule;
677+
if (auto Error =
678+
loadBitcodeLibrary(Lib.Path, Context).moveInto(LibModule)) {
679+
return Error;
680+
}
681+
682+
if (Linker::linkModules(Module, std::move(LibModule),
683+
Linker::LinkOnlyNeeded)) {
684+
return createStringError("Unable to link device library %s: %s",
685+
Lib.Path.c_str(), BuildLog.c_str());
686+
}
687+
}
688+
}
689+
581690
return Error::success();
582691
}
583692

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "JITBinaryInfo.h"
1112
#include "RTC.h"
1213

1314
#include <llvm/ADT/SmallVector.h>
@@ -24,16 +25,17 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;
2425

2526
llvm::Expected<std::string>
2627
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
27-
const llvm::opt::InputArgList &UserArgList);
28+
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);
2829

2930
llvm::Expected<ModuleUPtr>
3031
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
3132
const llvm::opt::InputArgList &UserArgList,
32-
std::string &BuildLog, llvm::LLVMContext &Context);
33+
std::string &BuildLog, llvm::LLVMContext &Context,
34+
BinaryFormat Format);
3335

3436
llvm::Error linkDeviceLibraries(llvm::Module &Module,
3537
const llvm::opt::InputArgList &UserArgList,
36-
std::string &BuildLog);
38+
std::string &BuildLog, BinaryFormat Format);
3739

3840
using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
3941
llvm::Expected<PostLinkResult>

0 commit comments

Comments
 (0)