diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 7eceec3d8cf8f..6b6e5bc19d95a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2994,27 +2994,29 @@ class OpenMPIRBuilder { /// \param Loc where the target data construct was encountered. /// \param IsOffloadEntry whether it is an offload entry. /// \param CodeGenIP The insertion point where the call to the outlined - /// function should be emitted. + /// function should be emitted. /// \param EntryInfo The entry information about the function. /// \param DefaultAttrs Structure containing the default attributes, including /// numbers of threads and teams to launch the kernel with. /// \param RuntimeAttrs Structure containing the runtime numbers of threads /// and teams to launch the kernel with. + /// \param IfCond value of the `if` clause. /// \param Inputs The input values to the region that will be passed. - /// as arguments to the outlined function. + /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. /// \param ArgAccessorFuncCB Callback that will generate accessors - /// instructions for passed in target arguments where neccessary + /// instructions for passed in target arguments where neccessary /// \param Dependencies A vector of DependData objects that carry - // dependency information as passed in the depend clause - // \param HasNowait Whether the target construct has a `nowait` clause or not. + /// dependency information as passed in the depend clause + /// \param HasNowait Whether the target construct has a `nowait` clause or + /// not. InsertPointOrErrorTy createTarget( const LocationDescription &Loc, bool IsOffloadEntry, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, - const TargetKernelRuntimeAttrs &RuntimeAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 3d461f0ad4228..c6603635d5e28 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Value *Alignment = AlignedItem.second; Instruction *loadInst = dyn_cast(AlignedPtr); Builder.SetInsertPoint(loadInst->getNextNode()); - Builder.CreateAlignmentAssumption(F->getDataLayout(), - AlignedPtr, Alignment); + Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr, + Alignment); } Builder.restoreIP(IP); } @@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) { Loop *L = LI.getLoopFor(CLI->getHeader()); assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop"); - TargetTransformInfo::UnrollingPreferences UP = - gatherUnrollingPreferences(L, SE, TTI, - /*BlockFrequencyInfo=*/nullptr, - /*ProfileSummaryInfo=*/nullptr, ORE, static_cast(OptLevel), - /*UserThreshold=*/std::nullopt, - /*UserCount=*/std::nullopt, - /*UserAllowPartial=*/true, - /*UserAllowRuntime=*/true, - /*UserUpperBound=*/std::nullopt, - /*UserFullUnrollMaxCount=*/std::nullopt); + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, SE, TTI, + /*BlockFrequencyInfo=*/nullptr, + /*ProfileSummaryInfo=*/nullptr, ORE, static_cast(OptLevel), + /*UserThreshold=*/std::nullopt, + /*UserCount=*/std::nullopt, + /*UserAllowPartial=*/true, + /*UserAllowRuntime=*/true, + /*UserUpperBound=*/std::nullopt, + /*UserFullUnrollMaxCount=*/std::nullopt); UP.Force = true; @@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, - Function *OutlinedFn, Constant *OutlinedFnID, + Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector Dependencies = {}, @@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Error::success(); }; - // If we don't have an ID for the target region, it means an offload entry - // wasn't created. In this case we just run the host fallback directly. - if (!OutlinedFnID) { + auto &&EmitTargetCallElse = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { // Assume no error was returned because EmitTargetCallFallbackCB doesn't // produce any. OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { @@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, }()); Builder.restoreIP(AfterIP); - return; - } - - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); - OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); - - SmallVector NumTeamsC; - for (auto [DefaultVal, RuntimeVal] : - zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) - NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal)); - - // Calculate number of threads: 0 if no clauses specified, otherwise it is the - // minimum between optional THREAD_LIMIT and NUM_THREADS clauses. - auto InitMaxThreadsClause = [&Builder](Value *Clause) { - if (Clause) - Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), - /*isSigned=*/false); - return Clause; + return Error::success(); }; - auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { - if (Clause) - Result = Result - ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + + auto &&EmitTargetCallThen = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, + RTArgs, MapInfo, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false); + + SmallVector NumTeamsC; + for (auto [DefaultVal, RuntimeVal] : + zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) + NumTeamsC.push_back(RuntimeVal ? RuntimeVal + : Builder.getInt32(DefaultVal)); + + // Calculate number of threads: 0 if no clauses specified, otherwise it is + // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = + Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), Result, Clause) : Clause; - }; + }; - // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so - // the NUM_THREADS clause is overriden by THREAD_LIMIT. - SmallVector NumThreadsC; - Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1 - ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) - : nullptr; + // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so + // the NUM_THREADS clause is overriden by THREAD_LIMIT. + SmallVector NumThreadsC; + Value *MaxThreadsClause = + RuntimeAttrs.TeamsThreadLimit.size() == 1 + ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) + : nullptr; - for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit, - RuntimeAttrs.TargetThreadLimit)) { - Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); - Value *NumThreads = InitMaxThreadsClause(TargetVal); + for (auto [TeamsVal, TargetVal] : zip_equal( + RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); + Value *NumThreads = InitMaxThreadsClause(TargetVal); - CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); - CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); - NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); - } + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } - unsigned NumTargetItems = Info.NumberOfPtrs; - // TODO: Use correct device ID - Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - uint32_t SrcLocStrSize; - Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); - Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, - llvm::omp::IdentFlag(0), 0); + unsigned NumTargetItems = Info.NumberOfPtrs; + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); - Value *TripCount = RuntimeAttrs.LoopTripCount - ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, - Builder.getInt64Ty(), - /*isSigned=*/false) - : Builder.getInt64(0); + Value *TripCount = RuntimeAttrs.LoopTripCount + ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); - // TODO: Use correct DynCGGroupMem - Value *DynCGGroupMem = Builder.getInt32(0); + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); - KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, - NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); + KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); - // Assume no error was returned because TaskBodyCB and - // EmitTargetCallFallbackCB don't produce any. - OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { - // The presence of certain clauses on the target directive require the - // explicit generation of the target task. - if (RequiresOuterTargetTask) - return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, - Dependencies, HasNoWait); + // Assume no error was returned because TaskBodyCB and + // EmitTargetCallFallbackCB don't produce any. + OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { + // The presence of certain clauses on the target directive require the + // explicit generation of the target task. + if (RequiresOuterTargetTask) + return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, + Dependencies, HasNoWait); + + return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, + EmitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP); + }()); + + Builder.restoreIP(AfterIP); + return Error::success(); + }; - return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, - EmitTargetCallFallbackCB, KArgs, - DeviceID, RTLoc, AllocaIP); - }()); + // If we don't have an ID for the target region, it means an offload entry + // wasn't created. In this case we just run the host fallback directly and + // ignore any potential 'if' clauses. + if (!OutlinedFnID) { + cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP())); + return; + } + + // If there's no 'if' clause, only generate the kernel launch code path. + if (!IfCond) { + cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP())); + return; + } - Builder.restoreIP(AfterIP); + cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen, + EmitTargetCallElse, AllocaIP)); } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, - const TargetKernelRuntimeAttrs &RuntimeAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, @@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, + emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait); return Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 3b571cce09a4f..a7b513bdfdc66 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, - RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, - SimpleArgAccessorCB)); + RuntimeAttrs, /*IfCond=*/nullptr, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6561,8 +6562,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, - RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, - SimpleArgAccessorCB)); + RuntimeAttrs, /*IfCond=*/nullptr, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6660,11 +6661,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6774,11 +6776,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP, - OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + ASSERT_EXPECTED_INIT( + OpenMPIRBuilder::InsertPointTy, AfterIP, + OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, + /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0be515e63b470..abef2cb7411aa 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { result = op.emitError("not yet implemented: host evaluation of loop " "bounds in omp.target operation"); }; - auto checkIf = [&todo](auto op, LogicalResult &result) { - if (op.getIfExpr()) - result = todo("if"); - }; auto checkInReduction = [&todo](auto op, LogicalResult &result) { if (!op.getInReductionVars().empty() || op.getInReductionByref() || op.getInReductionSyms()) @@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkDevice(op, result); checkHasDeviceAddr(op, result); checkHostEval(op, result); - checkIf(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); checkPrivate(op, result); @@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::Value *ifCond = nullptr; + if (Value targetIfCond = targetOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(targetIfCond); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB, + defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir new file mode 100644 index 0000000000000..706ad4411438b --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_variable( + // CHECK-SAME: i1 %[[IF_COND:.*]]) + // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]] + + // CHECK: [[THEN_LABEL]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + // CHECK: [[OFFLOAD_CONT_LABEL]]: + // CHECK-NEXT: br label %[[END_LABEL:.*]] + + // CHECK: [[ELSE_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN]]() + // CHECK-NEXT: br label %[[END_LABEL]] + + llvm.func @target_if_true() { + %0 = llvm.mlir.constant(true) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_true() + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + llvm.func @target_if_false() { + %0 = llvm.mlir.constant(false) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_false() + // CHECK-NEXT: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}() +} + diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 392a6558dcfa6..c1e30964b2507 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) { // ----- -llvm.func @target_if(%x : i1) { - // expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target if(%x) { - omp.terminator - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32):