8
8
9
9
#include " DeviceCompilation.h"
10
10
#include " ESIMD.h"
11
+ #include " JITBinaryInfo.h"
12
+ #include " translation/Translation.h"
11
13
12
14
#include < clang/Basic/DiagnosticDriver.h>
13
15
#include < clang/Basic/Version.h>
22
24
#include < clang/Frontend/Utils.h>
23
25
#include < clang/Tooling/CompilationDatabase.h>
24
26
#include < clang/Tooling/Tooling.h>
27
+ #if defined(JIT_SUPPORT_PTX) || defined(JIT_SUPPORT_AMDGCN)
28
+ #include < clang/Driver/Driver.h>
29
+ #endif
30
+ #ifdef JIT_SUPPORT_PTX
31
+ #include < Driver/ToolChains/Cuda.h>
32
+ #include < Driver/ToolChains/LazyDetector.h>
33
+ #elif JIT_SUPPORT_AMDGCN
34
+ #include < Driver/ToolChains/AMDGPU.h>
35
+ #endif
25
36
26
37
#include < llvm/IR/DiagnosticInfo.h>
27
38
#include < llvm/IR/DiagnosticPrinter.h>
@@ -178,7 +189,8 @@ class RTCToolActionBase : public ToolAction {
178
189
assert (!hasExecuted () && " Action should only be invoked on a single file" );
179
190
180
191
// Create a compiler instance to handle the actual work.
181
- CompilerInstance Compiler (std::move (Invocation), std::move (PCHContainerOps));
192
+ CompilerInstance Compiler (std::move (Invocation),
193
+ std::move (PCHContainerOps));
182
194
Compiler.setFileManager (Files);
183
195
// Suppress summary with number of warnings and errors being printed to
184
196
// stdout.
@@ -361,10 +373,24 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361
373
});
362
374
}
363
375
364
- Expected<std::string>
365
- jit_compiler::calculateHash (InMemoryFile SourceFile,
366
- View<InMemoryFile> IncludeFiles,
367
- const InputArgList &UserArgList) {
376
+ static void setGPUTarget (BinaryFormat Format,
377
+ SmallVector<std::string> &CommandLine) {
378
+ auto [CPU, _] = Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
379
+ CommandLine.push_back (" -fsycl" );
380
+ if (Format == BinaryFormat::PTX) {
381
+ CommandLine.push_back (" -fsycl-targets=nvptx64-nvidia-cuda" );
382
+ CommandLine.push_back (" -Xsycl-target-backend" );
383
+ CommandLine.push_back (" --cuda-gpu-arch=" + CPU);
384
+ } else if (Format == BinaryFormat::AMDGCN) {
385
+ CommandLine.push_back (" -fsycl-targets=amdgcn-amd-amdhsa" );
386
+ CommandLine.push_back (" -Xsycl-target-backend=amdgcn-amd-amdhsa" );
387
+ CommandLine.push_back (" --offload-arch=" + CPU);
388
+ }
389
+ }
390
+
391
+ Expected<std::string> jit_compiler::calculateHash (
392
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
393
+ const InputArgList &UserArgList, BinaryFormat Format) {
368
394
TimeTraceScope TTS{" calculateHash" };
369
395
370
396
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -373,6 +399,9 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373
399
}
374
400
375
401
SmallVector<std::string> CommandLine;
402
+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
403
+ setGPUTarget (Format, CommandLine);
404
+ }
376
405
adjustArgs (UserArgList, DPCPPRoot, CommandLine);
377
406
378
407
FixedCompilationDatabase DB{" ." , CommandLine};
@@ -399,11 +428,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399
428
return createStringError (" Calculating source hash failed" );
400
429
}
401
430
402
- Expected<ModuleUPtr>
403
- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
404
- View<InMemoryFile> IncludeFiles,
405
- const InputArgList &UserArgList,
406
- std::string &BuildLog, LLVMContext &Context) {
431
+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
432
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
433
+ const InputArgList &UserArgList, std::string &BuildLog,
434
+ LLVMContext &Context, BinaryFormat Format) {
407
435
TimeTraceScope TTS{" compileDeviceCode" };
408
436
409
437
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -412,6 +440,9 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412
440
}
413
441
414
442
SmallVector<std::string> CommandLine;
443
+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
444
+ setGPUTarget (Format, CommandLine);
445
+ }
415
446
adjustArgs (UserArgList, DPCPPRoot, CommandLine);
416
447
417
448
FixedCompilationDatabase DB{" ." , CommandLine};
@@ -430,12 +461,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430
461
return createStringError (BuildLog);
431
462
}
432
463
433
- // This function is a simplified copy of the device library selection process in
434
- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435
- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
464
+ // This function is a simplified copy of the device library selection process
465
+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
466
+ // GPU targets ( no AoT , no native CPU). Keep in sync!
436
467
static bool getDeviceLibraries (const ArgList &Args,
437
468
SmallVectorImpl<std::string> &LibraryList,
438
- DiagnosticsEngine &Diags) {
469
+ DiagnosticsEngine &Diags, BinaryFormat Format) {
470
+ // For CUDA/HIP we only need devicelib, early exit here.
471
+ if (Format == BinaryFormat::PTX) {
472
+ LibraryList.push_back (
473
+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
474
+ return false ;
475
+ } else if (Format == BinaryFormat::AMDGCN) {
476
+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
477
+ return false ;
478
+ }
479
+
439
480
struct DeviceLibOptInfo {
440
481
StringRef DeviceLibName;
441
482
StringRef DeviceLibOption;
@@ -540,7 +581,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540
581
541
582
Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
542
583
const InputArgList &UserArgList,
543
- std::string &BuildLog) {
584
+ std::string &BuildLog,
585
+ BinaryFormat Format) {
544
586
TimeTraceScope TTS{" linkDeviceLibraries" };
545
587
546
588
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -555,11 +597,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555
597
/* ShouldOwnClient=*/ false );
556
598
557
599
SmallVector<std::string> LibNames;
558
- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
600
+ const bool FoundUnknownLib =
601
+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
559
602
if (FoundUnknownLib) {
560
603
return createStringError (" Could not determine list of device libraries: %s" ,
561
604
BuildLog.c_str ());
562
605
}
606
+ const bool IsGPUTarget =
607
+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
608
+ if (IsGPUTarget) {
609
+ // Based on the OS and the format decide on the version of libspirv.
610
+ // NOTE: this will be problematic if cross-compiling between OSes.
611
+ std::string Libclc{" clc/" };
612
+ Libclc.append (
613
+ #ifdef _WIN32
614
+ " remangled-l32-signed_char.libspirv-"
615
+ #else
616
+ " remangled-l64-signed_char.libspirv-"
617
+ #endif
618
+ );
619
+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
620
+ : " amdgcn-amd-amdhsa.bc" );
621
+ LibNames.push_back (Libclc);
622
+ }
563
623
564
624
LLVMContext &Context = Module.getContext ();
565
625
for (const std::string &LibName : LibNames) {
@@ -577,6 +637,57 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577
637
}
578
638
}
579
639
640
+ // For GPU targets we need to link against vendor provided libdevice.
641
+ if (IsGPUTarget) {
642
+ Triple T{Module.getTargetTriple ()};
643
+ Driver D{(Twine (DPCPPRoot) + " /bin/clang++" ).str (), T.getTriple (), Diags};
644
+ auto [CPU, _] =
645
+ Translator::getTargetCPUAndFeatureAttrs (&Module, " " , Format);
646
+ // Helper lambda to link modules.
647
+ auto LinkInLib = [&](const StringRef LibDevice) -> Error {
648
+ ModuleUPtr LibDeviceModule;
649
+ if (auto Error = loadBitcodeLibrary (LibDevice, Context)
650
+ .moveInto (LibDeviceModule)) {
651
+ return Error;
652
+ }
653
+ if (Linker::linkModules (Module, std::move (LibDeviceModule),
654
+ Linker::LinkOnlyNeeded)) {
655
+ return createStringError (" Unable to link libdevice: %s" ,
656
+ BuildLog.c_str ());
657
+ }
658
+ return Error::success ();
659
+ };
660
+ SmallVector<std::string, 12 > LibDeviceFiles;
661
+ #ifdef JIT_SUPPORT_PTX
662
+ // For NVPTX we can get away with CudaInstallationDetector.
663
+ LazyDetector<CudaInstallationDetector> CudaInstallation{D, T, UserArgList};
664
+ auto LibDevice = CudaInstallation->getLibDeviceFile (CPU);
665
+ if (LibDevice.empty ()) {
666
+ return createStringError (" Unable to find Cuda libdevice" );
667
+ }
668
+ LibDeviceFiles.push_back (LibDevice);
669
+ #elif JIT_SUPPORT_AMDGCN
670
+ // AMDGPU requires entire toolchain in order to provide all common bitcode
671
+ // libraries.
672
+ clang::driver::toolchains::ROCMToolChain TC (D, T, UserArgList);
673
+ auto CommonDeviceLibs = TC.getCommonDeviceLibNames (
674
+ UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false );
675
+ if (CommonDeviceLibs.empty ()) {
676
+ return createStringError (" Unable to find ROCm common device libraries" );
677
+ }
678
+ for (auto &Lib : CommonDeviceLibs) {
679
+ LibDeviceFiles.push_back (Lib.Path );
680
+ }
681
+ #endif
682
+ for (auto &LibDeviceFile : LibDeviceFiles) {
683
+ auto Res = LinkInLib (LibDeviceFile);
684
+ // llvm::Error converts to false on success.
685
+ if (Res) {
686
+ return Res;
687
+ }
688
+ }
689
+ }
690
+
580
691
return Error::success ();
581
692
}
582
693
0 commit comments