Skip to content

[SYCL] Bring back RTC support for AMD and Nvidia GPU targets #19342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 16, 2025
14 changes: 14 additions & 0 deletions clang/lib/Driver/ToolChains/Cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,20 @@ void CudaToolChain::AddIAMCUIncludeArgs(const ArgList &Args,
HostTC.AddIAMCUIncludeArgs(Args, CC1Args);
}

llvm::SmallVector<ToolChain::BitCodeLibraryInfo, 12>
CudaToolChain::getDeviceLibs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind ping @intel/dpcpp-clang-driver-reviewers -- any concerns about adding this method?

const llvm::opt::ArgList &DriverArgs,
const Action::OffloadKind DeviceOffloadingKind) const {
StringRef GpuArch = DriverArgs.getLastArgValue(options::OPT_march_EQ);
std::string LibDeviceFile = CudaInstallation.getLibDeviceFile(GpuArch);
if (LibDeviceFile.empty()) {
getDriver().Diag(diag::err_drv_no_cuda_libdevice) << GpuArch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -march guaranteed to be populated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this PR also adds the first and only callsite for this method, and we unconditionally set -march in the ArgList. HIPAMDToolChain::getDeviceLibs makes a similar assumption that -mcpu is set.

return {};
}

return {BitCodeLibraryInfo{LibDeviceFile}};
}

SanitizerMask CudaToolChain::getSupportedSanitizers() const {
// The CudaToolChain only supports sanitizers in the sense that it allows
// sanitizer arguments on the command line if they are supported by the host
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Driver/ToolChains/Cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ class LLVM_LIBRARY_VISIBILITY CudaToolChain : public NVPTXToolChain {
void AddIAMCUIncludeArgs(const llvm::opt::ArgList &DriverArgs,
llvm::opt::ArgStringList &CC1Args) const override;

llvm::SmallVector<BitCodeLibraryInfo, 12>
getDeviceLibs(const llvm::opt::ArgList &Args,
const Action::OffloadKind DeviceOffloadingKind) const override;

SanitizerMask getSupportedSanitizers() const override;

VersionTuple
Expand Down
11 changes: 8 additions & 3 deletions sycl-jit/jit-compiler/include/RTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,11 @@ class RTCResult {

/// Calculates a BLAKE3 hash of the pre-processed source string described by
/// \p SourceFile (considering any additional \p IncludeFiles) and the
/// concatenation of the \p UserArgs.
/// concatenation of the \p UserArgs, for a given \p Format.
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);
View<const char *> UserArgs,
BinaryFormat Format);

/// Compiles, links against device libraries, and finalizes the device code in
/// the source string described by \p SourceFile, considering any additional \p
Expand All @@ -191,10 +192,14 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
///
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
/// the frontend invocation is wrapped in bitcode format in the result object.
///
/// \p BinaryFormat describes the desired format of the compilation - which
/// corresponds to the backend that is being targeted.
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR);
View<char> CachedIR, bool SaveIR,
BinaryFormat Format);

/// Requests that the JIT binary referenced by \p Address is deleted from the
/// `JITContext`.
Expand Down
145 changes: 127 additions & 18 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

#include "DeviceCompilation.h"
#include "ESIMD.h"
#include "JITBinaryInfo.h"
#include "translation/Translation.h"

#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Driver/Driver.h>
#include <clang/Driver/Options.h>
#include <clang/Driver/ToolChain.h>
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/FrontendActions.h>
Expand Down Expand Up @@ -52,6 +56,7 @@ using namespace llvm::opt;
using namespace llvm::sycl;
using namespace llvm::module_split;
using namespace llvm::util;
using namespace llvm::vfs;
using namespace jit_compiler;

#ifdef _GNU_SOURCE
Expand Down Expand Up @@ -313,7 +318,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
} // anonymous namespace

static void adjustArgs(const InputArgList &UserArgList,
const std::string &DPCPPRoot,
const std::string &DPCPPRoot, BinaryFormat Format,
SmallVectorImpl<std::string> &CommandLine) {
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
Expand All @@ -326,6 +331,17 @@ static void adjustArgs(const InputArgList &UserArgList,
// unused argument warning.
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));

if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
(void)Features;
StringRef OT = Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda"
: "amdgcn-amd-amdhsa";
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
}

ArgStringList ASL;
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
for_each(UserArgList,
Expand Down Expand Up @@ -362,10 +378,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
});
}

Expected<std::string>
jit_compiler::calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList) {
Expected<std::string> jit_compiler::calculateHash(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, BinaryFormat Format) {
TimeTraceScope TTS{"calculateHash"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -374,7 +389,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -400,11 +415,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
return createStringError("Calculating source hash failed");
}

Expected<ModuleUPtr>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList,
std::string &BuildLog, LLVMContext &Context) {
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, std::string &BuildLog,
LLVMContext &Context, BinaryFormat Format) {
TimeTraceScope TTS{"compileDeviceCode"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -413,7 +427,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -431,12 +445,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
return createStringError(BuildLog);
}

// This function is a simplified copy of the device library selection process in
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
// This function is a simplified copy of the device library selection process
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
// GPU targets (no AoT, no native CPU). Keep in sync!
static bool getDeviceLibraries(const ArgList &Args,
SmallVectorImpl<std::string> &LibraryList,
DiagnosticsEngine &Diags) {
DiagnosticsEngine &Diags, BinaryFormat Format) {
// For CUDA/HIP we only need devicelib, early exit here.
if (Format == BinaryFormat::PTX) {
LibraryList.push_back(
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
return false;
} else if (Format == BinaryFormat::AMDGCN) {
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
return false;
}

struct DeviceLibOptInfo {
StringRef DeviceLibName;
StringRef DeviceLibOption;
Expand Down Expand Up @@ -541,7 +565,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList,
std::string &BuildLog) {
std::string &BuildLog,
BinaryFormat Format) {
TimeTraceScope TTS{"linkDeviceLibraries"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -556,11 +581,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
/* ShouldOwnClient=*/false);

SmallVector<std::string> LibNames;
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
const bool FoundUnknownLib =
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
if (FoundUnknownLib) {
return createStringError("Could not determine list of device libraries: %s",
BuildLog.c_str());
}
const bool IsCudaHIP =
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
if (IsCudaHIP) {
// Based on the OS and the format decide on the version of libspirv.
// NOTE: this will be problematic if cross-compiling between OSes.
std::string Libclc{"clc/"};
Libclc.append(
#ifdef _WIN32
"remangled-l32-signed_char.libspirv-"
#else
"remangled-l64-signed_char.libspirv-"
#endif
);
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
: "amdgcn-amd-amdhsa.bc");
LibNames.push_back(Libclc);
}

LLVMContext &Context = Module.getContext();
for (const std::string &LibName : LibNames) {
Expand All @@ -578,6 +621,72 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
}
}

// For GPU targets we need to link against vendor provided libdevice.
if (IsCudaHIP) {
std::string Argv0 = DPCPPRoot + "/bin/clang++";
Triple T{Module.getTargetTriple()};
IntrusiveRefCntPtr<OverlayFileSystem> OFS{
new OverlayFileSystem{getRealFileSystem()}};
IntrusiveRefCntPtr<InMemoryFileSystem> VFS{new InMemoryFileSystem};
std::string CppFileName{"a.cpp"};
VFS->addFile(CppFileName, /*ModificationTime=*/0,
MemoryBuffer::getMemBuffer("", ""));
OFS->pushOverlay(VFS);
Driver D{Argv0, T.getTriple(), Diags, "dpcpp compiler driver", OFS};

SmallVector<std::string> CommandLine;
CommandLine.push_back(Argv0);
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
CommandLine.push_back(CppFileName);
SmallVector<const char *> CommandLineCStr(CommandLine.size());
llvm::transform(CommandLine, CommandLineCStr.begin(),
[](const auto &S) { return S.c_str(); });

Compilation *C = D.BuildCompilation(CommandLineCStr);
if (!C) {
return createStringError("Unable to construct driver for CUDA/HIP");
}

const ToolChain *OffloadTC =
C->getSingleOffloadToolChain<Action::OFK_SYCL>();
InputArgList EmptyArgList;
auto Archs =
D.getOffloadArchs(*C, EmptyArgList, Action::OFK_SYCL, OffloadTC);
assert(Archs.size() == 1 &&
"Offload toolchain should be configured to single architecture");
StringRef CPU = *Archs.begin();

// Pass only `-march=` or `-mcpu=` with the GPU arch determined by the
// driver to `getDeviceLibs`.
DerivedArgList CPUArgList{EmptyArgList};
if (Format == BinaryFormat::PTX) {
CPUArgList.AddJoinedArg(nullptr, D.getOpts().getOption(OPT_march_EQ),
CPU);
} else {
CPUArgList.AddJoinedArg(nullptr, D.getOpts().getOption(OPT_mcpu_EQ), CPU);
}

SmallVector<ToolChain::BitCodeLibraryInfo, 12> CommonDeviceLibs =
OffloadTC->getDeviceLibs(CPUArgList, Action::OffloadKind::OFK_SYCL);
Comment on lines +669 to +670
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: This (and the setup of the clang::Driver instance above) is the new bit in this PR. We cannot use CudaToolChain and ROCMToolChain directly because they are marked as hidden symbols.

if (CommonDeviceLibs.empty()) {
return createStringError("Unable to find common device libraries");
}

for (auto &Lib : CommonDeviceLibs) {
ModuleUPtr LibModule;
if (auto Error =
loadBitcodeLibrary(Lib.Path, Context).moveInto(LibModule)) {
return Error;
}

if (Linker::linkModules(Module, std::move(LibModule),
Linker::LinkOnlyNeeded)) {
return createStringError("Unable to link device library %s: %s",
Lib.Path.c_str(), BuildLog.c_str());
}
}
}

return Error::success();
}

Expand Down
8 changes: 5 additions & 3 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "JITBinaryInfo.h"
#include "RTC.h"

#include <llvm/ADT/SmallVector.h>
Expand All @@ -24,16 +25,17 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;

llvm::Expected<std::string>
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList);
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);

llvm::Expected<ModuleUPtr>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog, llvm::LLVMContext &Context);
std::string &BuildLog, llvm::LLVMContext &Context,
BinaryFormat Format);

llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog);
std::string &BuildLog, BinaryFormat Format);

using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
llvm::Expected<PostLinkResult>
Expand Down
Loading
Loading