@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7310
7310
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7311
7311
Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7312
7312
ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7313
+ Value *IfCond,
7313
7314
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7314
7315
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7315
7316
bool HasNoWait = false ) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7354
7355
return Error::success ();
7355
7356
};
7356
7357
7357
- // If we don't have an ID for the target region, it means an offload entry
7358
- // wasn't created. In this case we just run the host fallback directly.
7359
- if (!OutlinedFnID) {
7358
+ auto &&EmitTargetCallElse =
7359
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7360
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7360
7361
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
7361
7362
// produce any.
7362
7363
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7372
7373
}());
7373
7374
7374
7375
Builder.restoreIP (AfterIP);
7375
- return ;
7376
- }
7376
+ return Error::success ();
7377
+ };
7378
+
7379
+ auto &&EmitTargetCallThen =
7380
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7381
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7382
+ OpenMPIRBuilder::TargetDataInfo Info (
7383
+ /* RequiresDevicePointerInfo=*/ false ,
7384
+ /* SeparateBeginEndCalls=*/ true );
7385
+
7386
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7387
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7388
+ OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7389
+ RTArgs, MapInfo,
7390
+ /* IsNonContiguous=*/ true ,
7391
+ /* ForEndCall=*/ false );
7392
+
7393
+ SmallVector<Value *, 3 > NumTeamsC;
7394
+ SmallVector<Value *, 3 > NumThreadsC;
7395
+ for (auto V : NumTeams)
7396
+ NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7397
+ for (auto V : NumThreads)
7398
+ NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7399
+
7400
+ unsigned NumTargetItems = Info.NumberOfPtrs ;
7401
+ // TODO: Use correct device ID
7402
+ Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7403
+ uint32_t SrcLocStrSize;
7404
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7405
+ Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7406
+ llvm::omp::IdentFlag (0 ), 0 );
7407
+ // TODO: Use correct NumIterations
7408
+ Value *NumIterations = Builder.getInt64 (0 );
7409
+ // TODO: Use correct DynCGGroupMem
7410
+ Value *DynCGGroupMem = Builder.getInt32 (0 );
7411
+
7412
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (
7413
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7414
+ DynCGGroupMem, HasNoWait);
7415
+
7416
+ // Assume no error was returned because TaskBodyCB and
7417
+ // EmitTargetCallFallbackCB don't produce any.
7418
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7419
+ // The presence of certain clauses on the target directive require the
7420
+ // explicit generation of the target task.
7421
+ if (RequiresOuterTargetTask)
7422
+ return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7423
+ Dependencies, HasNoWait);
7424
+
7425
+ return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7426
+ EmitTargetCallFallbackCB, KArgs,
7427
+ DeviceID, RTLoc, AllocaIP);
7428
+ }());
7377
7429
7378
- OpenMPIRBuilder::TargetDataInfo Info (
7379
- /* RequiresDevicePointerInfo= */ false ,
7380
- /* SeparateBeginEndCalls= */ true ) ;
7430
+ Builder. restoreIP (AfterIP);
7431
+ return Error::success ();
7432
+ } ;
7381
7433
7382
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder. saveIP ());
7383
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7384
- OMPBuilder. emitOffloadingArraysAndArgs (AllocaIP, Builder. saveIP (), Info,
7385
- RTArgs, MapInfo,
7386
- /* IsNonContiguous= */ true ,
7387
- /* ForEndCall= */ false );
7434
+ // If we don't have an ID for the target region, it means an offload entry
7435
+ // wasn't created. In this case we just run the host fallback directly.
7436
+ if (!OutlinedFnID) {
7437
+ cantFail ( EmitTargetCallElse (AllocaIP, Builder. saveIP ()));
7438
+ return ;
7439
+ }
7388
7440
7389
- SmallVector<Value *, 3 > NumTeamsC;
7390
- SmallVector<Value *, 3 > NumThreadsC;
7391
- for (auto V : NumTeams)
7392
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7393
- for (auto V : NumThreads)
7394
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7441
+ // If there's no IF clause, only generate the kernel launch code path.
7442
+ if (!IfCond) {
7443
+ cantFail (EmitTargetCallThen (AllocaIP, Builder.saveIP ()));
7444
+ return ;
7445
+ }
7395
7446
7396
- unsigned NumTargetItems = Info.NumberOfPtrs ;
7397
- // TODO: Use correct device ID
7398
- Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7399
- uint32_t SrcLocStrSize;
7400
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7401
- Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7402
- llvm::omp::IdentFlag (0 ), 0 );
7403
- // TODO: Use correct NumIterations
7404
- Value *NumIterations = Builder.getInt64 (0 );
7405
- // TODO: Use correct DynCGGroupMem
7406
- Value *DynCGGroupMem = Builder.getInt32 (0 );
7407
-
7408
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7409
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7410
- DynCGGroupMem, HasNoWait);
7411
-
7412
- // Assume no error was returned because TaskBodyCB and
7413
- // EmitTargetCallFallbackCB don't produce any.
7414
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7415
- // The presence of certain clauses on the target directive require the
7416
- // explicit generation of the target task.
7417
- if (RequiresOuterTargetTask)
7418
- return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7419
- Dependencies, HasNoWait);
7420
-
7421
- return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7422
- EmitTargetCallFallbackCB, KArgs,
7423
- DeviceID, RTLoc, AllocaIP);
7424
- }());
7425
-
7426
- Builder.restoreIP (AfterIP);
7447
+ cantFail (OMPBuilder.emitIfClause (IfCond, EmitTargetCallThen,
7448
+ EmitTargetCallElse, AllocaIP));
7427
7449
}
7428
7450
7429
7451
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7430
7452
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7431
7453
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7432
7454
ArrayRef<int32_t > NumTeams, ArrayRef<int32_t > NumThreads,
7433
- SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7455
+ SmallVectorImpl<Value *> &Args, Value *IfCond,
7456
+ GenMapInfoCallbackTy GenMapInfoCB,
7434
7457
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7435
7458
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7436
7459
SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7455
7478
// that represents the target region. Do that now.
7456
7479
if (!Config.isTargetDevice ())
7457
7480
emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7458
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
7481
+ NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
7482
+ HasNowait);
7459
7483
return Builder.saveIP ();
7460
7484
}
7461
7485
0 commit comments