Skip to content

Commit 7890b53

Browse files
committed
Move parsing and printing reductions to separate function
1 parent c7ea791 commit 7890b53

File tree

2 files changed

+46
-28
lines changed

2 files changed

+46
-28
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
209209
$allocators_vars, type($allocators_vars)
210210
) `)`
211211
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
212-
) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
212+
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
213213
}];
214214
let hasVerifier = 1;
215215
}

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
3636
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
3737
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
38+
#include "mlir/Support/LogicalResult.h"
3839

3940
using namespace mlir;
4041
using namespace mlir::omp;
@@ -428,17 +429,15 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
428429
// Parser, printer and verifier for ReductionVarList
429430
//===----------------------------------------------------------------------===//
430431

431-
static ParseResult
432-
parseWsReduction(OpAsmParser &parser, Region &region,
433-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
434-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
435-
436-
// possibly parse reduction
432+
ParseResult parseReductionClause(
433+
OpAsmParser &parser, Region &region,
434+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435+
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
436+
SmallVectorImpl<OpAsmParser::Argument> &privates) {
437437
if (failed(parser.parseOptionalKeyword("reduction")))
438-
return parser.parseRegion(region);
438+
return failure();
439439

440440
SmallVector<SymbolRefAttr> reductionVec;
441-
SmallVector<OpAsmParser::Argument> privates;
442441

443442
if (failed(
444443
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
@@ -452,28 +451,46 @@ parseWsReduction(OpAsmParser &parser, Region &region,
452451
})))
453452
return failure();
454453

455-
for (std::size_t i = 0; i < privates.size(); ++i) {
456-
privates[i].type = types[i];
454+
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
455+
prv.type = type;
457456
}
458457
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
459458
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
460-
return parser.parseRegion(region, privates);
461-
}
462-
463-
void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
464-
ValueRange operands, TypeRange types,
465-
ArrayAttr reductionSymbols) {
466-
if (reductionSymbols) {
467-
p << "reduction(";
468-
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
469-
region.front().getArguments(), types),
470-
p, [&p](auto t) {
471-
auto [sym, op, arg, type] = t;
472-
p << sym << " " << op << " -> " << arg << " : "
473-
<< type;
474-
});
475-
p << ") ";
476-
}
459+
return success();
460+
}
461+
462+
static void printReductionClause(OpAsmPrinter &p, Operation *op, Region &region,
463+
ValueRange operands, TypeRange types,
464+
ArrayAttr reductionSymbols) {
465+
p << "reduction(";
466+
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
467+
region.front().getArguments(), types),
468+
p, [&p](auto t) {
469+
auto [sym, op, arg, type] = t;
470+
p << sym << " " << op << " -> " << arg << " : "
471+
<< type;
472+
});
473+
p << ") ";
474+
}
475+
476+
static ParseResult
477+
parseParallelRegion(OpAsmParser &parser, Region &region,
478+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479+
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
480+
481+
llvm::SmallVector<OpAsmParser::Argument> privates;
482+
if (succeeded(parseReductionClause(parser, region, operands, types,
483+
reductionSymbols, privates)))
484+
return parser.parseRegion(region, privates);
485+
486+
return parser.parseRegion(region);
487+
}
488+
489+
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
490+
ValueRange operands, TypeRange types,
491+
ArrayAttr reductionSymbols) {
492+
if (reductionSymbols)
493+
printReductionClause(p, op, region, operands, types, reductionSymbols);
477494
p.printRegion(region, /*printEntryBlockArgs=*/false);
478495
}
479496

@@ -1164,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region &region,
11641181
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
11651182
for (auto &iv : ivs)
11661183
iv.type = loopVarType;
1184+
11671185
return parser.parseRegion(region, ivs);
11681186
}
11691187

0 commit comments

Comments
 (0)