Skip to content

[OMPIRBuilder][MLIR] Add support for target 'if' clause #122478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default attributes, including
/// numbers of threads and teams to launch the kernel with.
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
/// and teams to launch the kernel with.
/// \param IfCond value of the `if` clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
/// dependency information as passed in the depend clause
/// \param HasNowait Whether the target construct has a `nowait` clause or
/// not.
InsertPointOrErrorTy createTarget(
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
Expand Down
210 changes: 117 additions & 93 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
Value *Alignment = AlignedItem.second;
Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
Builder.SetInsertPoint(loadInst->getNextNode());
Builder.CreateAlignmentAssumption(F->getDataLayout(),
AlignedPtr, Alignment);
Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
Alignment);
}
Builder.restoreIP(IP);
}
Expand Down Expand Up @@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
Loop *L = LI.getLoopFor(CLI->getHeader());
assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");

TargetTransformInfo::UnrollingPreferences UP =
gatherUnrollingPreferences(L, SE, TTI,
/*BlockFrequencyInfo=*/nullptr,
/*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
/*UserThreshold=*/std::nullopt,
/*UserCount=*/std::nullopt,
/*UserAllowPartial=*/true,
/*UserAllowRuntime=*/true,
/*UserUpperBound=*/std::nullopt,
/*UserFullUnrollMaxCount=*/std::nullopt);
TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
L, SE, TTI,
/*BlockFrequencyInfo=*/nullptr,
/*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
/*UserThreshold=*/std::nullopt,
/*UserCount=*/std::nullopt,
/*UserAllowPartial=*/true,
/*UserAllowRuntime=*/true,
/*UserUpperBound=*/std::nullopt,
/*UserFullUnrollMaxCount=*/std::nullopt);

UP.Force = true;

Expand Down Expand Up @@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
Expand Down Expand Up @@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
auto &&EmitTargetCallElse =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
Expand All @@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());

Builder.restoreIP(AfterIP);
return;
}

OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefaultVal, RuntimeVal] :
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));

// Calculate number of threads: 0 if no clauses specified, otherwise it is the
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
return Error::success();
};
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result = Result
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),

auto &&EmitTargetCallThen =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefaultVal, RuntimeVal] :
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
NumTeamsC.push_back(RuntimeVal ? RuntimeVal
: Builder.getInt32(DefaultVal));

// Calculate number of threads: 0 if no clauses specified, otherwise it is
// the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result =
Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
Result, Clause)
: Clause;
};
};

// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
: nullptr;
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause =
RuntimeAttrs.TeamsThreadLimit.size() == 1
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
: nullptr;

for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
RuntimeAttrs.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
Value *NumThreads = InitMaxThreadsClause(TargetVal);
for (auto [TeamsVal, TargetVal] : zip_equal(
RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
Value *NumThreads = InitMaxThreadsClause(TargetVal);

CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);

NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
}
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
}

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);

Value *TripCount = RuntimeAttrs.LoopTripCount
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);
Value *TripCount = RuntimeAttrs.LoopTripCount
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
Dependencies, HasNoWait);
// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
Dependencies, HasNoWait);

return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP);
}());

Builder.restoreIP(AfterIP);
return Error::success();
};

return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP);
}());
// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly and
// ignore any potential 'if' clauses.
if (!OutlinedFnID) {
cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
return;
}

// If there's no 'if' clause, only generate the kernel launch code path.
if (!IfCond) {
cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
return;
}

Builder.restoreIP(AfterIP);
cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
EmitTargetCallElse, AllocaIP));
}

OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
Expand All @@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
HasNowait);
return Builder.saveIP();
Expand Down
41 changes: 22 additions & 19 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs,
RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
SimpleArgAccessorCB));
RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

OMPBuilder.finalize();
Expand Down Expand Up @@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};

ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs,
/*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

Builder.CreateRetVoid();
Expand Down Expand Up @@ -6561,8 +6562,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs,
RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
SimpleArgAccessorCB));
RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

OMPBuilder.finalize();
Expand Down Expand Up @@ -6660,11 +6661,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};

ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs,
/*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

Builder.CreateRetVoid();
Expand Down Expand Up @@ -6774,11 +6776,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};

ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs,
/*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

Builder.CreateRetVoid();
Expand Down
Loading
Loading