@@ -618,12 +618,12 @@ class ClauseProcessor {
618
618
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr ,
619
619
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
620
620
*mapSymbols = nullptr ) const ;
621
- bool processReduction (
622
- mlir::Location currentLocation,
623
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625
- llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
626
- nullptr ) const ;
621
+ bool
622
+ processReduction ( mlir::Location currentLocation,
623
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
626
+ *reductionSymbols = nullptr ) const ;
627
627
bool processSectionsReduction (mlir::Location currentLocation) const ;
628
628
bool processTo (llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const ;
629
629
bool
@@ -1074,14 +1074,14 @@ class ReductionProcessor {
1074
1074
1075
1075
// / Creates a reduction declaration and associates it with an OpenMP block
1076
1076
// / directive.
1077
- static void addReductionDecl (
1078
- mlir::Location currentLocation,
1079
- Fortran::lower::AbstractConverter &converter,
1080
- const Fortran::parser::OmpReductionClause &reduction,
1081
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1082
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1083
- llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
1084
- nullptr ) {
1077
+ static void
1078
+ addReductionDecl ( mlir::Location currentLocation,
1079
+ Fortran::lower::AbstractConverter &converter,
1080
+ const Fortran::parser::OmpReductionClause &reduction,
1081
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1082
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1083
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
1084
+ *reductionSymbols = nullptr ) {
1085
1085
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1086
1086
mlir::omp::ReductionDeclareOp decl;
1087
1087
const auto &redOperator{
@@ -1110,7 +1110,7 @@ class ReductionProcessor {
1110
1110
for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
1111
1111
if (const auto *name{
1112
1112
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1113
- if (Fortran::semantics::Symbol * symbol{name->symbol }) {
1113
+ if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1114
1114
if (reductionSymbols)
1115
1115
reductionSymbols->push_back (symbol);
1116
1116
mlir::Value symVal = converter.getSymbolAddress (*symbol);
@@ -1144,7 +1144,7 @@ class ReductionProcessor {
1144
1144
for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
1145
1145
if (const auto *name{
1146
1146
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1147
- if (Fortran::semantics::Symbol * symbol{name->symbol }) {
1147
+ if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1148
1148
if (reductionSymbols)
1149
1149
reductionSymbols->push_back (symbol);
1150
1150
mlir::Value symVal = converter.getSymbolAddress (*symbol);
@@ -1941,7 +1941,7 @@ bool ClauseProcessor::processReduction(
1941
1941
mlir::Location currentLocation,
1942
1942
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1943
1943
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1944
- llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
1944
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
1945
1945
const {
1946
1946
return findRepeatableClause<ClauseTy::Reduction>(
1947
1947
[&](const ClauseTy::Reduction *reductionClause,
@@ -2258,8 +2258,11 @@ static void createBodyOfOp(
2258
2258
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
2259
2259
Fortran::lower::pft::Evaluation &eval, bool genNested,
2260
2260
const Fortran::parser::OmpClauseList *clauses = nullptr ,
2261
- const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2262
- bool outerCombined = false , DataSharingProcessor *dsp = nullptr ) {
2261
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs = {},
2262
+ bool outerCombined = false , DataSharingProcessor *dsp = nullptr ,
2263
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs =
2264
+ {},
2265
+ const llvm::SmallVector<mlir::Type> &reductionTypes = {}) {
2263
2266
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2264
2267
2265
2268
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2272,24 +2275,32 @@ static void createBodyOfOp(
2272
2275
// argument. Also update the symbol's address with the mlir argument value.
2273
2276
// e.g. For loops the argument is the induction variable. And all further
2274
2277
// uses of the induction variable should use this mlir value.
2275
- if (args .size ()) {
2278
+ if (loopArgs .size ()) {
2276
2279
std::size_t loopVarTypeSize = 0 ;
2277
- for (const Fortran::semantics::Symbol *arg : args )
2280
+ for (const Fortran::semantics::Symbol *arg : loopArgs )
2278
2281
loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
2279
2282
mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2280
- llvm::SmallVector<mlir::Type> tiv (args .size (), loopVarType);
2281
- llvm::SmallVector<mlir::Location> locs (args .size (), loc);
2283
+ llvm::SmallVector<mlir::Type> tiv (loopArgs .size (), loopVarType);
2284
+ llvm::SmallVector<mlir::Location> locs (loopArgs .size (), loc);
2282
2285
firOpBuilder.createBlock (&op.getRegion (), {}, tiv, locs);
2283
2286
// The argument is not currently in memory, so make a temporary for the
2284
2287
// argument, and store it there, then bind that location to the argument.
2285
2288
mlir::Operation *storeOp = nullptr ;
2286
- for (auto [argIndex, argSymbol] : llvm::enumerate (args )) {
2289
+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs )) {
2287
2290
mlir::Value indexVal =
2288
2291
fir::getBase (op.getRegion ().front ().getArgument (argIndex));
2289
2292
storeOp =
2290
2293
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
2291
2294
}
2292
2295
firOpBuilder.setInsertionPointAfter (storeOp);
2296
+ } else if (reductionArgs.size ()) {
2297
+ llvm::SmallVector<mlir::Location> locs (reductionArgs.size (), loc);
2298
+ auto block =
2299
+ firOpBuilder.createBlock (&op.getRegion (), {}, reductionTypes, locs);
2300
+ for (auto [arg, prv] :
2301
+ llvm::zip_equal (reductionArgs, block->getArguments ())) {
2302
+ converter.bindSymbol (*arg, prv);
2303
+ }
2293
2304
} else {
2294
2305
firOpBuilder.createBlock (&op.getRegion ());
2295
2306
}
@@ -2390,8 +2401,8 @@ static void createBodyOfOp(
2390
2401
assert (tempDsp.has_value ());
2391
2402
tempDsp->processStep2 (op, isLoop);
2392
2403
} else {
2393
- if (isLoop && args .size () > 0 )
2394
- dsp->setLoopIV (converter.getSymbolAddress (*args [0 ]));
2404
+ if (isLoop && loopArgs .size () > 0 )
2405
+ dsp->setLoopIV (converter.getSymbolAddress (*loopArgs [0 ]));
2395
2406
dsp->processStep2 (op, isLoop);
2396
2407
}
2397
2408
}
@@ -2476,7 +2487,8 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
2476
2487
currentLocation, std::forward<Args>(args)...);
2477
2488
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
2478
2489
clauseList,
2479
- /* args=*/ {}, outerCombined);
2490
+ /* loopArgs=*/ {}, outerCombined, /* dsp=*/ nullptr ,
2491
+ /* reductionArgs=*/ {}, /* reductionTypes=*/ {});
2480
2492
return op;
2481
2493
}
2482
2494
@@ -2513,7 +2525,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2513
2525
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2514
2526
reductionVars;
2515
2527
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2516
- llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
2528
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
2517
2529
2518
2530
ClauseProcessor cp (converter, clauseList);
2519
2531
cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2526,9 +2538,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2526
2538
cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols,
2527
2539
&reductionSymbols);
2528
2540
2529
- auto op = genOpWithBody<mlir::omp::ParallelOp>(
2530
- converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2531
- /* resultTypes=*/ mlir::TypeRange (), ifClauseOperand,
2541
+ auto op = converter.getFirOpBuilder ().create <mlir::omp::ParallelOp>(
2542
+ currentLocation, mlir::TypeRange (), ifClauseOperand,
2532
2543
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2533
2544
reductionVars,
2534
2545
reductionDeclSymbols.empty ()
@@ -2537,21 +2548,14 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2537
2548
reductionDeclSymbols),
2538
2549
procBindKindAttr);
2539
2550
2540
- // Add reduction block arguments
2541
- if (!reductionVars.empty ()) {
2542
- mlir::Block ®ionBlock = op.getRegion ().front ();
2543
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2544
- for (auto [val, sym] : llvm::zip_equal (reductionVars, reductionSymbols)) {
2545
- auto savedIP = firOpBuilder.getInsertionPoint ();
2546
- firOpBuilder.setInsertionPointToStart (®ionBlock);
2547
- auto prv = regionBlock.addArgument (val.getType (), op.getLoc ());
2548
- converter.bindSymbol (*sym, prv);
2549
- val.replaceUsesWithIf (prv, [®ionBlock](mlir::OpOperand &use) {
2550
- return use.getOwner ()->getBlock () == ®ionBlock;
2551
- });
2552
- firOpBuilder.setInsertionPoint (®ionBlock, savedIP);
2553
- }
2554
- }
2551
+ llvm::SmallVector<mlir::Type> reductionTypes;
2552
+ reductionTypes.reserve (reductionVars.size ());
2553
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
2554
+ [](mlir::Value v) { return v.getType (); });
2555
+ createBodyOfOp<mlir::omp::ParallelOp>(op, converter, currentLocation, eval,
2556
+ genNested, &clauseList, /* loopArgs=*/ {},
2557
+ outerCombined, /* dsp=*/ nullptr ,
2558
+ reductionSymbols, reductionTypes);
2555
2559
2556
2560
return op;
2557
2561
}
0 commit comments