8
8
9
9
#include " DeviceCompilation.h"
10
10
#include " ESIMD.h"
11
+ #include " JITBinaryInfo.h"
12
+ #include " translation/Translation.h"
11
13
12
14
#include < clang/Basic/DiagnosticDriver.h>
13
15
#include < clang/Basic/Version.h>
14
16
#include < clang/CodeGen/CodeGenAction.h>
15
17
#include < clang/Driver/Compilation.h>
18
+ #include < clang/Driver/Driver.h>
16
19
#include < clang/Driver/Options.h>
20
+ #include < clang/Driver/ToolChain.h>
17
21
#include < clang/Frontend/ChainedDiagnosticConsumer.h>
18
22
#include < clang/Frontend/CompilerInstance.h>
19
23
#include < clang/Frontend/FrontendActions.h>
@@ -52,6 +56,7 @@ using namespace llvm::opt;
52
56
using namespace llvm ::sycl;
53
57
using namespace llvm ::module_split;
54
58
using namespace llvm ::util;
59
+ using namespace llvm ::vfs;
55
60
using namespace jit_compiler ;
56
61
57
62
#ifdef _GNU_SOURCE
@@ -313,7 +318,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
313
318
} // anonymous namespace
314
319
315
320
static void adjustArgs (const InputArgList &UserArgList,
316
- const std::string &DPCPPRoot,
321
+ const std::string &DPCPPRoot, BinaryFormat Format,
317
322
SmallVectorImpl<std::string> &CommandLine) {
318
323
DerivedArgList DAL{UserArgList};
319
324
const auto &OptTable = getDriverOptTable ();
@@ -326,6 +331,17 @@ static void adjustArgs(const InputArgList &UserArgList,
326
331
// unused argument warning.
327
332
DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Qunused_arguments));
328
333
334
+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
335
+ auto [CPU, Features] =
336
+ Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
337
+ (void )Features;
338
+ StringRef OT = Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda"
339
+ : " amdgcn-amd-amdhsa" ;
340
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ), OT);
341
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_Xsycl_backend_EQ), OT);
342
+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_offload_arch_EQ), CPU);
343
+ }
344
+
329
345
ArgStringList ASL;
330
346
for_each (DAL, [&DAL, &ASL](Arg *A) { A->render (DAL, ASL); });
331
347
for_each (UserArgList,
@@ -362,10 +378,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
362
378
});
363
379
}
364
380
365
- Expected<std::string>
366
- jit_compiler::calculateHash (InMemoryFile SourceFile,
367
- View<InMemoryFile> IncludeFiles,
368
- const InputArgList &UserArgList) {
381
+ Expected<std::string> jit_compiler::calculateHash (
382
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
383
+ const InputArgList &UserArgList, BinaryFormat Format) {
369
384
TimeTraceScope TTS{" calculateHash" };
370
385
371
386
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -374,7 +389,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
374
389
}
375
390
376
391
SmallVector<std::string> CommandLine;
377
- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
392
+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
378
393
379
394
FixedCompilationDatabase DB{" ." , CommandLine};
380
395
ClangTool Tool{DB, {SourceFile.Path }};
@@ -400,11 +415,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
400
415
return createStringError (" Calculating source hash failed" );
401
416
}
402
417
403
- Expected<ModuleUPtr>
404
- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
405
- View<InMemoryFile> IncludeFiles,
406
- const InputArgList &UserArgList,
407
- std::string &BuildLog, LLVMContext &Context) {
418
+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
419
+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
420
+ const InputArgList &UserArgList, std::string &BuildLog,
421
+ LLVMContext &Context, BinaryFormat Format) {
408
422
TimeTraceScope TTS{" compileDeviceCode" };
409
423
410
424
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -413,7 +427,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
413
427
}
414
428
415
429
SmallVector<std::string> CommandLine;
416
- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
430
+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
417
431
418
432
FixedCompilationDatabase DB{" ." , CommandLine};
419
433
ClangTool Tool{DB, {SourceFile.Path }};
@@ -431,12 +445,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
431
445
return createStringError (BuildLog);
432
446
}
433
447
434
- // This function is a simplified copy of the device library selection process in
435
- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
436
- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
448
+ // This function is a simplified copy of the device library selection process
449
+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
450
+ // GPU targets ( no AoT , no native CPU). Keep in sync!
437
451
static bool getDeviceLibraries (const ArgList &Args,
438
452
SmallVectorImpl<std::string> &LibraryList,
439
- DiagnosticsEngine &Diags) {
453
+ DiagnosticsEngine &Diags, BinaryFormat Format) {
454
+ // For CUDA/HIP we only need devicelib, early exit here.
455
+ if (Format == BinaryFormat::PTX) {
456
+ LibraryList.push_back (
457
+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
458
+ return false ;
459
+ } else if (Format == BinaryFormat::AMDGCN) {
460
+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
461
+ return false ;
462
+ }
463
+
440
464
struct DeviceLibOptInfo {
441
465
StringRef DeviceLibName;
442
466
StringRef DeviceLibOption;
@@ -541,7 +565,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
541
565
542
566
Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
543
567
const InputArgList &UserArgList,
544
- std::string &BuildLog) {
568
+ std::string &BuildLog,
569
+ BinaryFormat Format) {
545
570
TimeTraceScope TTS{" linkDeviceLibraries" };
546
571
547
572
const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -556,11 +581,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
556
581
/* ShouldOwnClient=*/ false );
557
582
558
583
SmallVector<std::string> LibNames;
559
- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
584
+ const bool FoundUnknownLib =
585
+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
560
586
if (FoundUnknownLib) {
561
587
return createStringError (" Could not determine list of device libraries: %s" ,
562
588
BuildLog.c_str ());
563
589
}
590
+ const bool IsCudaHIP =
591
+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
592
+ if (IsCudaHIP) {
593
+ // Based on the OS and the format decide on the version of libspirv.
594
+ // NOTE: this will be problematic if cross-compiling between OSes.
595
+ std::string Libclc{" clc/" };
596
+ Libclc.append (
597
+ #ifdef _WIN32
598
+ " remangled-l32-signed_char.libspirv-"
599
+ #else
600
+ " remangled-l64-signed_char.libspirv-"
601
+ #endif
602
+ );
603
+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
604
+ : " amdgcn-amd-amdhsa.bc" );
605
+ LibNames.push_back (Libclc);
606
+ }
564
607
565
608
LLVMContext &Context = Module.getContext ();
566
609
for (const std::string &LibName : LibNames) {
@@ -578,6 +621,72 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
578
621
}
579
622
}
580
623
624
+ // For GPU targets we need to link against vendor provided libdevice.
625
+ if (IsCudaHIP) {
626
+ std::string Argv0 = DPCPPRoot + " /bin/clang++" ;
627
+ Triple T{Module.getTargetTriple ()};
628
+ IntrusiveRefCntPtr<OverlayFileSystem> OFS{
629
+ new OverlayFileSystem{getRealFileSystem ()}};
630
+ IntrusiveRefCntPtr<InMemoryFileSystem> VFS{new InMemoryFileSystem};
631
+ std::string CppFileName{" a.cpp" };
632
+ VFS->addFile (CppFileName, /* ModificationTime=*/ 0 ,
633
+ MemoryBuffer::getMemBuffer (" " , " " ));
634
+ OFS->pushOverlay (VFS);
635
+ Driver D{Argv0, T.getTriple (), Diags, " dpcpp compiler driver" , OFS};
636
+
637
+ SmallVector<std::string> CommandLine;
638
+ CommandLine.push_back (Argv0);
639
+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
640
+ CommandLine.push_back (CppFileName);
641
+ SmallVector<const char *> CommandLineCStr (CommandLine.size ());
642
+ llvm::transform (CommandLine, CommandLineCStr.begin (),
643
+ [](const auto &S) { return S.c_str (); });
644
+
645
+ Compilation *C = D.BuildCompilation (CommandLineCStr);
646
+ if (!C) {
647
+ return createStringError (" Unable to construct driver for CUDA/HIP" );
648
+ }
649
+
650
+ const ToolChain *OffloadTC =
651
+ C->getSingleOffloadToolChain <Action::OFK_SYCL>();
652
+ InputArgList EmptyArgList;
653
+ auto Archs =
654
+ D.getOffloadArchs (*C, EmptyArgList, Action::OFK_SYCL, OffloadTC);
655
+ assert (Archs.size () == 1 &&
656
+ " Offload toolchain should be configured to single architecture" );
657
+ StringRef CPU = *Archs.begin ();
658
+
659
+ // Pass only `-march=` or `-mcpu=` with the GPU arch determined by the
660
+ // driver to `getDeviceLibs`.
661
+ DerivedArgList CPUArgList{EmptyArgList};
662
+ if (Format == BinaryFormat::PTX) {
663
+ CPUArgList.AddJoinedArg (nullptr , D.getOpts ().getOption (OPT_march_EQ),
664
+ CPU);
665
+ } else {
666
+ CPUArgList.AddJoinedArg (nullptr , D.getOpts ().getOption (OPT_mcpu_EQ), CPU);
667
+ }
668
+
669
+ SmallVector<ToolChain::BitCodeLibraryInfo, 12 > CommonDeviceLibs =
670
+ OffloadTC->getDeviceLibs (CPUArgList, Action::OffloadKind::OFK_SYCL);
671
+ if (CommonDeviceLibs.empty ()) {
672
+ return createStringError (" Unable to find common device libraries" );
673
+ }
674
+
675
+ for (auto &Lib : CommonDeviceLibs) {
676
+ ModuleUPtr LibModule;
677
+ if (auto Error =
678
+ loadBitcodeLibrary (Lib.Path , Context).moveInto (LibModule)) {
679
+ return Error;
680
+ }
681
+
682
+ if (Linker::linkModules (Module, std::move (LibModule),
683
+ Linker::LinkOnlyNeeded)) {
684
+ return createStringError (" Unable to link device library %s: %s" ,
685
+ Lib.Path .c_str (), BuildLog.c_str ());
686
+ }
687
+ }
688
+ }
689
+
581
690
return Error::success ();
582
691
}
583
692
0 commit comments