Skip to content

[SYCL] RTC support for GPU targets #18918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
de5084d
[SYCL] RTC support for GPU targets
jchlanda Apr 29, 2025
50d8fea
PR feedback
jchlanda Jun 12, 2025
e7f6349
Bump sycl.hpp count - note, this include comes from RTC source
jchlanda Jun 12, 2025
f2f861d
PR feedback 2
jchlanda Jun 14, 2025
d750eaf
Fix handling of CPU/Features
jchlanda Jun 16, 2025
5a53343
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 16, 2025
21ca552
No need for semicolon in env
jchlanda Jun 17, 2025
c6ce0c3
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 17, 2025
07f6adf
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 18, 2025
c737208
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 18, 2025
9a0d655
Enable more RTC tests
jchlanda Jun 18, 2025
702fbc1
typo and device flags
jchlanda Jun 18, 2025
92bb53d
AMD arch substitution
jchlanda Jun 18, 2025
5791d9c
no run in rm
jchlanda Jun 18, 2025
08e1ae7
%run
jchlanda Jun 23, 2025
c14045e
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 23, 2025
1fa0115
Missing SYCL_JIT_AMDGCN_PTX_TARGET_CPU
jchlanda Jun 23, 2025
c621906
sycl cache only on l0 and opencl
jchlanda Jun 23, 2025
f0e0f72
run -> run-aux
jchlanda Jun 24, 2025
4f1587b
space
jchlanda Jun 24, 2025
0ebc239
ls
jchlanda Jun 25, 2025
37981b9
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jun 25, 2025
494f547
ls
jchlanda Jun 25, 2025
8e74e15
rm can not be in run line
jchlanda Jun 25, 2025
8eb60bd
rm can not be in run line
jchlanda Jun 25, 2025
ed00a53
Make sure to rm the temp dir after first run
jchlanda Jun 25, 2025
469a670
cache
jchlanda Jun 30, 2025
d9cd56a
cache
jchlanda Jun 30, 2025
efb5c32
correct num of sycl.hpp
jchlanda Jun 30, 2025
5c1bf6f
don't run cache on gpu
jchlanda Jun 30, 2025
4b55904
sycl_imf suffers from 18390
jchlanda Jul 1, 2025
9f520c0
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jul 1, 2025
51065b9
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jul 1, 2025
95ca6ef
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jul 2, 2025
71cef2c
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jul 2, 2025
bed6a20
no sycl hpp update
jchlanda Jul 2, 2025
63f0339
ignore features
jchlanda Jul 2, 2025
6351f2e
Merge remote-tracking branch 'upstream/sycl' into jakub/rtc_gpu
jchlanda Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ target_include_directories(sycl-jit
${LLVM_MAIN_INCLUDE_DIR}
${LLVM_SPIRV_INCLUDE_DIRS}
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/lib
${CMAKE_BINARY_DIR}/tools/clang/include
)
target_include_directories(sycl-jit
Expand Down
11 changes: 8 additions & 3 deletions sycl-jit/jit-compiler/include/RTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,11 @@ class RTCResult {

/// Calculates a BLAKE3 hash of the pre-processed source string described by
/// \p SourceFile (considering any additional \p IncludeFiles) and the
/// concatenation of the \p UserArgs.
/// concatenation of the \p UserArgs, for a given \p Format.
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);
View<const char *> UserArgs,
BinaryFormat Format);

/// Compiles, links against device libraries, and finalizes the device code in
/// the source string described by \p SourceFile, considering any additional \p
Expand All @@ -191,10 +192,14 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
///
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
/// the frontend invocation is wrapped in bitcode format in the result object.
///
/// \p BinaryFormat describes the desired format of the compilation - which
/// corresponds to the backend that is being targeted.
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR);
View<char> CachedIR, bool SaveIR,
BinaryFormat Format);

/// Requests that the JIT binary referenced by \p Address is deleted from the
/// `JITContext`.
Expand Down
141 changes: 122 additions & 19 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@

#include "DeviceCompilation.h"
#include "ESIMD.h"
#include "JITBinaryInfo.h"
#include "translation/Translation.h"

#include <Driver/ToolChains/AMDGPU.h>
#include <Driver/ToolChains/Cuda.h>
#include <Driver/ToolChains/LazyDetector.h>
#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Driver/Driver.h>
#include <clang/Driver/Options.h>
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
#include <clang/Frontend/CompilerInstance.h>
Expand Down Expand Up @@ -178,7 +184,8 @@ class RTCToolActionBase : public ToolAction {
assert(!hasExecuted() && "Action should only be invoked on a single file");

// Create a compiler instance to handle the actual work.
CompilerInstance Compiler(std::move(Invocation), std::move(PCHContainerOps));
CompilerInstance Compiler(std::move(Invocation),
std::move(PCHContainerOps));
Compiler.setFileManager(Files);
// Suppress summary with number of warnings and errors being printed to
// stdout.
Expand Down Expand Up @@ -312,7 +319,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
} // anonymous namespace

static void adjustArgs(const InputArgList &UserArgList,
const std::string &DPCPPRoot,
const std::string &DPCPPRoot, BinaryFormat Format,
SmallVectorImpl<std::string> &CommandLine) {
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
Expand All @@ -325,6 +332,23 @@ static void adjustArgs(const InputArgList &UserArgList,
// unused argument warning.
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));

if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
(void)Features;
if (Format == BinaryFormat::AMDGCN) {
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
"amdgcn-amd-amdhsa");
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ),
"amdgcn-amd-amdhsa");
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
} else {
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
"nvptx64-nvidia-cuda");
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Xsycl_backend));
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_cuda_gpu_arch_EQ), CPU);
}
}
ArgStringList ASL;
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
for_each(UserArgList,
Expand Down Expand Up @@ -361,10 +385,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
});
}

Expected<std::string>
jit_compiler::calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList) {
Expected<std::string> jit_compiler::calculateHash(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, BinaryFormat Format) {
TimeTraceScope TTS{"calculateHash"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -373,7 +396,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -399,11 +422,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
return createStringError("Calculating source hash failed");
}

Expected<ModuleUPtr>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList,
std::string &BuildLog, LLVMContext &Context) {
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, std::string &BuildLog,
LLVMContext &Context, BinaryFormat Format) {
TimeTraceScope TTS{"compileDeviceCode"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -412,7 +434,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -430,12 +452,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
return createStringError(BuildLog);
}

// This function is a simplified copy of the device library selection process in
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
// This function is a simplified copy of the device library selection process
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
// GPU targets (no AoT, no native CPU). Keep in sync!
static bool getDeviceLibraries(const ArgList &Args,
SmallVectorImpl<std::string> &LibraryList,
DiagnosticsEngine &Diags) {
DiagnosticsEngine &Diags, BinaryFormat Format) {
// For CUDA/HIP we only need devicelib, early exit here.
if (Format == BinaryFormat::PTX) {
LibraryList.push_back(
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
return false;
} else if (Format == BinaryFormat::AMDGCN) {
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
return false;
}

struct DeviceLibOptInfo {
StringRef DeviceLibName;
StringRef DeviceLibOption;
Expand Down Expand Up @@ -540,7 +572,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList,
std::string &BuildLog) {
std::string &BuildLog,
BinaryFormat Format) {
TimeTraceScope TTS{"linkDeviceLibraries"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -555,11 +588,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
/* ShouldOwnClient=*/false);

SmallVector<std::string> LibNames;
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
const bool FoundUnknownLib =
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
if (FoundUnknownLib) {
return createStringError("Could not determine list of device libraries: %s",
BuildLog.c_str());
}
const bool IsCudaHIP =
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
if (IsCudaHIP) {
// Based on the OS and the format decide on the version of libspirv.
// NOTE: this will be problematic if cross-compiling between OSes.
std::string Libclc{"clc/"};
Libclc.append(
#ifdef _WIN32
"remangled-l32-signed_char.libspirv-"
#else
"remangled-l64-signed_char.libspirv-"
#endif
);
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
: "amdgcn-amd-amdhsa.bc");
LibNames.push_back(Libclc);
}

LLVMContext &Context = Module.getContext();
for (const std::string &LibName : LibNames) {
Expand All @@ -577,6 +628,58 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
}
}

// For GPU targets we need to link against vendor provided libdevice.
if (IsCudaHIP) {
Triple T{Module.getTargetTriple()};
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
(void)Features;
// Helper lambda to link modules.
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
ModuleUPtr LibDeviceModule;
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
.moveInto(LibDeviceModule)) {
return Error;
}
if (Linker::linkModules(Module, std::move(LibDeviceModule),
Linker::LinkOnlyNeeded)) {
return createStringError("Unable to link libdevice: %s",
BuildLog.c_str());
}
return Error::success();
};
SmallVector<std::string, 12> LibDeviceFiles;
if (Format == BinaryFormat::PTX) {
// For NVPTX we can get away with CudaInstallationDetector.
LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
UserArgList};
auto LibDevice = CudaInstallation->getLibDeviceFile(CPU);
if (LibDevice.empty()) {
return createStringError("Unable to find Cuda libdevice");
}
LibDeviceFiles.push_back(LibDevice);
} else {
// AMDGPU requires entire toolchain in order to provide all common bitcode
// libraries.
clang::driver::toolchains::ROCMToolChain TC(D, T, UserArgList);
auto CommonDeviceLibs = TC.getCommonDeviceLibNames(
UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false);
if (CommonDeviceLibs.empty()) {
return createStringError("Unable to find ROCm common device libraries");
}
for (auto &Lib : CommonDeviceLibs) {
LibDeviceFiles.push_back(Lib.Path);
}
}
for (auto &LibDeviceFile : LibDeviceFiles) {
// llvm::Error converts to false on success.
if (auto Error = LinkInLib(LibDeviceFile)) {
return Error;
}
}
}

return Error::success();
}

Expand Down
8 changes: 5 additions & 3 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "JITBinaryInfo.h"
#include "RTC.h"

#include <llvm/ADT/SmallVector.h>
Expand All @@ -24,16 +25,17 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;

llvm::Expected<std::string>
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList);
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);

llvm::Expected<ModuleUPtr>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog, llvm::LLVMContext &Context);
std::string &BuildLog, llvm::LLVMContext &Context,
BinaryFormat Format);

llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog);
std::string &BuildLog, BinaryFormat Format);

using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
llvm::Expected<PostLinkResult>
Expand Down
22 changes: 13 additions & 9 deletions sycl-jit/jit-compiler/lib/rtc/RTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "RTC.h"
#include "JITBinaryInfo.h"
#include "helper/ErrorHelper.h"
#include "rtc/DeviceCompilation.h"
#include "translation/SPIRVLLVMTranslation.h"
Expand All @@ -26,7 +27,8 @@ using namespace jit_compiler;

JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
View<const char *> UserArgs,
BinaryFormat Format) {
llvm::opt::InputArgList UserArgList;
if (auto Error = parseUserArgs(UserArgs).moveInto(UserArgList)) {
return errorTo<RTCHashResult>(std::move(Error),
Expand All @@ -36,8 +38,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,

auto Start = std::chrono::high_resolution_clock::now();
std::string Hash;
if (auto Error =
calculateHash(SourceFile, IncludeFiles, UserArgList).moveInto(Hash)) {
if (auto Error = calculateHash(SourceFile, IncludeFiles, UserArgList, Format)
.moveInto(Hash)) {
return errorTo<RTCHashResult>(std::move(Error), "Hashing failed",
/*IsHash=*/false);
}
Expand All @@ -55,7 +57,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR) {
View<char> CachedIR, bool SaveIR,
BinaryFormat Format) {
llvm::LLVMContext Context;
std::string BuildLog;
configureDiagnostics(Context, BuildLog);
Expand Down Expand Up @@ -104,7 +107,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
bool FromSource = !Module;
if (FromSource) {
if (auto Error = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
BuildLog, Context)
BuildLog, Context, Format)
.moveInto(Module)) {
return errorTo<RTCResult>(std::move(Error), "Device compilation failed");
}
Expand All @@ -118,7 +121,8 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
}

if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
if (auto Error =
linkDeviceLibraries(*Module, UserArgList, BuildLog, Format)) {
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

Expand All @@ -131,9 +135,9 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,

for (auto [DevImgInfo, Module] :
llvm::zip_equal(BundleInfo.DevImgInfos, Modules)) {
if (auto Error = Translator::translate(*Module, JITContext::getInstance(),
BinaryFormat::SPIRV)
.moveInto(DevImgInfo.BinaryInfo)) {
if (auto Error =
Translator::translate(*Module, JITContext::getInstance(), Format)
.moveInto(DevImgInfo.BinaryInfo)) {
return errorTo<RTCResult>(std::move(Error), "SPIR-V translation failed");
}
}
Expand Down
Loading