Skip to content

Commit 3aea11b

Browse files
committed
[OMPIRBuilder][MLIR] Add support for target 'if' clause
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.
1 parent cef1269 commit 3aea11b

File tree

6 files changed

+172
-83
lines changed

6 files changed

+172
-83
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
29652965
/// \param NumThreads Number of teams specified in the thread_limit clause.
29662966
/// \param Inputs The input values to the region that will be passed.
29672967
/// as arguments to the outlined function.
2968+
/// \param IfCond value of the `if` clause.
29682969
/// \param BodyGenCB Callback that will generate the region code.
29692970
/// \param ArgAccessorFuncCB Callback that will generate accessors
29702971
/// instructions for passed in target arguments where neccessary
29712972
/// \param Dependencies A vector of DependData objects that carry
2972-
// dependency information as passed in the depend clause
2973-
// \param HasNowait Whether the target construct has a `nowait` clause or not.
2974-
InsertPointOrErrorTy createTarget(
2975-
const LocationDescription &Loc, bool IsOffloadEntry,
2976-
OpenMPIRBuilder::InsertPointTy AllocaIP,
2977-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2978-
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
2979-
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
2980-
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
2981-
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
2982-
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
2973+
/// dependency information as passed in the depend clause
2974+
/// \param HasNowait Whether the target construct has a `nowait` clause or
2975+
/// not.
2976+
InsertPointOrErrorTy
2977+
createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
2978+
OpenMPIRBuilder::InsertPointTy AllocaIP,
2979+
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2980+
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
2981+
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
2982+
Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
2983+
TargetBodyGenCallbackTy BodyGenCB,
2984+
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
2985+
SmallVector<DependData> Dependencies = {},
2986+
bool HasNowait = false);
29832987

29842988
/// Returns __kmpc_for_static_init_* runtime function for the specified
29852989
/// size \a IVSize and sign \a IVSigned. Will create a distribute call

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73107310
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
73117311
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
73127312
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
7313+
Value *IfCond,
73137314
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73147315
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
73157316
bool HasNoWait = false) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73547355
return Error::success();
73557356
};
73567357

7357-
// If we don't have an ID for the target region, it means an offload entry
7358-
// wasn't created. In this case we just run the host fallback directly.
7359-
if (!OutlinedFnID) {
7358+
auto &&EmitTargetCallElse =
7359+
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7360+
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
73607361
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
73617362
// produce any.
73627363
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73727373
}());
73737374

73747375
Builder.restoreIP(AfterIP);
7375-
return;
7376-
}
7376+
return Error::success();
7377+
};
7378+
7379+
auto &&EmitTargetCallThen =
7380+
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7381+
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7382+
OpenMPIRBuilder::TargetDataInfo Info(
7383+
/*RequiresDevicePointerInfo=*/false,
7384+
/*SeparateBeginEndCalls=*/true);
7385+
7386+
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7387+
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7388+
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7389+
RTArgs, MapInfo,
7390+
/*IsNonContiguous=*/true,
7391+
/*ForEndCall=*/false);
7392+
7393+
SmallVector<Value *, 3> NumTeamsC;
7394+
SmallVector<Value *, 3> NumThreadsC;
7395+
for (auto V : NumTeams)
7396+
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7397+
for (auto V : NumThreads)
7398+
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7399+
7400+
unsigned NumTargetItems = Info.NumberOfPtrs;
7401+
// TODO: Use correct device ID
7402+
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7403+
uint32_t SrcLocStrSize;
7404+
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7405+
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7406+
llvm::omp::IdentFlag(0), 0);
7407+
// TODO: Use correct NumIterations
7408+
Value *NumIterations = Builder.getInt64(0);
7409+
// TODO: Use correct DynCGGroupMem
7410+
Value *DynCGGroupMem = Builder.getInt32(0);
7411+
7412+
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7413+
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7414+
DynCGGroupMem, HasNoWait);
7415+
7416+
// Assume no error was returned because TaskBodyCB and
7417+
// EmitTargetCallFallbackCB don't produce any.
7418+
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7419+
// The presence of certain clauses on the target directive require the
7420+
// explicit generation of the target task.
7421+
if (RequiresOuterTargetTask)
7422+
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7423+
Dependencies, HasNoWait);
7424+
7425+
return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7426+
EmitTargetCallFallbackCB, KArgs,
7427+
DeviceID, RTLoc, AllocaIP);
7428+
}());
73777429

7378-
OpenMPIRBuilder::TargetDataInfo Info(
7379-
/*RequiresDevicePointerInfo=*/false,
7380-
/*SeparateBeginEndCalls=*/true);
7430+
Builder.restoreIP(AfterIP);
7431+
return Error::success();
7432+
};
73817433

7382-
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7383-
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7384-
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7385-
RTArgs, MapInfo,
7386-
/*IsNonContiguous=*/true,
7387-
/*ForEndCall=*/false);
7434+
// If we don't have an ID for the target region, it means an offload entry
7435+
// wasn't created. In this case we just run the host fallback directly.
7436+
if (!OutlinedFnID) {
7437+
cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
7438+
return;
7439+
}
73887440

7389-
SmallVector<Value *, 3> NumTeamsC;
7390-
SmallVector<Value *, 3> NumThreadsC;
7391-
for (auto V : NumTeams)
7392-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7393-
for (auto V : NumThreads)
7394-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7441+
// If there's no IF clause, only generate the kernel launch code path.
7442+
if (!IfCond) {
7443+
cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
7444+
return;
7445+
}
73957446

7396-
unsigned NumTargetItems = Info.NumberOfPtrs;
7397-
// TODO: Use correct device ID
7398-
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7399-
uint32_t SrcLocStrSize;
7400-
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7401-
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7402-
llvm::omp::IdentFlag(0), 0);
7403-
// TODO: Use correct NumIterations
7404-
Value *NumIterations = Builder.getInt64(0);
7405-
// TODO: Use correct DynCGGroupMem
7406-
Value *DynCGGroupMem = Builder.getInt32(0);
7407-
7408-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7409-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7410-
DynCGGroupMem, HasNoWait);
7411-
7412-
// Assume no error was returned because TaskBodyCB and
7413-
// EmitTargetCallFallbackCB don't produce any.
7414-
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7415-
// The presence of certain clauses on the target directive require the
7416-
// explicit generation of the target task.
7417-
if (RequiresOuterTargetTask)
7418-
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7419-
Dependencies, HasNoWait);
7420-
7421-
return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7422-
EmitTargetCallFallbackCB, KArgs,
7423-
DeviceID, RTLoc, AllocaIP);
7424-
}());
7425-
7426-
Builder.restoreIP(AfterIP);
7447+
cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
7448+
EmitTargetCallElse, AllocaIP));
74277449
}
74287450

74297451
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74307452
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74317453
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
74327454
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
7433-
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7455+
SmallVectorImpl<Value *> &Args, Value *IfCond,
7456+
GenMapInfoCallbackTy GenMapInfoCB,
74347457
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74357458
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
74367459
SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74557478
// that represents the target region. Do that now.
74567479
if (!Config.isTargetDevice())
74577480
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7458-
NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
7481+
NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
7482+
HasNowait);
74597483
return Builder.saveIP();
74607484
}
74617485

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
62326232
ASSERT_EXPECTED_INIT(
62336233
OpenMPIRBuilder::InsertPointTy, AfterIP,
62346234
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
6235-
Builder.saveIP(), EntryInfo, -1, 0, Inputs,
6235+
Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
6236+
/*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
62366237
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
62376238
Builder.restoreIP(AfterIP);
62386239
OMPBuilder.finalize();
@@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
63436344
ASSERT_EXPECTED_INIT(
63446345
OpenMPIRBuilder::InsertPointTy, AfterIP,
63456346
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6346-
EntryInfo, /*NumTeams=*/-1,
6347-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6347+
EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
6348+
CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
63486349
BodyGenCB, SimpleArgAccessorCB));
63496350
Builder.restoreIP(AfterIP);
63506351

@@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
65006501
ASSERT_EXPECTED_INIT(
65016502
OpenMPIRBuilder::InsertPointTy, AfterIP,
65026503
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6503-
EntryInfo, /*NumTeams=*/-1,
6504-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6504+
EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
6505+
CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
65056506
BodyGenCB, SimpleArgAccessorCB));
65066507
Builder.restoreIP(AfterIP);
65076508

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
285285
checkBare(op, result);
286286
checkDevice(op, result);
287287
checkHasDeviceAddr(op, result);
288-
checkIf(op, result);
289288
checkInReduction(op, result);
290289
checkIsDevicePtr(op, result);
291290
// Privatization clauses are supported, except on some situations, so we
@@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
41124111
findAllocaInsertPoint(builder, moduleTranslation);
41134112
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
41144113

4114+
llvm::Value *ifCond = nullptr;
4115+
if (Value targetIfCond = targetOp.getIfExpr())
4116+
ifCond = moduleTranslation.lookupValue(targetIfCond);
4117+
41154118
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
41164119
moduleTranslation.getOpenMPBuilder()->createTarget(
41174120
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4118-
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
4119-
argAccessorCB, dds, targetOp.getNowait());
4121+
defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
4122+
bodyCB, argAccessorCB, dds, targetOp.getNowait());
41204123

41214124
if (failed(handleError(afterIP, opInst)))
41224125
return failure();
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
4+
llvm.func @target_if_variable(%x : i1) {
5+
omp.target if(%x) {
6+
omp.terminator
7+
}
8+
llvm.return
9+
}
10+
11+
// CHECK-LABEL: define void @target_if_variable(
12+
// CHECK-SAME: i1 %[[IF_COND:.*]])
13+
// CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
14+
15+
// CHECK: [[THEN_LABEL]]:
16+
// CHECK-NOT: {{^.*}}:
17+
// CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
18+
// CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
19+
// CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
20+
21+
// CHECK: [[OFFLOAD_FAIL_LABEL]]:
22+
// CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
23+
// CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
24+
25+
// CHECK: [[OFFLOAD_CONT_LABEL]]:
26+
// CHECK-NEXT: br label %[[END_LABEL:.*]]
27+
28+
// CHECK: [[ELSE_LABEL]]:
29+
// CHECK-NEXT: call void @[[FALLBACK_FN]]()
30+
// CHECK-NEXT: br label %[[END_LABEL]]
31+
32+
llvm.func @target_if_true() {
33+
%0 = llvm.mlir.constant(true) : i1
34+
omp.target if(%0) {
35+
omp.terminator
36+
}
37+
llvm.return
38+
}
39+
40+
// CHECK-LABEL: define void @target_if_true()
41+
// CHECK-NOT: {{^.*}}:
42+
// CHECK: br label %[[ENTRY:.*]]
43+
44+
// CHECK: [[ENTRY]]:
45+
// CHECK-NOT: {{^.*}}:
46+
// CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
47+
// CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
48+
// CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
49+
50+
// CHECK: [[OFFLOAD_FAIL_LABEL]]:
51+
// CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
52+
// CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
53+
54+
llvm.func @target_if_false() {
55+
%0 = llvm.mlir.constant(false) : i1
56+
omp.target if(%0) {
57+
omp.terminator
58+
}
59+
llvm.return
60+
}
61+
62+
// CHECK-LABEL: define void @target_if_false()
63+
// CHECK-NEXT: br label %[[ENTRY:.*]]
64+
65+
// CHECK: [[ENTRY]]:
66+
// CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
67+
}
68+

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
266266

267267
// -----
268268

269-
llvm.func @target_if(%x : i1) {
270-
// expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}}
271-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
272-
omp.target if(%x) {
273-
omp.terminator
274-
}
275-
llvm.return
276-
}
277-
278-
// -----
279-
280269
omp.declare_reduction @add_f32 : f32
281270
init {
282271
^bb0(%arg: f32):

0 commit comments

Comments
 (0)