Skip to content

Commit 8c348ba

Browse files
committed
[OMPIRBuilder][MLIR] Add support for target 'if' clause
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.
1 parent 364cd46 commit 8c348ba

File tree

6 files changed

+200
-118
lines changed

6 files changed

+200
-118
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
29942994
/// \param Loc where the target data construct was encountered.
29952995
/// \param IsOffloadEntry whether it is an offload entry.
29962996
/// \param CodeGenIP The insertion point where the call to the outlined
2997-
/// function should be emitted.
2997+
/// function should be emitted.
29982998
/// \param EntryInfo The entry information about the function.
29992999
/// \param DefaultAttrs Structure containing the default attributes, including
30003000
/// numbers of threads and teams to launch the kernel with.
30013001
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
30023002
/// and teams to launch the kernel with.
3003+
/// \param IfCond value of the `if` clause.
30033004
/// \param Inputs The input values to the region that will be passed.
3004-
/// as arguments to the outlined function.
3005+
/// as arguments to the outlined function.
30053006
/// \param BodyGenCB Callback that will generate the region code.
30063007
/// \param ArgAccessorFuncCB Callback that will generate accessors
3007-
/// instructions for passed in target arguments where neccessary
3008+
/// instructions for passed in target arguments where neccessary
30083009
/// \param Dependencies A vector of DependData objects that carry
3009-
// dependency information as passed in the depend clause
3010-
// \param HasNowait Whether the target construct has a `nowait` clause or not.
3010+
/// dependency information as passed in the depend clause
3011+
/// \param HasNowait Whether the target construct has a `nowait` clause or
3012+
/// not.
30113013
InsertPointOrErrorTy createTarget(
30123014
const LocationDescription &Loc, bool IsOffloadEntry,
30133015
OpenMPIRBuilder::InsertPointTy AllocaIP,
30143016
OpenMPIRBuilder::InsertPointTy CodeGenIP,
30153017
TargetRegionEntryInfo &EntryInfo,
30163018
const TargetKernelDefaultAttrs &DefaultAttrs,
3017-
const TargetKernelRuntimeAttrs &RuntimeAttrs,
3019+
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
30183020
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30193021
TargetBodyGenCallbackTy BodyGenCB,
30203022
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 105 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73407340
OpenMPIRBuilder::InsertPointTy AllocaIP,
73417341
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
73427342
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7343-
Function *OutlinedFn, Constant *OutlinedFnID,
7343+
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
73447344
SmallVectorImpl<Value *> &Args,
73457345
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73467346
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73867386
return Error::success();
73877387
};
73887388

7389-
// If we don't have an ID for the target region, it means an offload entry
7390-
// wasn't created. In this case we just run the host fallback directly.
7391-
if (!OutlinedFnID) {
7389+
auto &&EmitTargetCallElse =
7390+
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7391+
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
73927392
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
73937393
// produce any.
73947394
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7404,102 +7404,124 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74047404
}());
74057405

74067406
Builder.restoreIP(AfterIP);
7407-
return;
7408-
}
7409-
7410-
OpenMPIRBuilder::TargetDataInfo Info(
7411-
/*RequiresDevicePointerInfo=*/false,
7412-
/*SeparateBeginEndCalls=*/true);
7413-
7414-
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7415-
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7416-
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7417-
RTArgs, MapInfo,
7418-
/*IsNonContiguous=*/true,
7419-
/*ForEndCall=*/false);
7420-
7421-
SmallVector<Value *, 3> NumTeamsC;
7422-
for (auto [DefaultVal, RuntimeVal] :
7423-
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7424-
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7425-
7426-
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7427-
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7428-
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7429-
if (Clause)
7430-
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7431-
/*isSigned=*/false);
7432-
return Clause;
7433-
};
7434-
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7435-
if (Clause)
7436-
Result = Result
7437-
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7438-
Result, Clause)
7439-
: Clause;
7407+
return Error::success();
74407408
};
74417409

7442-
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7443-
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
7444-
SmallVector<Value *, 3> NumThreadsC;
7445-
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7446-
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7447-
: nullptr;
7410+
auto &&EmitTargetCallThen =
7411+
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7412+
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7413+
OpenMPIRBuilder::TargetDataInfo Info(
7414+
/*RequiresDevicePointerInfo=*/false,
7415+
/*SeparateBeginEndCalls=*/true);
7416+
7417+
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7418+
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7419+
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7420+
RTArgs, MapInfo,
7421+
/*IsNonContiguous=*/true,
7422+
/*ForEndCall=*/false);
7423+
7424+
SmallVector<Value *, 3> NumTeamsC;
7425+
for (auto [DefaultVal, RuntimeVal] :
7426+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7427+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7428+
7429+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7430+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7431+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7432+
if (Clause)
7433+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7434+
/*isSigned=*/false);
7435+
return Clause;
7436+
};
7437+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7438+
if (Clause)
7439+
Result = Result
7440+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7441+
Result, Clause)
7442+
: Clause;
7443+
};
74487444

7449-
for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
7450-
RuntimeAttrs.TargetThreadLimit)) {
7451-
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7452-
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7445+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7446+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
7447+
SmallVector<Value *, 3> NumThreadsC;
7448+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7449+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7450+
: nullptr;
74537451

7454-
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7455-
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7452+
for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
7453+
RuntimeAttrs.TargetThreadLimit)) {
7454+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7455+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
74567456

7457-
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7458-
}
7457+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7458+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
74597459

7460-
unsigned NumTargetItems = Info.NumberOfPtrs;
7461-
// TODO: Use correct device ID
7462-
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7463-
uint32_t SrcLocStrSize;
7464-
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7465-
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7466-
llvm::omp::IdentFlag(0), 0);
7460+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7461+
}
74677462

7468-
Value *TripCount = RuntimeAttrs.LoopTripCount
7469-
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7470-
Builder.getInt64Ty(),
7471-
/*isSigned=*/false)
7472-
: Builder.getInt64(0);
7463+
unsigned NumTargetItems = Info.NumberOfPtrs;
7464+
// TODO: Use correct device ID
7465+
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7466+
uint32_t SrcLocStrSize;
7467+
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7468+
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7469+
llvm::omp::IdentFlag(0), 0);
74737470

7474-
// TODO: Use correct DynCGGroupMem
7475-
Value *DynCGGroupMem = Builder.getInt32(0);
7471+
Value *TripCount = RuntimeAttrs.LoopTripCount
7472+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7473+
Builder.getInt64Ty(),
7474+
/*isSigned=*/false)
7475+
: Builder.getInt64(0);
74767476

7477-
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7478-
NumTeamsC, NumThreadsC,
7479-
DynCGGroupMem, HasNoWait);
7477+
// TODO: Use correct DynCGGroupMem
7478+
Value *DynCGGroupMem = Builder.getInt32(0);
74807479

7481-
// Assume no error was returned because TaskBodyCB and
7482-
// EmitTargetCallFallbackCB don't produce any.
7483-
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7484-
// The presence of certain clauses on the target directive require the
7485-
// explicit generation of the target task.
7486-
if (RequiresOuterTargetTask)
7487-
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7488-
Dependencies, HasNoWait);
7480+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7481+
NumTeamsC, NumThreadsC,
7482+
DynCGGroupMem, HasNoWait);
74897483

7490-
return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7491-
EmitTargetCallFallbackCB, KArgs,
7492-
DeviceID, RTLoc, AllocaIP);
7493-
}());
7484+
// Assume no error was returned because TaskBodyCB and
7485+
// EmitTargetCallFallbackCB don't produce any.
7486+
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7487+
// The presence of certain clauses on the target directive require the
7488+
// explicit generation of the target task.
7489+
if (RequiresOuterTargetTask)
7490+
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7491+
Dependencies, HasNoWait);
7492+
7493+
return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7494+
EmitTargetCallFallbackCB, KArgs,
7495+
DeviceID, RTLoc, AllocaIP);
7496+
}());
7497+
7498+
Builder.restoreIP(AfterIP);
7499+
return Error::success();
7500+
};
7501+
7502+
// If we don't have an ID for the target region, it means an offload entry
7503+
// wasn't created. In this case we just run the host fallback directly and
7504+
// ignore any potential 'if' clauses.
7505+
if (!OutlinedFnID) {
7506+
cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
7507+
return;
7508+
}
7509+
7510+
// If there's no 'if' clause, only generate the kernel launch code path.
7511+
if (!IfCond) {
7512+
cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
7513+
return;
7514+
}
74947515

7495-
Builder.restoreIP(AfterIP);
7516+
cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
7517+
EmitTargetCallElse, AllocaIP));
74967518
}
74977519

74987520
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74997521
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
75007522
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
75017523
const TargetKernelDefaultAttrs &DefaultAttrs,
7502-
const TargetKernelRuntimeAttrs &RuntimeAttrs,
7524+
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
75037525
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
75047526
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
75057527
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7546,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
75247546
// to make a remote call (offload) to the previously outlined function
75257547
// that represents the target region. Do that now.
75267548
if (!Config.isTargetDevice())
7527-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7549+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
75287550
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
75297551
HasNowait);
75307552
return Builder.saveIP();

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
62436243
OpenMPIRBuilder::InsertPointTy, AfterIP,
62446244
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
62456245
Builder.saveIP(), EntryInfo, DefaultAttrs,
6246-
RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
6247-
SimpleArgAccessorCB));
6246+
RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
6247+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
62486248
Builder.restoreIP(AfterIP);
62496249

62506250
OMPBuilder.finalize();
@@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
64026402
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
64036403
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
64046404

6405-
ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
6406-
OMPBuilder.createTarget(
6407-
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6408-
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
6409-
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
6405+
ASSERT_EXPECTED_INIT(
6406+
OpenMPIRBuilder::InsertPointTy, AfterIP,
6407+
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6408+
EntryInfo, DefaultAttrs, RuntimeAttrs,
6409+
/*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
6410+
BodyGenCB, SimpleArgAccessorCB));
64106411
Builder.restoreIP(AfterIP);
64116412

64126413
Builder.CreateRetVoid();
@@ -6774,11 +6775,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
67746775
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
67756776
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
67766777

6777-
ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
6778-
OMPBuilder.createTarget(
6779-
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6780-
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
6781-
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
6778+
ASSERT_EXPECTED_INIT(
6779+
OpenMPIRBuilder::InsertPointTy, AfterIP,
6780+
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6781+
EntryInfo, DefaultAttrs, RuntimeAttrs,
6782+
/*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
6783+
BodyGenCB, SimpleArgAccessorCB));
67826784
Builder.restoreIP(AfterIP);
67836785

67846786
Builder.CreateRetVoid();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
183183
result = op.emitError("not yet implemented: host evaluation of loop "
184184
"bounds in omp.target operation");
185185
};
186-
auto checkIf = [&todo](auto op, LogicalResult &result) {
187-
if (op.getIfExpr())
188-
result = todo("if");
189-
};
190186
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
191187
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
192188
op.getInReductionSyms())
@@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
306302
checkDevice(op, result);
307303
checkHasDeviceAddr(op, result);
308304
checkHostEval(op, result);
309-
checkIf(op, result);
310305
checkInReduction(op, result);
311306
checkIsDevicePtr(op, result);
312307
checkPrivate(op, result);
@@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
43784373
findAllocaInsertPoint(builder, moduleTranslation);
43794374
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
43804375

4376+
llvm::Value *ifCond = nullptr;
4377+
if (Value targetIfCond = targetOp.getIfExpr())
4378+
ifCond = moduleTranslation.lookupValue(targetIfCond);
4379+
43814380
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
43824381
moduleTranslation.getOpenMPBuilder()->createTarget(
43834382
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4384-
defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
4383+
defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
43854384
argAccessorCB, dds, targetOp.getNowait());
43864385

43874386
if (failed(handleError(afterIP, opInst)))

0 commit comments

Comments
 (0)