Skip to content

Revert " [SYCL] RTC support for AMD and Nvidia GPU targets (#18918)" #19304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ target_include_directories(sycl-jit
${LLVM_MAIN_INCLUDE_DIR}
${LLVM_SPIRV_INCLUDE_DIRS}
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/lib
${CMAKE_BINARY_DIR}/tools/clang/include
)
target_include_directories(sycl-jit
Expand Down
11 changes: 3 additions & 8 deletions sycl-jit/jit-compiler/include/RTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,10 @@ class RTCResult {

/// Calculates a BLAKE3 hash of the pre-processed source string described by
/// \p SourceFile (considering any additional \p IncludeFiles) and the
/// concatenation of the \p UserArgs, for a given \p Format.
/// concatenation of the \p UserArgs.
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
BinaryFormat Format);
View<const char *> UserArgs);

/// Compiles, links against device libraries, and finalizes the device code in
/// the source string described by \p SourceFile, considering any additional \p
Expand All @@ -192,14 +191,10 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
///
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
/// the frontend invocation is wrapped in bitcode format in the result object.
///
/// \p BinaryFormat describes the desired format of the compilation - which
/// corresponds to the backend that is being targeted.
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR,
BinaryFormat Format);
View<char> CachedIR, bool SaveIR);

/// Requests that the JIT binary referenced by \p Address is deleted from the
/// `JITContext`.
Expand Down
138 changes: 18 additions & 120 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,11 @@

#include "DeviceCompilation.h"
#include "ESIMD.h"
#include "JITBinaryInfo.h"
#include "translation/Translation.h"

#include <Driver/ToolChains/AMDGPU.h>
#include <Driver/ToolChains/Cuda.h>
#include <Driver/ToolChains/LazyDetector.h>
#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Driver/Driver.h>
#include <clang/Driver/Options.h>
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
#include <clang/Frontend/CompilerInstance.h>
Expand Down Expand Up @@ -319,7 +313,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
} // anonymous namespace

static void adjustArgs(const InputArgList &UserArgList,
const std::string &DPCPPRoot, BinaryFormat Format,
const std::string &DPCPPRoot,
SmallVectorImpl<std::string> &CommandLine) {
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
Expand All @@ -332,23 +326,6 @@ static void adjustArgs(const InputArgList &UserArgList,
// unused argument warning.
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));

if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
(void)Features;
if (Format == BinaryFormat::AMDGCN) {
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
"amdgcn-amd-amdhsa");
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ),
"amdgcn-amd-amdhsa");
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
} else {
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
"nvptx64-nvidia-cuda");
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Xsycl_backend));
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_cuda_gpu_arch_EQ), CPU);
}
}
ArgStringList ASL;
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
for_each(UserArgList,
Expand Down Expand Up @@ -385,9 +362,10 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
});
}

Expected<std::string> jit_compiler::calculateHash(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, BinaryFormat Format) {
Expected<std::string>
jit_compiler::calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList) {
TimeTraceScope TTS{"calculateHash"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -396,7 +374,7 @@ Expected<std::string> jit_compiler::calculateHash(
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -422,10 +400,11 @@ Expected<std::string> jit_compiler::calculateHash(
return createStringError("Calculating source hash failed");
}

Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, std::string &BuildLog,
LLVMContext &Context, BinaryFormat Format) {
Expected<ModuleUPtr>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList,
std::string &BuildLog, LLVMContext &Context) {
TimeTraceScope TTS{"compileDeviceCode"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -434,7 +413,7 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
}

SmallVector<std::string> CommandLine;
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
adjustArgs(UserArgList, DPCPPRoot, CommandLine);

FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};
Expand All @@ -452,22 +431,12 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
return createStringError(BuildLog);
}

// This function is a simplified copy of the device library selection process
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
// GPU targets (no AoT, no native CPU). Keep in sync!
// This function is a simplified copy of the device library selection process in
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
static bool getDeviceLibraries(const ArgList &Args,
SmallVectorImpl<std::string> &LibraryList,
DiagnosticsEngine &Diags, BinaryFormat Format) {
// For CUDA/HIP we only need devicelib, early exit here.
if (Format == BinaryFormat::PTX) {
LibraryList.push_back(
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
return false;
} else if (Format == BinaryFormat::AMDGCN) {
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
return false;
}

DiagnosticsEngine &Diags) {
struct DeviceLibOptInfo {
StringRef DeviceLibName;
StringRef DeviceLibOption;
Expand Down Expand Up @@ -572,8 +541,7 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList,
std::string &BuildLog,
BinaryFormat Format) {
std::string &BuildLog) {
TimeTraceScope TTS{"linkDeviceLibraries"};

const std::string &DPCPPRoot = getDPCPPRoot();
Expand All @@ -588,29 +556,11 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
/* ShouldOwnClient=*/false);

SmallVector<std::string> LibNames;
const bool FoundUnknownLib =
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
if (FoundUnknownLib) {
return createStringError("Could not determine list of device libraries: %s",
BuildLog.c_str());
}
const bool IsCudaHIP =
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
if (IsCudaHIP) {
// Based on the OS and the format decide on the version of libspirv.
// NOTE: this will be problematic if cross-compiling between OSes.
std::string Libclc{"clc/"};
Libclc.append(
#ifdef _WIN32
"remangled-l32-signed_char.libspirv-"
#else
"remangled-l64-signed_char.libspirv-"
#endif
);
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
: "amdgcn-amd-amdhsa.bc");
LibNames.push_back(Libclc);
}

LLVMContext &Context = Module.getContext();
for (const std::string &LibName : LibNames) {
Expand All @@ -628,58 +578,6 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
}
}

// For GPU targets we need to link against vendor provided libdevice.
if (IsCudaHIP) {
Triple T{Module.getTargetTriple()};
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
(void)Features;
// Helper lambda to link modules.
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
ModuleUPtr LibDeviceModule;
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
.moveInto(LibDeviceModule)) {
return Error;
}
if (Linker::linkModules(Module, std::move(LibDeviceModule),
Linker::LinkOnlyNeeded)) {
return createStringError("Unable to link libdevice: %s",
BuildLog.c_str());
}
return Error::success();
};
SmallVector<std::string, 12> LibDeviceFiles;
if (Format == BinaryFormat::PTX) {
// For NVPTX we can get away with CudaInstallationDetector.
LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
UserArgList};
auto LibDevice = CudaInstallation->getLibDeviceFile(CPU);
if (LibDevice.empty()) {
return createStringError("Unable to find Cuda libdevice");
}
LibDeviceFiles.push_back(LibDevice);
} else {
// AMDGPU requires entire toolchain in order to provide all common bitcode
// libraries.
clang::driver::toolchains::ROCMToolChain TC(D, T, UserArgList);
auto CommonDeviceLibs = TC.getCommonDeviceLibNames(
UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false);
if (CommonDeviceLibs.empty()) {
return createStringError("Unable to find ROCm common device libraries");
}
for (auto &Lib : CommonDeviceLibs) {
LibDeviceFiles.push_back(Lib.Path);
}
}
for (auto &LibDeviceFile : LibDeviceFiles) {
// llvm::Error converts to false on success.
if (auto Error = LinkInLib(LibDeviceFile)) {
return Error;
}
}
}

return Error::success();
}

Expand Down
8 changes: 3 additions & 5 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#pragma once

#include "JITBinaryInfo.h"
#include "RTC.h"

#include <llvm/ADT/SmallVector.h>
Expand All @@ -25,17 +24,16 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;

llvm::Expected<std::string>
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<ModuleUPtr>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog, llvm::LLVMContext &Context,
BinaryFormat Format);
std::string &BuildLog, llvm::LLVMContext &Context);

llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog, BinaryFormat Format);
std::string &BuildLog);

using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
llvm::Expected<PostLinkResult>
Expand Down
22 changes: 9 additions & 13 deletions sycl-jit/jit-compiler/lib/rtc/RTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//

#include "RTC.h"
#include "JITBinaryInfo.h"
#include "helper/ErrorHelper.h"
#include "rtc/DeviceCompilation.h"
#include "translation/SPIRVLLVMTranslation.h"
Expand All @@ -27,8 +26,7 @@ using namespace jit_compiler;

JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
BinaryFormat Format) {
View<const char *> UserArgs) {
llvm::opt::InputArgList UserArgList;
if (auto Error = parseUserArgs(UserArgs).moveInto(UserArgList)) {
return errorTo<RTCHashResult>(std::move(Error),
Expand All @@ -38,8 +36,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,

auto Start = std::chrono::high_resolution_clock::now();
std::string Hash;
if (auto Error = calculateHash(SourceFile, IncludeFiles, UserArgList, Format)
.moveInto(Hash)) {
if (auto Error =
calculateHash(SourceFile, IncludeFiles, UserArgList).moveInto(Hash)) {
return errorTo<RTCHashResult>(std::move(Error), "Hashing failed",
/*IsHash=*/false);
}
Expand All @@ -57,8 +55,7 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR,
BinaryFormat Format) {
View<char> CachedIR, bool SaveIR) {
llvm::LLVMContext Context;
std::string BuildLog;
configureDiagnostics(Context, BuildLog);
Expand Down Expand Up @@ -107,7 +104,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
bool FromSource = !Module;
if (FromSource) {
if (auto Error = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
BuildLog, Context, Format)
BuildLog, Context)
.moveInto(Module)) {
return errorTo<RTCResult>(std::move(Error), "Device compilation failed");
}
Expand All @@ -121,8 +118,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
}

if (auto Error =
linkDeviceLibraries(*Module, UserArgList, BuildLog, Format)) {
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

Expand All @@ -135,9 +131,9 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,

for (auto [DevImgInfo, Module] :
llvm::zip_equal(BundleInfo.DevImgInfos, Modules)) {
if (auto Error =
Translator::translate(*Module, JITContext::getInstance(), Format)
.moveInto(DevImgInfo.BinaryInfo)) {
if (auto Error = Translator::translate(*Module, JITContext::getInstance(),
BinaryFormat::SPIRV)
.moveInto(DevImgInfo.BinaryInfo)) {
return errorTo<RTCResult>(std::move(Error), "SPIR-V translation failed");
}
}
Expand Down
Loading