@@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5308
5308
Value *Alignment = AlignedItem.second ;
5309
5309
Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
5310
5310
Builder.SetInsertPoint (loadInst->getNextNode ());
5311
- Builder.CreateAlignmentAssumption (F->getDataLayout (),
5312
- AlignedPtr, Alignment);
5311
+ Builder.CreateAlignmentAssumption (F->getDataLayout (), AlignedPtr,
5312
+ Alignment);
5313
5313
}
5314
5314
Builder.restoreIP (IP);
5315
5315
}
@@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5457
5457
Loop *L = LI.getLoopFor (CLI->getHeader ());
5458
5458
assert (L && " Expecting CanonicalLoopInfo to be recognized as a loop" );
5459
5459
5460
- TargetTransformInfo::UnrollingPreferences UP =
5461
- gatherUnrollingPreferences ( L, SE, TTI,
5462
- /* BlockFrequencyInfo=*/ nullptr ,
5463
- /* ProfileSummaryInfo=*/ nullptr , ORE, static_cast <int >(OptLevel),
5464
- /* UserThreshold=*/ std::nullopt,
5465
- /* UserCount=*/ std::nullopt,
5466
- /* UserAllowPartial=*/ true ,
5467
- /* UserAllowRuntime=*/ true ,
5468
- /* UserUpperBound=*/ std::nullopt,
5469
- /* UserFullUnrollMaxCount=*/ std::nullopt);
5460
+ TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences (
5461
+ L, SE, TTI,
5462
+ /* BlockFrequencyInfo=*/ nullptr ,
5463
+ /* ProfileSummaryInfo=*/ nullptr , ORE, static_cast <int >(OptLevel),
5464
+ /* UserThreshold=*/ std::nullopt,
5465
+ /* UserCount=*/ std::nullopt,
5466
+ /* UserAllowPartial=*/ true ,
5467
+ /* UserAllowRuntime=*/ true ,
5468
+ /* UserUpperBound=*/ std::nullopt,
5469
+ /* UserFullUnrollMaxCount=*/ std::nullopt);
5470
5470
5471
5471
UP.Force = true ;
5472
5472
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7340
7340
OpenMPIRBuilder::InsertPointTy AllocaIP,
7341
7341
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7342
7342
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7343
- Function *OutlinedFn, Constant *OutlinedFnID,
7343
+ Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
7344
7344
SmallVectorImpl<Value *> &Args,
7345
7345
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7346
7346
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7386
7386
return Error::success ();
7387
7387
};
7388
7388
7389
- // If we don't have an ID for the target region, it means an offload entry
7390
- // wasn't created. In this case we just run the host fallback directly.
7391
- if (!OutlinedFnID) {
7389
+ auto &&EmitTargetCallElse =
7390
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7391
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7392
7392
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
7393
7393
// produce any.
7394
7394
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
@@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7404
7404
}());
7405
7405
7406
7406
Builder.restoreIP (AfterIP);
7407
- return ;
7408
- }
7409
-
7410
- OpenMPIRBuilder::TargetDataInfo Info (
7411
- /* RequiresDevicePointerInfo=*/ false ,
7412
- /* SeparateBeginEndCalls=*/ true );
7413
-
7414
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7415
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7416
- OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7417
- RTArgs, MapInfo,
7418
- /* IsNonContiguous=*/ true ,
7419
- /* ForEndCall=*/ false );
7420
-
7421
- SmallVector<Value *, 3 > NumTeamsC;
7422
- for (auto [DefaultVal, RuntimeVal] :
7423
- zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7424
- NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7425
-
7426
- // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7427
- // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7428
- auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7429
- if (Clause)
7430
- Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7431
- /* isSigned=*/ false );
7432
- return Clause;
7407
+ return Error::success ();
7433
7408
};
7434
- auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7435
- if (Clause)
7436
- Result = Result
7437
- ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7409
+
7410
+ auto &&EmitTargetCallThen =
7411
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7412
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7413
+ OpenMPIRBuilder::TargetDataInfo Info (
7414
+ /* RequiresDevicePointerInfo=*/ false ,
7415
+ /* SeparateBeginEndCalls=*/ true );
7416
+
7417
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7418
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7419
+ OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7420
+ RTArgs, MapInfo,
7421
+ /* IsNonContiguous=*/ true ,
7422
+ /* ForEndCall=*/ false );
7423
+
7424
+ SmallVector<Value *, 3 > NumTeamsC;
7425
+ for (auto [DefaultVal, RuntimeVal] :
7426
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7427
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal
7428
+ : Builder.getInt32 (DefaultVal));
7429
+
7430
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is
7431
+ // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7432
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7433
+ if (Clause)
7434
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7435
+ /* isSigned=*/ false );
7436
+ return Clause;
7437
+ };
7438
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7439
+ if (Clause)
7440
+ Result =
7441
+ Result ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7438
7442
Result, Clause)
7439
7443
: Clause;
7440
- };
7444
+ };
7441
7445
7442
- // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7443
- // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7444
- SmallVector<Value *, 3 > NumThreadsC;
7445
- Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7446
- ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7447
- : nullptr ;
7446
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7447
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7448
+ SmallVector<Value *, 3 > NumThreadsC;
7449
+ Value *MaxThreadsClause =
7450
+ RuntimeAttrs.TeamsThreadLimit .size () == 1
7451
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7452
+ : nullptr ;
7448
7453
7449
- for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs. TeamsThreadLimit ,
7450
- RuntimeAttrs.TargetThreadLimit )) {
7451
- Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7452
- Value *NumThreads = InitMaxThreadsClause (TargetVal);
7454
+ for (auto [TeamsVal, TargetVal] : zip_equal (
7455
+ RuntimeAttrs. TeamsThreadLimit , RuntimeAttrs.TargetThreadLimit )) {
7456
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7457
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7453
7458
7454
- CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7455
- CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7459
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7460
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7456
7461
7457
- NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7458
- }
7462
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7463
+ }
7459
7464
7460
- unsigned NumTargetItems = Info.NumberOfPtrs ;
7461
- // TODO: Use correct device ID
7462
- Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7463
- uint32_t SrcLocStrSize;
7464
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7465
- Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7466
- llvm::omp::IdentFlag (0 ), 0 );
7465
+ unsigned NumTargetItems = Info.NumberOfPtrs ;
7466
+ // TODO: Use correct device ID
7467
+ Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7468
+ uint32_t SrcLocStrSize;
7469
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7470
+ Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7471
+ llvm::omp::IdentFlag (0 ), 0 );
7467
7472
7468
- Value *TripCount = RuntimeAttrs.LoopTripCount
7469
- ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7470
- Builder.getInt64Ty (),
7471
- /* isSigned=*/ false )
7472
- : Builder.getInt64 (0 );
7473
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7474
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7475
+ Builder.getInt64Ty (),
7476
+ /* isSigned=*/ false )
7477
+ : Builder.getInt64 (0 );
7473
7478
7474
- // TODO: Use correct DynCGGroupMem
7475
- Value *DynCGGroupMem = Builder.getInt32 (0 );
7479
+ // TODO: Use correct DynCGGroupMem
7480
+ Value *DynCGGroupMem = Builder.getInt32 (0 );
7476
7481
7477
- KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7478
- NumTeamsC, NumThreadsC,
7479
- DynCGGroupMem, HasNoWait);
7482
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7483
+ NumTeamsC, NumThreadsC,
7484
+ DynCGGroupMem, HasNoWait);
7480
7485
7481
- // Assume no error was returned because TaskBodyCB and
7482
- // EmitTargetCallFallbackCB don't produce any.
7483
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7484
- // The presence of certain clauses on the target directive require the
7485
- // explicit generation of the target task.
7486
- if (RequiresOuterTargetTask)
7487
- return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7488
- Dependencies, HasNoWait);
7486
+ // Assume no error was returned because TaskBodyCB and
7487
+ // EmitTargetCallFallbackCB don't produce any.
7488
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7489
+ // The presence of certain clauses on the target directive require the
7490
+ // explicit generation of the target task.
7491
+ if (RequiresOuterTargetTask)
7492
+ return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7493
+ Dependencies, HasNoWait);
7494
+
7495
+ return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7496
+ EmitTargetCallFallbackCB, KArgs,
7497
+ DeviceID, RTLoc, AllocaIP);
7498
+ }());
7499
+
7500
+ Builder.restoreIP (AfterIP);
7501
+ return Error::success ();
7502
+ };
7489
7503
7490
- return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7491
- EmitTargetCallFallbackCB, KArgs,
7492
- DeviceID, RTLoc, AllocaIP);
7493
- }());
7504
+ // If we don't have an ID for the target region, it means an offload entry
7505
+ // wasn't created. In this case we just run the host fallback directly and
7506
+ // ignore any potential 'if' clauses.
7507
+ if (!OutlinedFnID) {
7508
+ cantFail (EmitTargetCallElse (AllocaIP, Builder.saveIP ()));
7509
+ return ;
7510
+ }
7511
+
7512
+ // If there's no 'if' clause, only generate the kernel launch code path.
7513
+ if (!IfCond) {
7514
+ cantFail (EmitTargetCallThen (AllocaIP, Builder.saveIP ()));
7515
+ return ;
7516
+ }
7494
7517
7495
- Builder.restoreIP (AfterIP);
7518
+ cantFail (OMPBuilder.emitIfClause (IfCond, EmitTargetCallThen,
7519
+ EmitTargetCallElse, AllocaIP));
7496
7520
}
7497
7521
7498
7522
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7499
7523
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7500
7524
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7501
7525
const TargetKernelDefaultAttrs &DefaultAttrs,
7502
- const TargetKernelRuntimeAttrs &RuntimeAttrs,
7526
+ const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7503
7527
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7504
7528
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7505
7529
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7524
7548
// to make a remote call (offload) to the previously outlined function
7525
7549
// that represents the target region. Do that now.
7526
7550
if (!Config.isTargetDevice ())
7527
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7551
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
7528
7552
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7529
7553
HasNowait);
7530
7554
return Builder.saveIP ();
0 commit comments