Skip to content

Commit 84966b2

Browse files
committed
[SYCL/Driver] Add splitting module capabilities when compiling for NVPTX or AMDGCN.
The patch allows modules splitting for NVPTX and AMDGCN targets. The driver works by wrapping the actions with a ForEachWrappingAction. To allow standard clang tools to work properly, SYCLPostLinkJobAction and FileTableTformJobAction reports the underlying file type rather than TY_Tempfilelist or TY_tempfiletable for those targets. Signed-off-by: Victor Lomuller <[email protected]>
1 parent 039fcba commit 84966b2

File tree

6 files changed

+136
-101
lines changed

6 files changed

+136
-101
lines changed

clang/include/clang/Driver/Action.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,15 @@ class SYCLPostLinkJobAction : public JobAction {
738738
void anchor() override;
739739

740740
public:
741-
SYCLPostLinkJobAction(Action *Input, types::ID OutputType);
741+
// The tempfiletable management relies on a shadowing the main file type by
742+
// types::TY_Tempfiletable. The problem of shadowing is it prevents its
743+
// integration with clang tools that relies on the file type to properly set
744+
// args.
745+
// We "trick" the driver by declaring the underlying file type and set a
746+
// "true output type" which will be used by the SYCLPostLinkJobAction
747+
// to properly set the job.
748+
SYCLPostLinkJobAction(Action *Input, types::ID ShadowOutputType,
749+
types::ID TrueOutputType);
742750

743751
static bool classof(const Action *A) {
744752
return A->getKind() == SYCLPostLinkJobClass;
@@ -748,8 +756,11 @@ class SYCLPostLinkJobAction : public JobAction {
748756

749757
bool getRTSetsSpecConstants() const { return RTSetsSpecConsts; }
750758

759+
types::ID getTrueType() const { return TrueOutputType; }
760+
751761
private:
752762
bool RTSetsSpecConsts = true;
763+
types::ID TrueOutputType;
753764
};
754765

755766
class BackendCompileJobAction : public JobAction {
@@ -772,6 +783,9 @@ class FileTableTformJobAction : public JobAction {
772783
void anchor() override;
773784

774785
public:
786+
static constexpr const char *COL_CODE = "Code";
787+
static constexpr const char *COL_ZERO = "0";
788+
775789
struct Tform {
776790
enum Kind {
777791
EXTRACT,
@@ -792,8 +806,10 @@ class FileTableTformJobAction : public JobAction {
792806
SmallVector<std::string, 2> TheArgs;
793807
};
794808

795-
FileTableTformJobAction(Action *Input, types::ID OutputType);
796-
FileTableTformJobAction(ActionList &Inputs, types::ID OutputType);
809+
FileTableTformJobAction(Action *Input, types::ID ShadowOutputType,
810+
types::ID TrueOutputType);
811+
FileTableTformJobAction(ActionList &Inputs, types::ID ShadowOutputType,
812+
types::ID TrueOutputType);
797813

798814
// Deletes all columns except the one with given name.
799815
void addExtractColumnTform(StringRef ColumnName, bool WithColTitle = true);
@@ -821,7 +837,10 @@ class FileTableTformJobAction : public JobAction {
821837

822838
const ArrayRef<Tform> getTforms() const { return Tforms; }
823839

840+
types::ID getTrueType() const { return TrueOutputType; }
841+
824842
private:
843+
types::ID TrueOutputType;
825844
SmallVector<Tform, 2> Tforms; // transformation actions requested
826845

827846
// column to copy single file from if requested

clang/lib/Driver/Action.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,11 @@ SPIRCheckJobAction::SPIRCheckJobAction(Action *Input, types::ID Type)
476476

477477
void SYCLPostLinkJobAction::anchor() {}
478478

479-
SYCLPostLinkJobAction::SYCLPostLinkJobAction(Action *Input, types::ID Type)
480-
: JobAction(SYCLPostLinkJobClass, Input, Type) {}
479+
SYCLPostLinkJobAction::SYCLPostLinkJobAction(Action *Input,
480+
types::ID ShadowOutputType,
481+
types::ID TrueOutputType)
482+
: JobAction(SYCLPostLinkJobClass, Input, ShadowOutputType),
483+
TrueOutputType(TrueOutputType) {}
481484

482485
void BackendCompileJobAction::anchor() {}
483486

@@ -491,12 +494,17 @@ BackendCompileJobAction::BackendCompileJobAction(Action *Input,
491494

492495
void FileTableTformJobAction::anchor() {}
493496

494-
FileTableTformJobAction::FileTableTformJobAction(Action *Input, types::ID Type)
495-
: JobAction(FileTableTformJobClass, Input, Type) {}
497+
FileTableTformJobAction::FileTableTformJobAction(Action *Input,
498+
types::ID ShadowOutputType,
499+
types::ID TrueOutputType)
500+
: JobAction(FileTableTformJobClass, Input, ShadowOutputType),
501+
TrueOutputType(TrueOutputType) {}
496502

497503
FileTableTformJobAction::FileTableTformJobAction(ActionList &Inputs,
498-
types::ID Type)
499-
: JobAction(FileTableTformJobClass, Inputs, Type) {}
504+
types::ID ShadowOutputType,
505+
types::ID TrueOutputType)
506+
: JobAction(FileTableTformJobClass, Inputs, ShadowOutputType),
507+
TrueOutputType(TrueOutputType) {}
500508

501509
void FileTableTformJobAction::addExtractColumnTform(StringRef ColumnName,
502510
bool WithColTitle) {

clang/lib/Driver/Driver.cpp

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3896,7 +3896,7 @@ class OffloadingActionBuilder final {
38963896
SmallVector<CudaArch, 8> GpuArchList;
38973897

38983898
/// Build the last steps for CUDA after all BC files have been linked.
3899-
Action *finalizeNVPTXDependences(Action *Input, const llvm::Triple &TT) {
3899+
JobAction *finalizeNVPTXDependences(Action *Input, const llvm::Triple &TT) {
39003900
auto *BA = C.getDriver().ConstructPhaseAction(
39013901
C, Args, phases::Backend, Input, AssociatedOffloadKind);
39023902
if (TT.getOS() != llvm::Triple::NVCL) {
@@ -3906,10 +3906,11 @@ class OffloadingActionBuilder final {
39063906
return C.MakeAction<LinkJobAction>(DeviceActions,
39073907
types::TY_CUDA_FATBIN);
39083908
}
3909-
return BA;
3909+
return cast<JobAction>(BA);
39103910
}
39113911

3912-
Action *finalizeAMDGCNDependences(Action *Input, const llvm::Triple &TT) {
3912+
JobAction *finalizeAMDGCNDependences(Action *Input,
3913+
const llvm::Triple &TT) {
39133914
auto *BA = C.getDriver().ConstructPhaseAction(
39143915
C, Args, phases::Backend, Input, AssociatedOffloadKind);
39153916

@@ -3919,7 +3920,7 @@ class OffloadingActionBuilder final {
39193920
ActionList AL = {AA};
39203921
Action *LinkAction = C.MakeAction<LinkJobAction>(AL, types::TY_Image);
39213922
ActionList HIPActions = {LinkAction};
3922-
Action *HIPFatBinary =
3923+
JobAction *HIPFatBinary =
39233924
C.MakeAction<LinkJobAction>(HIPActions, types::TY_HIP_FATBIN);
39243925
return HIPFatBinary;
39253926
}
@@ -4052,7 +4053,7 @@ class OffloadingActionBuilder final {
40524053
else
40534054
FullDeviceLinkAction = DeviceLinkAction;
40544055
auto *PostLinkAction = C.MakeAction<SYCLPostLinkJobAction>(
4055-
FullDeviceLinkAction, types::TY_LLVM_BC);
4056+
FullDeviceLinkAction, types::TY_LLVM_BC, types::TY_LLVM_BC);
40564057
auto *TranslateAction = C.MakeAction<SPIRVTranslatorJobAction>(
40574058
PostLinkAction, types::TY_Image);
40584059
SYCLLinkBinary = C.MakeAction<OffloadWrapperJobAction>(
@@ -4322,6 +4323,7 @@ class OffloadingActionBuilder final {
43224323
auto TT = SYCLTripleList[I];
43234324
auto isNVPTX = (*TC)->getTriple().isNVPTX();
43244325
auto isAMDGCN = (*TC)->getTriple().isAMDGCN();
4326+
auto isSPIR = (*TC)->getTriple().isSPIR();
43254327
bool isSpirvAOT = TT.getSubArch() == llvm::Triple::SPIRSubArch_fpga ||
43264328
TT.getSubArch() == llvm::Triple::SPIRSubArch_gen ||
43274329
TT.getSubArch() == llvm::Triple::SPIRSubArch_x86_64;
@@ -4333,8 +4335,6 @@ class OffloadingActionBuilder final {
43334335
// directly to the backend compilation step (aocr) or wrapper (aocx)
43344336
else if (types::isFPGA(Input->getType())) {
43354337
Action *FPGAAOTAction;
4336-
constexpr char COL_CODE[] = "Code";
4337-
constexpr char COL_ZERO[] = "0";
43384338
if (Input->getType() == types::TY_FPGA_AOCR ||
43394339
Input->getType() == types::TY_FPGA_AOCR_EMU)
43404340
// Generate AOCX/AOCR
@@ -4346,8 +4346,10 @@ class OffloadingActionBuilder final {
43464346
else
43474347
llvm_unreachable("Unexpected FPGA input type.");
43484348
auto *RenameAction = C.MakeAction<FileTableTformJobAction>(
4349-
FPGAAOTAction, types::TY_Tempfilelist);
4350-
RenameAction->addRenameColumnTform(COL_ZERO, COL_CODE);
4349+
FPGAAOTAction, types::TY_Tempfilelist, types::TY_Tempfilelist);
4350+
RenameAction->addRenameColumnTform(
4351+
FileTableTformJobAction::COL_ZERO,
4352+
FileTableTformJobAction::COL_CODE);
43514353
auto *DeviceWrappingAction = C.MakeAction<OffloadWrapperJobAction>(
43524354
RenameAction, types::TY_Object);
43534355
DA.add(*DeviceWrappingAction, **TC, /*BoundArch=*/nullptr,
@@ -4361,7 +4363,7 @@ class OffloadingActionBuilder final {
43614363
// The linkage actions subgraph leading to the offload wrapper.
43624364
// [cond] Means incoming/outgoing dependence is created only when cond
43634365
// is true. A function of:
4364-
// n - target is NVPTX
4366+
// n - target is NVPTX/AMDGCN
43654367
// a - SPIRV AOT compilation is requested
43664368
// s - device code split requested
43674369
// * - "all other cases"
@@ -4378,17 +4380,18 @@ class OffloadingActionBuilder final {
43784380
// .--------------------------------------.
43794381
// | PostLink |
43804382
// .--------------------------------------.
4381-
// [+n] [+*] [+]
4382-
// | | |
4383-
// .----------------. .-----------------. |
4384-
// | FileTableTform | | FileTableTform | |
4385-
// | (copy "Code") | | (extract "Code")| |
4386-
// .----------------. .-----------------. |
4387-
// [.] [-] |
4388-
// | | |
4383+
// [+*] [+]
4384+
// | |
4385+
// .-----------------. |
4386+
// | FileTableTform | |
4387+
// | (extract "Code")| |
4388+
// .-----------------. |
4389+
// [-] |
4390+
// --------------------| |
43894391
// [.] [-*] |
43904392
// .---------------. .-------------------. |
4391-
// | finalizeNVPTX | | SPIRVTranslator | |
4393+
// | finalizeNVPTX | | SPIRVTranslator | |
4394+
// | finalizeAMDGCN | | | |
43924395
// .---------------. .-------------------. |
43934396
// [.] [-as] [-!a] |
43944397
// | | | |
@@ -4398,13 +4401,13 @@ class OffloadingActionBuilder final {
43984401
// | .----------------. | |
43994402
// | [-s] | |
44004403
// | | | |
4401-
// [.] [-a] [-!a] [+]
4402-
// .------------------------------------.
4403-
// | FileTableTform |
4404-
// | (replace "Code") |
4405-
// .------------------------------------.
4406-
// |
4407-
// [+]
4404+
// | [-a] [-!a] [+]
4405+
// | .--------------------.
4406+
// -----------[-n]| FileTableTform |
4407+
// | (replace "Code") |
4408+
// .--------------------.
4409+
// |
4410+
// [+*]
44084411
// .--------------------------------------.
44094412
// | OffloadWrapper |
44104413
// .--------------------------------------.
@@ -4419,7 +4422,7 @@ class OffloadingActionBuilder final {
44194422
// When spv online link is supported by all backends, the fallback
44204423
// device libraries are only needed when current toolchain is using
44214424
// AOT compilation.
4422-
if (!isNVPTX && !isAMDGCN) {
4425+
if (isSPIR) {
44234426
SYCLDeviceLibLinked = addSYCLDeviceLibs(
44244427
*TC, FullLinkObjects, true,
44254428
C.getDefaultToolChain().getTriple().isWindowsMSVCEnvironment());
@@ -4431,18 +4434,7 @@ class OffloadingActionBuilder final {
44314434
C.MakeAction<LinkJobAction>(FullLinkObjects, types::TY_LLVM_BC);
44324435
else
44334436
FullDeviceLinkAction = DeviceLinkAction;
4434-
// setup some flags upfront
4435-
4436-
if ((isNVPTX || isAMDGCN) && DeviceCodeSplit) {
4437-
// TODO Temporary limitation, need to support code splitting for PTX
4438-
const Driver &D = C.getDriver();
4439-
const std::string &OptName =
4440-
D.getOpts()
4441-
.getOption(options::OPT_fsycl_device_code_split)
4442-
.getPrefixedName();
4443-
D.Diag(diag::err_drv_unsupported_opt_for_target)
4444-
<< OptName << (*TC)->getTriple().str();
4445-
}
4437+
44464438
// reflects whether current target is ahead-of-time and can't support
44474439
// runtime setting of specialization constants
44484440
bool isAOT = isNVPTX || isAMDGCN || isSpirvAOT;
@@ -4451,50 +4443,48 @@ class OffloadingActionBuilder final {
44514443
ActionList WrapperInputs;
44524444
// post link is not optional - even if not splitting, always need to
44534445
// process specialization constants
4446+
4447+
types::ID PostLinkOutType =
4448+
isSPIR ? types::TY_Tempfiletable : FullDeviceLinkAction->getType();
4449+
// For SPIR-V targets, force TY_Tempfiletable.
44544450
auto *PostLinkAction = C.MakeAction<SYCLPostLinkJobAction>(
4455-
FullDeviceLinkAction, types::TY_Tempfiletable);
4451+
FullDeviceLinkAction, PostLinkOutType, types::TY_Tempfiletable);
44564452
PostLinkAction->setRTSetsSpecConstants(!isAOT);
44574453

4458-
constexpr char COL_CODE[] = "Code";
4454+
auto *ExtractIRFilesAction = C.MakeAction<FileTableTformJobAction>(
4455+
PostLinkAction,
4456+
isSPIR ? types::TY_Tempfilelist : PostLinkAction->getType(),
4457+
types::TY_Tempfilelist);
4458+
// single column w/o title fits TY_Tempfilelist format
4459+
ExtractIRFilesAction->addExtractColumnTform(
4460+
FileTableTformJobAction::COL_CODE, false /*drop titles*/);
44594461

44604462
if (isNVPTX || isAMDGCN) {
4461-
// Make extraction copy the only remaining code file instead of
4462-
// creating a new table with a single entry.
4463-
// TODO: Process all PTX code files in file table to enable code
4464-
// splitting for PTX target.
4465-
auto *ExtractIRFilesAction = C.MakeAction<FileTableTformJobAction>(
4466-
PostLinkAction, types::TY_LLVM_BC);
4467-
ExtractIRFilesAction->addCopySingleFileTform(COL_CODE, 0);
4468-
4469-
Action *FinAction;
4470-
if (isNVPTX) {
4471-
FinAction = finalizeNVPTXDependences(ExtractIRFilesAction,
4472-
(*TC)->getTriple());
4473-
} else /* isAMDGCN */ {
4474-
FinAction = finalizeAMDGCNDependences(ExtractIRFilesAction,
4463+
JobAction *FinAction =
4464+
isNVPTX ? finalizeNVPTXDependences(ExtractIRFilesAction,
4465+
(*TC)->getTriple())
4466+
: finalizeAMDGCNDependences(ExtractIRFilesAction,
44754467
(*TC)->getTriple());
4476-
}
4477-
ActionList TformInputs{PostLinkAction, FinAction};
4468+
auto *ForEachWrapping = C.MakeAction<ForEachWrappingAction>(
4469+
ExtractIRFilesAction, FinAction);
44784470

4479-
// Replace the only code entry in the table, as confirmed by the
4480-
// previous transformation.
4471+
ActionList TformInputs{PostLinkAction, ForEachWrapping};
44814472
auto *ReplaceFilesAction = C.MakeAction<FileTableTformJobAction>(
4482-
TformInputs, types::TY_Tempfiletable);
4483-
ReplaceFilesAction->addReplaceCellTform(COL_CODE, 0);
4473+
TformInputs, types::TY_Tempfiletable, types::TY_Tempfiletable);
4474+
ReplaceFilesAction->addReplaceColumnTform(
4475+
FileTableTformJobAction::COL_CODE,
4476+
FileTableTformJobAction::COL_CODE);
4477+
44844478
WrapperInputs.push_back(ReplaceFilesAction);
44854479
} else {
44864480
// For SPIRV-based targets - translate to SPIRV then optionally
44874481
// compile ahead-of-time to native architecture
4488-
auto *ExtractIRFilesAction = C.MakeAction<FileTableTformJobAction>(
4489-
PostLinkAction, types::TY_Tempfilelist);
4490-
// single column w/o title fits TY_Tempfilelist format
4491-
ExtractIRFilesAction->addExtractColumnTform(COL_CODE,
4492-
false /*drop titles*/);
4493-
Action *BuildCodeAction = C.MakeAction<SPIRVTranslatorJobAction>(
4494-
ExtractIRFilesAction, types::TY_Tempfilelist);
4482+
Action *BuildCodeAction =
4483+
(Action *)C.MakeAction<SPIRVTranslatorJobAction>(
4484+
ExtractIRFilesAction, types::TY_Tempfilelist);
44954485

44964486
// After the Link, wrap the files before the final host link
4497-
if (isSpirvAOT) {
4487+
if (isAOT) {
44984488
types::ID OutType = types::TY_Tempfilelist;
44994489
if (!DeviceCodeSplit) {
45004490
OutType = (TT.getSubArch() == llvm::Triple::SPIRSubArch_fpga)
@@ -4525,8 +4515,10 @@ class OffloadingActionBuilder final {
45254515
}
45264516
ActionList TformInputs{PostLinkAction, BuildCodeAction};
45274517
auto *ReplaceFilesAction = C.MakeAction<FileTableTformJobAction>(
4528-
TformInputs, types::TY_Tempfiletable);
4529-
ReplaceFilesAction->addReplaceColumnTform(COL_CODE, COL_CODE);
4518+
TformInputs, types::TY_Tempfiletable, types::TY_Tempfiletable);
4519+
ReplaceFilesAction->addReplaceColumnTform(
4520+
FileTableTformJobAction::COL_CODE,
4521+
FileTableTformJobAction::COL_CODE);
45304522
WrapperInputs.push_back(ReplaceFilesAction);
45314523
}
45324524

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8833,8 +8833,10 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA,
88338833
const InputInfoList &Inputs,
88348834
const llvm::opt::ArgList &TCArgs,
88358835
const char *LinkingOutput) const {
8836+
const SYCLPostLinkJobAction *SYCLPostLink =
8837+
dyn_cast<SYCLPostLinkJobAction>(&JA);
88368838
// Construct sycl-post-link command.
8837-
assert(isa<SYCLPostLinkJobAction>(JA) && "Expecting SYCL post link job!");
8839+
assert(SYCLPostLink && "Expecting SYCL post link job!");
88388840
ArgStringList CmdArgs;
88398841

88408842
// See if device code splitting is requested
@@ -8863,13 +8865,13 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA,
88638865
// Enable PI program metadata
88648866
if (getToolChain().getTriple().isNVPTX())
88658867
addArgs(CmdArgs, TCArgs, {"-emit-program-metadata"});
8866-
if (JA.getType() == types::TY_LLVM_BC) {
8868+
if (SYCLPostLink->getTrueType() == types::TY_LLVM_BC) {
88678869
// single file output requested - this means only perform necessary IR
88688870
// transformations (like specialization constant intrinsic lowering) and
88698871
// output LLVMIR
88708872
addArgs(CmdArgs, TCArgs, {"-ir-output-only"});
88718873
} else {
8872-
assert(JA.getType() == types::TY_Tempfiletable);
8874+
assert(SYCLPostLink->getTrueType() == types::TY_Tempfiletable);
88738875
// Symbol file and specialization constant info generation is mandatory -
88748876
// add options unconditionally
88758877
addArgs(CmdArgs, TCArgs, {"-symbols"});
@@ -8884,8 +8886,7 @@ void SYCLPostLink::ConstructJob(Compilation &C, const JobAction &JA,
88848886
addArgs(CmdArgs, TCArgs,
88858887
{StringRef(getSYCLPostLinkOptimizationLevel(TCArgs))});
88868888
// specialization constants processing is mandatory
8887-
auto *SYCLPostLink = llvm::dyn_cast<SYCLPostLinkJobAction>(&JA);
8888-
if (SYCLPostLink && SYCLPostLink->getRTSetsSpecConstants())
8889+
if (SYCLPostLink->getRTSetsSpecConstants())
88898890
addArgs(CmdArgs, TCArgs, {"-spec-const=rt"});
88908891
else
88918892
addArgs(CmdArgs, TCArgs, {"-spec-const=default"});

0 commit comments

Comments
 (0)