-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[OMPIRBuilder][MLIR] Add support for target 'if' clause #122478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-flang-openmp Author: Sergio Afonso (skatrak) ChangesThis patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the Full diff: https://github.com/llvm/llvm-project/pull/122478.diff 6 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..b1a23996c7bdd2 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
/// \param NumThreads Number of teams specified in the thread_limit clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
+ /// \param IfCond value of the `if` clause.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
- // dependency information as passed in the depend clause
- // \param HasNowait Whether the target construct has a `nowait` clause or not.
- InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
- OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
- TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
- SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+ /// dependency information as passed in the depend clause
+ /// \param HasNowait Whether the target construct has a `nowait` clause or
+ /// not.
+ InsertPointOrErrorTy
+ createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+ ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+ Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
+ TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+ SmallVector<DependData> Dependencies = {},
+ bool HasNowait = false);
/// Returns __kmpc_for_static_init_* runtime function for the specified
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index db77c6a5869764..0e190f4c64a8b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};
- // If we don't have an ID for the target region, it means an offload entry
- // wasn't created. In this case we just run the host fallback directly.
- if (!OutlinedFnID) {
+ auto &&EmitTargetCallElse =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());
Builder.restoreIP(AfterIP);
- return;
- }
+ return Error::success();
+ };
+
+ auto &&EmitTargetCallThen =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+ OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
+
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+ OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+ RTArgs, MapInfo,
+ /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
+
+ SmallVector<Value *, 3> NumTeamsC;
+ SmallVector<Value *, 3> NumThreadsC;
+ for (auto V : NumTeams)
+ NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ for (auto V : NumThreads)
+ NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+
+ unsigned NumTargetItems = Info.NumberOfPtrs;
+ // TODO: Use correct device ID
+ Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+ Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+ llvm::omp::IdentFlag(0), 0);
+ // TODO: Use correct NumIterations
+ Value *NumIterations = Builder.getInt64(0);
+ // TODO: Use correct DynCGGroupMem
+ Value *DynCGGroupMem = Builder.getInt32(0);
+
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
+
+ // Assume no error was returned because TaskBodyCB and
+ // EmitTargetCallFallbackCB don't produce any.
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+ // The presence of certain clauses on the target directive require the
+ // explicit generation of the target task.
+ if (RequiresOuterTargetTask)
+ return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+ Dependencies, HasNoWait);
+
+ return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+ EmitTargetCallFallbackCB, KArgs,
+ DeviceID, RTLoc, AllocaIP);
+ }());
- OpenMPIRBuilder::TargetDataInfo Info(
- /*RequiresDevicePointerInfo=*/false,
- /*SeparateBeginEndCalls=*/true);
+ Builder.restoreIP(AfterIP);
+ return Error::success();
+ };
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
- OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
- RTArgs, MapInfo,
- /*IsNonContiguous=*/true,
- /*ForEndCall=*/false);
+ // If we don't have an ID for the target region, it means an offload entry
+ // wasn't created. In this case we just run the host fallback directly.
+ if (!OutlinedFnID) {
+ cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+ return;
+ }
- SmallVector<Value *, 3> NumTeamsC;
- SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ // If there's no IF clause, only generate the kernel launch code path.
+ if (!IfCond) {
+ cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+ return;
+ }
- unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
- uint32_t SrcLocStrSize;
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
- Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
- llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
- // TODO: Use correct DynCGGroupMem
- Value *DynCGGroupMem = Builder.getInt32(0);
-
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
-
- // Assume no error was returned because TaskBodyCB and
- // EmitTargetCallFallbackCB don't produce any.
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
- // The presence of certain clauses on the target directive require the
- // explicit generation of the target task.
- if (RequiresOuterTargetTask)
- return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
- Dependencies, HasNoWait);
-
- return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
- EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP);
- }());
-
- Builder.restoreIP(AfterIP);
+ cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+ EmitTargetCallElse, AllocaIP));
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
- SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+ SmallVectorImpl<Value *> &Args, Value *IfCond,
+ GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index cdca725b147436..94dce5243d7004 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+ Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
+ /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
@@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
@@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a364098e0bd8a6..0c637bd32ab3f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
- checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
// Privatization clauses are supported, except on some situations, so we
@@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Value *ifCond = nullptr;
+ if (Value targetIfCond = targetOp.getIfExpr())
+ ifCond = moduleTranslation.lookupValue(targetIfCond);
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, dds, targetOp.getNowait());
+ defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
+ bodyCB, argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @target_if_variable(%x : i1) {
+ omp.target if(%x) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_variable(
+ // CHECK-SAME: i1 %[[IF_COND:.*]])
+ // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+ // CHECK: [[THEN_LABEL]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ // CHECK: [[OFFLOAD_CONT_LABEL]]:
+ // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+ // CHECK: [[ELSE_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+ // CHECK-NEXT: br label %[[END_LABEL]]
+
+ llvm.func @target_if_true() {
+ %0 = llvm.mlir.constant(true) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_true()
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ llvm.func @target_if_false() {
+ %0 = llvm.mlir.constant(false) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_false()
+ // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 83a0990d631620..4e0925c833c3b7 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
// -----
-llvm.func @target_if(%x : i1) {
- // expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.target}}
- omp.target if(%x) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
|
3aea11b
to
83c1d79
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Sergio! LGTM
cef1269
to
364cd46
Compare
83c1d79
to
8c348ba
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
d14f385
to
cb022ad
Compare
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.
cb022ad
to
9fe2c5f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the
omp.target
MLIR operation to make use of this new feature.