Skip to content

Commit c9ad4d7

Browse files
committed
Move block argument code into createBodyOfOp
1 parent 7890b53 commit c9ad4d7

File tree

2 files changed

+59
-55
lines changed

2 files changed

+59
-55
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -618,12 +618,12 @@ class ClauseProcessor {
618618
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
619619
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
620620
*mapSymbols = nullptr) const;
621-
bool processReduction(
622-
mlir::Location currentLocation,
623-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625-
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
626-
nullptr) const;
621+
bool
622+
processReduction(mlir::Location currentLocation,
623+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
626+
*reductionSymbols = nullptr) const;
627627
bool processSectionsReduction(mlir::Location currentLocation) const;
628628
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
629629
bool
@@ -1074,14 +1074,14 @@ class ReductionProcessor {
10741074

10751075
/// Creates a reduction declaration and associates it with an OpenMP block
10761076
/// directive.
1077-
static void addReductionDecl(
1078-
mlir::Location currentLocation,
1079-
Fortran::lower::AbstractConverter &converter,
1080-
const Fortran::parser::OmpReductionClause &reduction,
1081-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1082-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1083-
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
1084-
nullptr) {
1077+
static void
1078+
addReductionDecl(mlir::Location currentLocation,
1079+
Fortran::lower::AbstractConverter &converter,
1080+
const Fortran::parser::OmpReductionClause &reduction,
1081+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1082+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1083+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
1084+
*reductionSymbols = nullptr) {
10851085
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
10861086
mlir::omp::ReductionDeclareOp decl;
10871087
const auto &redOperator{
@@ -1110,7 +1110,7 @@ class ReductionProcessor {
11101110
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11111111
if (const auto *name{
11121112
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1113-
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1113+
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
11141114
if (reductionSymbols)
11151115
reductionSymbols->push_back(symbol);
11161116
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -1144,7 +1144,7 @@ class ReductionProcessor {
11441144
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11451145
if (const auto *name{
11461146
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1147-
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1147+
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
11481148
if (reductionSymbols)
11491149
reductionSymbols->push_back(symbol);
11501150
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -1941,7 +1941,7 @@ bool ClauseProcessor::processReduction(
19411941
mlir::Location currentLocation,
19421942
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
19431943
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1944-
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
1944+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
19451945
const {
19461946
return findRepeatableClause<ClauseTy::Reduction>(
19471947
[&](const ClauseTy::Reduction *reductionClause,
@@ -2258,8 +2258,11 @@ static void createBodyOfOp(
22582258
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
22592259
Fortran::lower::pft::Evaluation &eval, bool genNested,
22602260
const Fortran::parser::OmpClauseList *clauses = nullptr,
2261-
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2262-
bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
2261+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs = {},
2262+
bool outerCombined = false, DataSharingProcessor *dsp = nullptr,
2263+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs =
2264+
{},
2265+
const llvm::SmallVector<mlir::Type> &reductionTypes = {}) {
22632266
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
22642267

22652268
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2272,24 +2275,32 @@ static void createBodyOfOp(
22722275
// argument. Also update the symbol's address with the mlir argument value.
22732276
// e.g. For loops the argument is the induction variable. And all further
22742277
// uses of the induction variable should use this mlir value.
2275-
if (args.size()) {
2278+
if (loopArgs.size()) {
22762279
std::size_t loopVarTypeSize = 0;
2277-
for (const Fortran::semantics::Symbol *arg : args)
2280+
for (const Fortran::semantics::Symbol *arg : loopArgs)
22782281
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
22792282
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2280-
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
2281-
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
2283+
llvm::SmallVector<mlir::Type> tiv(loopArgs.size(), loopVarType);
2284+
llvm::SmallVector<mlir::Location> locs(loopArgs.size(), loc);
22822285
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
22832286
// The argument is not currently in memory, so make a temporary for the
22842287
// argument, and store it there, then bind that location to the argument.
22852288
mlir::Operation *storeOp = nullptr;
2286-
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
2289+
for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
22872290
mlir::Value indexVal =
22882291
fir::getBase(op.getRegion().front().getArgument(argIndex));
22892292
storeOp =
22902293
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
22912294
}
22922295
firOpBuilder.setInsertionPointAfter(storeOp);
2296+
} else if (reductionArgs.size()) {
2297+
llvm::SmallVector<mlir::Location> locs(reductionArgs.size(), loc);
2298+
auto block =
2299+
firOpBuilder.createBlock(&op.getRegion(), {}, reductionTypes, locs);
2300+
for (auto [arg, prv] :
2301+
llvm::zip_equal(reductionArgs, block->getArguments())) {
2302+
converter.bindSymbol(*arg, prv);
2303+
}
22932304
} else {
22942305
firOpBuilder.createBlock(&op.getRegion());
22952306
}
@@ -2390,8 +2401,8 @@ static void createBodyOfOp(
23902401
assert(tempDsp.has_value());
23912402
tempDsp->processStep2(op, isLoop);
23922403
} else {
2393-
if (isLoop && args.size() > 0)
2394-
dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
2404+
if (isLoop && loopArgs.size() > 0)
2405+
dsp->setLoopIV(converter.getSymbolAddress(*loopArgs[0]));
23952406
dsp->processStep2(op, isLoop);
23962407
}
23972408
}
@@ -2476,7 +2487,8 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
24762487
currentLocation, std::forward<Args>(args)...);
24772488
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
24782489
clauseList,
2479-
/*args=*/{}, outerCombined);
2490+
/*loopArgs=*/{}, outerCombined, /*dsp=*/nullptr,
2491+
/*reductionArgs=*/{}, /*reductionTypes=*/{});
24802492
return op;
24812493
}
24822494

@@ -2513,7 +2525,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25132525
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
25142526
reductionVars;
25152527
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2516-
llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
2528+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
25172529

25182530
ClauseProcessor cp(converter, clauseList);
25192531
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2526,9 +2538,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25262538
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
25272539
&reductionSymbols);
25282540

2529-
auto op = genOpWithBody<mlir::omp::ParallelOp>(
2530-
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2531-
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
2541+
auto op = converter.getFirOpBuilder().create<mlir::omp::ParallelOp>(
2542+
currentLocation, mlir::TypeRange(), ifClauseOperand,
25322543
numThreadsClauseOperand, allocateOperands, allocatorOperands,
25332544
reductionVars,
25342545
reductionDeclSymbols.empty()
@@ -2537,21 +2548,14 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25372548
reductionDeclSymbols),
25382549
procBindKindAttr);
25392550

2540-
// Add reduction block arguments
2541-
if (!reductionVars.empty()) {
2542-
mlir::Block &regionBlock = op.getRegion().front();
2543-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2544-
for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
2545-
auto savedIP = firOpBuilder.getInsertionPoint();
2546-
firOpBuilder.setInsertionPointToStart(&regionBlock);
2547-
auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
2548-
converter.bindSymbol(*sym, prv);
2549-
val.replaceUsesWithIf(prv, [&regionBlock](mlir::OpOperand &use) {
2550-
return use.getOwner()->getBlock() == &regionBlock;
2551-
});
2552-
firOpBuilder.setInsertionPoint(&regionBlock, savedIP);
2553-
}
2554-
}
2551+
llvm::SmallVector<mlir::Type> reductionTypes;
2552+
reductionTypes.reserve(reductionVars.size());
2553+
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
2554+
[](mlir::Value v) { return v.getType(); });
2555+
createBodyOfOp<mlir::omp::ParallelOp>(op, converter, currentLocation, eval,
2556+
genNested, &clauseList, /*loopArgs=*/{},
2557+
outerCombined, /*dsp=*/nullptr,
2558+
reductionSymbols, reductionTypes);
25552559

25562560
return op;
25572561
}

flang/test/Lower/OpenMP/parallel-reduction-add.f90

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
3131
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
3232
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
33-
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
33+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
3434
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
3535
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
36-
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : i32, !fir.ref<i32>
36+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
3737
!CHECK: omp.terminator
3838
!CHECK: }
3939
!CHECK: return
@@ -55,10 +55,10 @@ subroutine simple_int_add
5555
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
5656
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
5757
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
58-
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
58+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
5959
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
6060
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
61-
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : f32, !fir.ref<f32>
61+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
6262
!CHECK: omp.terminator
6363
!CHECK: }
6464
!CHECK: return
@@ -83,16 +83,16 @@ subroutine simple_real_add
8383
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
8484
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
8585
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
86-
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
8786
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
87+
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
8888
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
89-
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RPRV]] : !fir.ref<f32>
89+
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
9090
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
91-
!CHECK: hlfir.assign %[[RES1]] to %[[RPRV]] : f32, !fir.ref<f32>
92-
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IPRV]] : !fir.ref<i32>
91+
!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
92+
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
9393
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
9494
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
95-
!CHECK: hlfir.assign %[[RES0]] to %[[IPRV]] : i32, !fir.ref<i32>
95+
!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
9696
!CHECK: omp.terminator
9797
!CHECK: }
9898
!CHECK: return

0 commit comments

Comments
 (0)