@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7340
7340
OpenMPIRBuilder::InsertPointTy AllocaIP,
7341
7341
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7342
7342
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7343
- Function *OutlinedFn, Constant *OutlinedFnID,
7343
+ Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
7344
7344
SmallVectorImpl<Value *> &Args,
7345
7345
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7346
7346
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7386
7386
return Error::success ();
7387
7387
};
7388
7388
7389
- // If we don't have an ID for the target region, it means an offload entry
7390
- // wasn't created. In this case we just run the host fallback directly.
7391
- if (!OutlinedFnID) {
7389
+ auto &&EmitTargetCallElse =
7390
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7391
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7392
7392
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
7393
7393
// produce any.
7394
7394
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
@@ -7404,102 +7404,124 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7404
7404
}());
7405
7405
7406
7406
Builder.restoreIP (AfterIP);
7407
- return ;
7408
- }
7409
-
7410
- OpenMPIRBuilder::TargetDataInfo Info (
7411
- /* RequiresDevicePointerInfo=*/ false ,
7412
- /* SeparateBeginEndCalls=*/ true );
7413
-
7414
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7415
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7416
- OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7417
- RTArgs, MapInfo,
7418
- /* IsNonContiguous=*/ true ,
7419
- /* ForEndCall=*/ false );
7420
-
7421
- SmallVector<Value *, 3 > NumTeamsC;
7422
- for (auto [DefaultVal, RuntimeVal] :
7423
- zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7424
- NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7425
-
7426
- // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7427
- // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7428
- auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7429
- if (Clause)
7430
- Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7431
- /* isSigned=*/ false );
7432
- return Clause;
7433
- };
7434
- auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7435
- if (Clause)
7436
- Result = Result
7437
- ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7438
- Result, Clause)
7439
- : Clause;
7407
+ return Error::success ();
7440
7408
};
7441
7409
7442
- // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7443
- // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7444
- SmallVector<Value *, 3 > NumThreadsC;
7445
- Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7446
- ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7447
- : nullptr ;
7410
+ auto &&EmitTargetCallThen =
7411
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7412
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7413
+ OpenMPIRBuilder::TargetDataInfo Info (
7414
+ /* RequiresDevicePointerInfo=*/ false ,
7415
+ /* SeparateBeginEndCalls=*/ true );
7416
+
7417
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7418
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7419
+ OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7420
+ RTArgs, MapInfo,
7421
+ /* IsNonContiguous=*/ true ,
7422
+ /* ForEndCall=*/ false );
7423
+
7424
+ SmallVector<Value *, 3 > NumTeamsC;
7425
+ for (auto [DefaultVal, RuntimeVal] :
7426
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7427
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7428
+
7429
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7430
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7431
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7432
+ if (Clause)
7433
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7434
+ /* isSigned=*/ false );
7435
+ return Clause;
7436
+ };
7437
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7438
+ if (Clause)
7439
+ Result = Result
7440
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7441
+ Result, Clause)
7442
+ : Clause;
7443
+ };
7448
7444
7449
- for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs.TeamsThreadLimit ,
7450
- RuntimeAttrs.TargetThreadLimit )) {
7451
- Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7452
- Value *NumThreads = InitMaxThreadsClause (TargetVal);
7445
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7446
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7447
+ SmallVector<Value *, 3 > NumThreadsC;
7448
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7449
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7450
+ : nullptr ;
7453
7451
7454
- CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7455
- CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7452
+ for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs.TeamsThreadLimit ,
7453
+ RuntimeAttrs.TargetThreadLimit )) {
7454
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7455
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7456
7456
7457
- NumThreadsC. push_back (NumThreads ? NumThreads : Builder. getInt32 ( 0 ) );
7458
- }
7457
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads );
7458
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7459
7459
7460
- unsigned NumTargetItems = Info.NumberOfPtrs ;
7461
- // TODO: Use correct device ID
7462
- Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7463
- uint32_t SrcLocStrSize;
7464
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7465
- Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7466
- llvm::omp::IdentFlag (0 ), 0 );
7460
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7461
+ }
7467
7462
7468
- Value *TripCount = RuntimeAttrs.LoopTripCount
7469
- ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7470
- Builder.getInt64Ty (),
7471
- /* isSigned=*/ false )
7472
- : Builder.getInt64 (0 );
7463
+ unsigned NumTargetItems = Info.NumberOfPtrs ;
7464
+ // TODO: Use correct device ID
7465
+ Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7466
+ uint32_t SrcLocStrSize;
7467
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7468
+ Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7469
+ llvm::omp::IdentFlag (0 ), 0 );
7473
7470
7474
- // TODO: Use correct DynCGGroupMem
7475
- Value *DynCGGroupMem = Builder.getInt32 (0 );
7471
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7472
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7473
+ Builder.getInt64Ty (),
7474
+ /* isSigned=*/ false )
7475
+ : Builder.getInt64 (0 );
7476
7476
7477
- KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7478
- NumTeamsC, NumThreadsC,
7479
- DynCGGroupMem, HasNoWait);
7477
+ // TODO: Use correct DynCGGroupMem
7478
+ Value *DynCGGroupMem = Builder.getInt32 (0 );
7480
7479
7481
- // Assume no error was returned because TaskBodyCB and
7482
- // EmitTargetCallFallbackCB don't produce any.
7483
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7484
- // The presence of certain clauses on the target directive require the
7485
- // explicit generation of the target task.
7486
- if (RequiresOuterTargetTask)
7487
- return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7488
- Dependencies, HasNoWait);
7480
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7481
+ NumTeamsC, NumThreadsC,
7482
+ DynCGGroupMem, HasNoWait);
7489
7483
7490
- return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7491
- EmitTargetCallFallbackCB, KArgs,
7492
- DeviceID, RTLoc, AllocaIP);
7493
- }());
7484
+ // Assume no error was returned because TaskBodyCB and
7485
+ // EmitTargetCallFallbackCB don't produce any.
7486
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7487
+ // The presence of certain clauses on the target directive require the
7488
+ // explicit generation of the target task.
7489
+ if (RequiresOuterTargetTask)
7490
+ return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7491
+ Dependencies, HasNoWait);
7492
+
7493
+ return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7494
+ EmitTargetCallFallbackCB, KArgs,
7495
+ DeviceID, RTLoc, AllocaIP);
7496
+ }());
7497
+
7498
+ Builder.restoreIP (AfterIP);
7499
+ return Error::success ();
7500
+ };
7501
+
7502
+ // If we don't have an ID for the target region, it means an offload entry
7503
+ // wasn't created. In this case we just run the host fallback directly and
7504
+ // ignore any potential 'if' clauses.
7505
+ if (!OutlinedFnID) {
7506
+ cantFail (EmitTargetCallElse (AllocaIP, Builder.saveIP ()));
7507
+ return ;
7508
+ }
7509
+
7510
+ // If there's no 'if' clause, only generate the kernel launch code path.
7511
+ if (!IfCond) {
7512
+ cantFail (EmitTargetCallThen (AllocaIP, Builder.saveIP ()));
7513
+ return ;
7514
+ }
7494
7515
7495
- Builder.restoreIP (AfterIP);
7516
+ cantFail (OMPBuilder.emitIfClause (IfCond, EmitTargetCallThen,
7517
+ EmitTargetCallElse, AllocaIP));
7496
7518
}
7497
7519
7498
7520
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7499
7521
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7500
7522
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7501
7523
const TargetKernelDefaultAttrs &DefaultAttrs,
7502
- const TargetKernelRuntimeAttrs &RuntimeAttrs,
7524
+ const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7503
7525
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7504
7526
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7505
7527
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7546,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7524
7546
// to make a remote call (offload) to the previously outlined function
7525
7547
// that represents the target region. Do that now.
7526
7548
if (!Config.isTargetDevice ())
7527
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7549
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
7528
7550
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7529
7551
HasNowait);
7530
7552
return Builder.saveIP ();
0 commit comments