From 9fe2c5f70d35641715191b9c0ba1b8e33045d530 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 10 Jan 2025 15:40:05 +0000 Subject: [PATCH] [OMPIRBuilder][MLIR] Add support for target 'if' clause This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 14 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 210 ++++++++++-------- .../Frontend/OpenMPIRBuilderTest.cpp | 41 ++-- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 11 +- mlir/test/Target/LLVMIR/omptarget-if.mlir | 68 ++++++ mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 - 6 files changed, 220 insertions(+), 135 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/omptarget-if.mlir diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 7eceec3d8cf8f..6b6e5bc19d95a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2994,27 +2994,29 @@ class OpenMPIRBuilder { /// \param Loc where the target data construct was encountered. /// \param IsOffloadEntry whether it is an offload entry. /// \param CodeGenIP The insertion point where the call to the outlined - /// function should be emitted. + /// function should be emitted. /// \param EntryInfo The entry information about the function. /// \param DefaultAttrs Structure containing the default attributes, including /// numbers of threads and teams to launch the kernel with. /// \param RuntimeAttrs Structure containing the runtime numbers of threads /// and teams to launch the kernel with. + /// \param IfCond value of the `if` clause. /// \param Inputs The input values to the region that will be passed. - /// as arguments to the outlined function. + /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. /// \param ArgAccessorFuncCB Callback that will generate accessors - /// instructions for passed in target arguments where neccessary + /// instructions for passed in target arguments where neccessary /// \param Dependencies A vector of DependData objects that carry - // dependency information as passed in the depend clause - // \param HasNowait Whether the target construct has a `nowait` clause or not. + /// dependency information as passed in the depend clause + /// \param HasNowait Whether the target construct has a `nowait` clause or + /// not. InsertPointOrErrorTy createTarget( const LocationDescription &Loc, bool IsOffloadEntry, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, - const TargetKernelRuntimeAttrs &RuntimeAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 3d461f0ad4228..c6603635d5e28 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Value *Alignment = AlignedItem.second; Instruction *loadInst = dyn_cast(AlignedPtr); Builder.SetInsertPoint(loadInst->getNextNode()); - Builder.CreateAlignmentAssumption(F->getDataLayout(), - AlignedPtr, Alignment); + Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr, + Alignment); } Builder.restoreIP(IP); } @@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) { Loop *L = LI.getLoopFor(CLI->getHeader()); assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop"); - TargetTransformInfo::UnrollingPreferences UP = - gatherUnrollingPreferences(L, SE, TTI, - /*BlockFrequencyInfo=*/nullptr, - /*ProfileSummaryInfo=*/nullptr, ORE, static_cast(OptLevel), - /*UserThreshold=*/std::nullopt, - /*UserCount=*/std::nullopt, - /*UserAllowPartial=*/true, - /*UserAllowRuntime=*/true, - /*UserUpperBound=*/std::nullopt, - /*UserFullUnrollMaxCount=*/std::nullopt); + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, SE, TTI, + /*BlockFrequencyInfo=*/nullptr, + /*ProfileSummaryInfo=*/nullptr, ORE, static_cast(OptLevel), + /*UserThreshold=*/std::nullopt, + /*UserCount=*/std::nullopt, + /*UserAllowPartial=*/true, + /*UserAllowRuntime=*/true, + /*UserUpperBound=*/std::nullopt, + /*UserFullUnrollMaxCount=*/std::nullopt); UP.Force = true; @@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, - Function *OutlinedFn, Constant *OutlinedFnID, + Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector Dependencies = {}, @@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Error::success(); }; - // If we don't have an ID for the target region, it means an offload entry - // wasn't created. In this case we just run the host fallback directly. - if (!OutlinedFnID) { + auto &&EmitTargetCallElse = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { // Assume no error was returned because EmitTargetCallFallbackCB doesn't // produce any. OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { @@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, }()); Builder.restoreIP(AfterIP); - return; - } - - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); - OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); - - SmallVector NumTeamsC; - for (auto [DefaultVal, RuntimeVal] : - zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) - NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal)); - - // Calculate number of threads: 0 if no clauses specified, otherwise it is the - // minimum between optional THREAD_LIMIT and NUM_THREADS clauses. - auto InitMaxThreadsClause = [&Builder](Value *Clause) { - if (Clause) - Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), - /*isSigned=*/false); - return Clause; + return Error::success(); }; - auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { - if (Clause) - Result = Result - ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + + auto &&EmitTargetCallThen = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, + RTArgs, MapInfo, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false); + + SmallVector NumTeamsC; + for (auto [DefaultVal, RuntimeVal] : + zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) + NumTeamsC.push_back(RuntimeVal ? RuntimeVal + : Builder.getInt32(DefaultVal)); + + // Calculate number of threads: 0 if no clauses specified, otherwise it is + // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = + Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), Result, Clause) : Clause; - }; + }; - // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so - // the NUM_THREADS clause is overriden by THREAD_LIMIT. - SmallVector NumThreadsC; - Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1 - ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) - : nullptr; + // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so + // the NUM_THREADS clause is overriden by THREAD_LIMIT. + SmallVector NumThreadsC; + Value *MaxThreadsClause = + RuntimeAttrs.TeamsThreadLimit.size() == 1 + ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) + : nullptr; - for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit, - RuntimeAttrs.TargetThreadLimit)) { - Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); - Value *NumThreads = InitMaxThreadsClause(TargetVal); + for (auto [TeamsVal, TargetVal] : zip_equal( + RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); + Value *NumThreads = InitMaxThreadsClause(TargetVal); - CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); - CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); - NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); - } + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } - unsigned NumTargetItems = Info.NumberOfPtrs; - // TODO: Use correct device ID - Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - uint32_t SrcLocStrSize; - Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); - Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, - llvm::omp::IdentFlag(0), 0); + unsigned NumTargetItems = Info.NumberOfPtrs; + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); - Value *TripCount = RuntimeAttrs.LoopTripCount - ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, - Builder.getInt64Ty(), - /*isSigned=*/false) - : Builder.getInt64(0); + Value *TripCount = RuntimeAttrs.LoopTripCount + ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); - // TODO: Use correct DynCGGroupMem - Value *DynCGGroupMem = Builder.getInt32(0); + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); - KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, - NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); + KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); - // Assume no error was returned because TaskBodyCB and - // EmitTargetCallFallbackCB don't produce any. - OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { - // The presence of certain clauses on the target directive require the - // explicit generation of the target task. - if (RequiresOuterTargetTask) - return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, - Dependencies, HasNoWait); + // Assume no error was returned because TaskBodyCB and + // EmitTargetCallFallbackCB don't produce any. + OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { + // The presence of certain clauses on the target directive require the + // explicit generation of the target task. + if (RequiresOuterTargetTask) + return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, + Dependencies, HasNoWait); + + return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, + EmitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP); + }()); + + Builder.restoreIP(AfterIP); + return Error::success(); + }; - return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, - EmitTargetCallFallbackCB, KArgs, - DeviceID, RTLoc, AllocaIP); - }()); + // If we don't have an ID for the target region, it means an offload entry + // wasn't created. In this case we just run the host fallback directly and + // ignore any potential 'if' clauses. + if (!OutlinedFnID) { + cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP())); + return; + } + + // If there's no 'if' clause, only generate the kernel launch code path. + if (!IfCond) { + cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP())); + return; + } - Builder.restoreIP(AfterIP); + cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen, + EmitTargetCallElse, AllocaIP)); } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, - const TargetKernelRuntimeAttrs &RuntimeAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, @@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, + emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait); return Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 3b571cce09a4f..a7b513bdfdc66 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, - RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, - SimpleArgAccessorCB)); + RuntimeAttrs, /*IfCond=*/nullptr, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6561,8 +6562,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, - RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, - SimpleArgAccessorCB)); + RuntimeAttrs, /*IfCond=*/nullptr, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6660,11 +6661,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6774,11 +6776,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0be515e63b470..abef2cb7411aa 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { result = op.emitError("not yet implemented: host evaluation of loop " "bounds in omp.target operation"); }; - auto checkIf = [&todo](auto op, LogicalResult &result) { - if (op.getIfExpr()) - result = todo("if"); - }; auto checkInReduction = [&todo](auto op, LogicalResult &result) { if (!op.getInReductionVars().empty() || op.getInReductionByref() || op.getInReductionSyms()) @@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkDevice(op, result); checkHasDeviceAddr(op, result); checkHostEval(op, result); - checkIf(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); checkPrivate(op, result); @@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::Value *ifCond = nullptr; + if (Value targetIfCond = targetOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(targetIfCond); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB, + defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir new file mode 100644 index 0000000000000..706ad4411438b --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_variable( + // CHECK-SAME: i1 %[[IF_COND:.*]]) + // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]] + + // CHECK: [[THEN_LABEL]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + // CHECK: [[OFFLOAD_CONT_LABEL]]: + // CHECK-NEXT: br label %[[END_LABEL:.*]] + + // CHECK: [[ELSE_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN]]() + // CHECK-NEXT: br label %[[END_LABEL]] + + llvm.func @target_if_true() { + %0 = llvm.mlir.constant(true) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_true() + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + llvm.func @target_if_false() { + %0 = llvm.mlir.constant(false) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_false() + // CHECK-NEXT: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}() +} + diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 392a6558dcfa6..c1e30964b2507 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) { // ----- -llvm.func @target_if(%x : i1) { - // expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target if(%x) { - omp.terminator - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32):