Skip to content

Commit 1c14ec3

Browse files
committed
Revert " [SYCL] RTC support for AMD and Nvidia GPU targets (#18918)"
This reverts commit 6d97d98.
1 parent e383f9f commit 1c14ec3

28 files changed

+148
-399
lines changed

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ target_include_directories(sycl-jit
6161
${LLVM_MAIN_INCLUDE_DIR}
6262
${LLVM_SPIRV_INCLUDE_DIRS}
6363
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
64-
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/lib
6564
${CMAKE_BINARY_DIR}/tools/clang/include
6665
)
6766
target_include_directories(sycl-jit

sycl-jit/jit-compiler/include/RTC.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,10 @@ class RTCResult {
176176

177177
/// Calculates a BLAKE3 hash of the pre-processed source string described by
178178
/// \p SourceFile (considering any additional \p IncludeFiles) and the
179-
/// concatenation of the \p UserArgs, for a given \p Format.
179+
/// concatenation of the \p UserArgs.
180180
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
181181
View<InMemoryFile> IncludeFiles,
182-
View<const char *> UserArgs,
183-
BinaryFormat Format);
182+
View<const char *> UserArgs);
184183

185184
/// Compiles, links against device libraries, and finalizes the device code in
186185
/// the source string described by \p SourceFile, considering any additional \p
@@ -192,14 +191,10 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
192191
///
193192
/// If \p SaveIR is true and \p CachedIR is empty, the LLVM module obtained from
194193
/// the frontend invocation is wrapped in bitcode format in the result object.
195-
///
196-
/// \p BinaryFormat describes the desired format of the compilation - which
197-
/// corresponds to the backend that is being targeted.
198194
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
199195
View<InMemoryFile> IncludeFiles,
200196
View<const char *> UserArgs,
201-
View<char> CachedIR, bool SaveIR,
202-
BinaryFormat Format);
197+
View<char> CachedIR, bool SaveIR);
203198

204199
/// Requests that the JIT binary referenced by \p Address is deleted from the
205200
/// `JITContext`.

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 19 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,11 @@
88

99
#include "DeviceCompilation.h"
1010
#include "ESIMD.h"
11-
#include "JITBinaryInfo.h"
12-
#include "translation/Translation.h"
1311

14-
#include <Driver/ToolChains/AMDGPU.h>
15-
#include <Driver/ToolChains/Cuda.h>
16-
#include <Driver/ToolChains/LazyDetector.h>
1712
#include <clang/Basic/DiagnosticDriver.h>
1813
#include <clang/Basic/Version.h>
1914
#include <clang/CodeGen/CodeGenAction.h>
2015
#include <clang/Driver/Compilation.h>
21-
#include <clang/Driver/Driver.h>
2216
#include <clang/Driver/Options.h>
2317
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
2418
#include <clang/Frontend/CompilerInstance.h>
@@ -184,8 +178,7 @@ class RTCToolActionBase : public ToolAction {
184178
assert(!hasExecuted() && "Action should only be invoked on a single file");
185179

186180
// Create a compiler instance to handle the actual work.
187-
CompilerInstance Compiler(std::move(Invocation),
188-
std::move(PCHContainerOps));
181+
CompilerInstance Compiler(std::move(Invocation), std::move(PCHContainerOps));
189182
Compiler.setFileManager(Files);
190183
// Suppress summary with number of warnings and errors being printed to
191184
// stdout.
@@ -319,7 +312,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
319312
} // anonymous namespace
320313

321314
static void adjustArgs(const InputArgList &UserArgList,
322-
const std::string &DPCPPRoot, BinaryFormat Format,
315+
const std::string &DPCPPRoot,
323316
SmallVectorImpl<std::string> &CommandLine) {
324317
DerivedArgList DAL{UserArgList};
325318
const auto &OptTable = getDriverOptTable();
@@ -332,23 +325,6 @@ static void adjustArgs(const InputArgList &UserArgList,
332325
// unused argument warning.
333326
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));
334327

335-
if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
336-
auto [CPU, Features] =
337-
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
338-
(void)Features;
339-
if (Format == BinaryFormat::AMDGCN) {
340-
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
341-
"amdgcn-amd-amdhsa");
342-
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ),
343-
"amdgcn-amd-amdhsa");
344-
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
345-
} else {
346-
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ),
347-
"nvptx64-nvidia-cuda");
348-
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Xsycl_backend));
349-
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_cuda_gpu_arch_EQ), CPU);
350-
}
351-
}
352328
ArgStringList ASL;
353329
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
354330
for_each(UserArgList,
@@ -385,9 +361,10 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
385361
});
386362
}
387363

388-
Expected<std::string> jit_compiler::calculateHash(
389-
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
390-
const InputArgList &UserArgList, BinaryFormat Format) {
364+
Expected<std::string>
365+
jit_compiler::calculateHash(InMemoryFile SourceFile,
366+
View<InMemoryFile> IncludeFiles,
367+
const InputArgList &UserArgList) {
391368
TimeTraceScope TTS{"calculateHash"};
392369

393370
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -396,7 +373,7 @@ Expected<std::string> jit_compiler::calculateHash(
396373
}
397374

398375
SmallVector<std::string> CommandLine;
399-
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
376+
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
400377

401378
FixedCompilationDatabase DB{".", CommandLine};
402379
ClangTool Tool{DB, {SourceFile.Path}};
@@ -422,10 +399,11 @@ Expected<std::string> jit_compiler::calculateHash(
422399
return createStringError("Calculating source hash failed");
423400
}
424401

425-
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
426-
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
427-
const InputArgList &UserArgList, std::string &BuildLog,
428-
LLVMContext &Context, BinaryFormat Format) {
402+
Expected<ModuleUPtr>
403+
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
404+
View<InMemoryFile> IncludeFiles,
405+
const InputArgList &UserArgList,
406+
std::string &BuildLog, LLVMContext &Context) {
429407
TimeTraceScope TTS{"compileDeviceCode"};
430408

431409
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -434,7 +412,7 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
434412
}
435413

436414
SmallVector<std::string> CommandLine;
437-
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
415+
adjustArgs(UserArgList, DPCPPRoot, CommandLine);
438416

439417
FixedCompilationDatabase DB{".", CommandLine};
440418
ClangTool Tool{DB, {SourceFile.Path}};
@@ -452,22 +430,12 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
452430
return createStringError(BuildLog);
453431
}
454432

455-
// This function is a simplified copy of the device library selection process
456-
// in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
457-
// GPU targets (no AoT, no native CPU). Keep in sync!
433+
// This function is a simplified copy of the device library selection process in
434+
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435+
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
458436
static bool getDeviceLibraries(const ArgList &Args,
459437
SmallVectorImpl<std::string> &LibraryList,
460-
DiagnosticsEngine &Diags, BinaryFormat Format) {
461-
// For CUDA/HIP we only need devicelib, early exit here.
462-
if (Format == BinaryFormat::PTX) {
463-
LibraryList.push_back(
464-
Args.MakeArgString("devicelib-nvptx64-nvidia-cuda.bc"));
465-
return false;
466-
} else if (Format == BinaryFormat::AMDGCN) {
467-
LibraryList.push_back(Args.MakeArgString("devicelib-amdgcn-amd-amdhsa.bc"));
468-
return false;
469-
}
470-
438+
DiagnosticsEngine &Diags) {
471439
struct DeviceLibOptInfo {
472440
StringRef DeviceLibName;
473441
StringRef DeviceLibOption;
@@ -572,8 +540,7 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
572540

573541
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
574542
const InputArgList &UserArgList,
575-
std::string &BuildLog,
576-
BinaryFormat Format) {
543+
std::string &BuildLog) {
577544
TimeTraceScope TTS{"linkDeviceLibraries"};
578545

579546
const std::string &DPCPPRoot = getDPCPPRoot();
@@ -588,29 +555,11 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
588555
/* ShouldOwnClient=*/false);
589556

590557
SmallVector<std::string> LibNames;
591-
const bool FoundUnknownLib =
592-
getDeviceLibraries(UserArgList, LibNames, Diags, Format);
558+
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
593559
if (FoundUnknownLib) {
594560
return createStringError("Could not determine list of device libraries: %s",
595561
BuildLog.c_str());
596562
}
597-
const bool IsCudaHIP =
598-
Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
599-
if (IsCudaHIP) {
600-
// Based on the OS and the format decide on the version of libspirv.
601-
// NOTE: this will be problematic if cross-compiling between OSes.
602-
std::string Libclc{"clc/"};
603-
Libclc.append(
604-
#ifdef _WIN32
605-
"remangled-l32-signed_char.libspirv-"
606-
#else
607-
"remangled-l64-signed_char.libspirv-"
608-
#endif
609-
);
610-
Libclc.append(Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda.bc"
611-
: "amdgcn-amd-amdhsa.bc");
612-
LibNames.push_back(Libclc);
613-
}
614563

615564
LLVMContext &Context = Module.getContext();
616565
for (const std::string &LibName : LibNames) {
@@ -628,58 +577,6 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
628577
}
629578
}
630579

631-
// For GPU targets we need to link against vendor provided libdevice.
632-
if (IsCudaHIP) {
633-
Triple T{Module.getTargetTriple()};
634-
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
635-
auto [CPU, Features] =
636-
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
637-
(void)Features;
638-
// Helper lambda to link modules.
639-
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
640-
ModuleUPtr LibDeviceModule;
641-
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
642-
.moveInto(LibDeviceModule)) {
643-
return Error;
644-
}
645-
if (Linker::linkModules(Module, std::move(LibDeviceModule),
646-
Linker::LinkOnlyNeeded)) {
647-
return createStringError("Unable to link libdevice: %s",
648-
BuildLog.c_str());
649-
}
650-
return Error::success();
651-
};
652-
SmallVector<std::string, 12> LibDeviceFiles;
653-
if (Format == BinaryFormat::PTX) {
654-
// For NVPTX we can get away with CudaInstallationDetector.
655-
LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
656-
UserArgList};
657-
auto LibDevice = CudaInstallation->getLibDeviceFile(CPU);
658-
if (LibDevice.empty()) {
659-
return createStringError("Unable to find Cuda libdevice");
660-
}
661-
LibDeviceFiles.push_back(LibDevice);
662-
} else {
663-
// AMDGPU requires entire toolchain in order to provide all common bitcode
664-
// libraries.
665-
clang::driver::toolchains::ROCMToolChain TC(D, T, UserArgList);
666-
auto CommonDeviceLibs = TC.getCommonDeviceLibNames(
667-
UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false);
668-
if (CommonDeviceLibs.empty()) {
669-
return createStringError("Unable to find ROCm common device libraries");
670-
}
671-
for (auto &Lib : CommonDeviceLibs) {
672-
LibDeviceFiles.push_back(Lib.Path);
673-
}
674-
}
675-
for (auto &LibDeviceFile : LibDeviceFiles) {
676-
// llvm::Error converts to false on success.
677-
if (auto Error = LinkInLib(LibDeviceFile)) {
678-
return Error;
679-
}
680-
}
681-
}
682-
683580
return Error::success();
684581
}
685582

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include "JITBinaryInfo.h"
1211
#include "RTC.h"
1312

1413
#include <llvm/ADT/SmallVector.h>
@@ -25,17 +24,16 @@ using ModuleUPtr = std::unique_ptr<llvm::Module>;
2524

2625
llvm::Expected<std::string>
2726
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
28-
const llvm::opt::InputArgList &UserArgList, BinaryFormat Format);
27+
const llvm::opt::InputArgList &UserArgList);
2928

3029
llvm::Expected<ModuleUPtr>
3130
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
3231
const llvm::opt::InputArgList &UserArgList,
33-
std::string &BuildLog, llvm::LLVMContext &Context,
34-
BinaryFormat Format);
32+
std::string &BuildLog, llvm::LLVMContext &Context);
3533

3634
llvm::Error linkDeviceLibraries(llvm::Module &Module,
3735
const llvm::opt::InputArgList &UserArgList,
38-
std::string &BuildLog, BinaryFormat Format);
36+
std::string &BuildLog);
3937

4038
using PostLinkResult = std::pair<RTCBundleInfo, llvm::SmallVector<ModuleUPtr>>;
4139
llvm::Expected<PostLinkResult>

sycl-jit/jit-compiler/lib/rtc/RTC.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "RTC.h"
10-
#include "JITBinaryInfo.h"
1110
#include "helper/ErrorHelper.h"
1211
#include "rtc/DeviceCompilation.h"
1312
#include "translation/SPIRVLLVMTranslation.h"
@@ -27,8 +26,7 @@ using namespace jit_compiler;
2726

2827
JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
2928
View<InMemoryFile> IncludeFiles,
30-
View<const char *> UserArgs,
31-
BinaryFormat Format) {
29+
View<const char *> UserArgs) {
3230
llvm::opt::InputArgList UserArgList;
3331
if (auto Error = parseUserArgs(UserArgs).moveInto(UserArgList)) {
3432
return errorTo<RTCHashResult>(std::move(Error),
@@ -38,8 +36,8 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
3836

3937
auto Start = std::chrono::high_resolution_clock::now();
4038
std::string Hash;
41-
if (auto Error = calculateHash(SourceFile, IncludeFiles, UserArgList, Format)
42-
.moveInto(Hash)) {
39+
if (auto Error =
40+
calculateHash(SourceFile, IncludeFiles, UserArgList).moveInto(Hash)) {
4341
return errorTo<RTCHashResult>(std::move(Error), "Hashing failed",
4442
/*IsHash=*/false);
4543
}
@@ -57,8 +55,7 @@ JIT_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
5755
JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
5856
View<InMemoryFile> IncludeFiles,
5957
View<const char *> UserArgs,
60-
View<char> CachedIR, bool SaveIR,
61-
BinaryFormat Format) {
58+
View<char> CachedIR, bool SaveIR) {
6259
llvm::LLVMContext Context;
6360
std::string BuildLog;
6461
configureDiagnostics(Context, BuildLog);
@@ -107,7 +104,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
107104
bool FromSource = !Module;
108105
if (FromSource) {
109106
if (auto Error = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
110-
BuildLog, Context, Format)
107+
BuildLog, Context)
111108
.moveInto(Module)) {
112109
return errorTo<RTCResult>(std::move(Error), "Device compilation failed");
113110
}
@@ -121,8 +118,7 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
121118
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
122119
}
123120

124-
if (auto Error =
125-
linkDeviceLibraries(*Module, UserArgList, BuildLog, Format)) {
121+
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
126122
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
127123
}
128124

@@ -135,9 +131,9 @@ JIT_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
135131

136132
for (auto [DevImgInfo, Module] :
137133
llvm::zip_equal(BundleInfo.DevImgInfos, Modules)) {
138-
if (auto Error =
139-
Translator::translate(*Module, JITContext::getInstance(), Format)
140-
.moveInto(DevImgInfo.BinaryInfo)) {
134+
if (auto Error = Translator::translate(*Module, JITContext::getInstance(),
135+
BinaryFormat::SPIRV)
136+
.moveInto(DevImgInfo.BinaryInfo)) {
141137
return errorTo<RTCResult>(std::move(Error), "SPIR-V translation failed");
142138
}
143139
}

0 commit comments

Comments
 (0)