35
35
#include " mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
36
36
#include " mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
37
37
#include " mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
38
+ #include " mlir/Support/LogicalResult.h"
38
39
39
40
using namespace mlir ;
40
41
using namespace mlir ::omp;
@@ -428,17 +429,15 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
428
429
// Parser, printer and verifier for ReductionVarList
429
430
// ===----------------------------------------------------------------------===//
430
431
431
- static ParseResult
432
- parseWsReduction (OpAsmParser &parser, Region ®ion,
433
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
434
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
435
-
436
- // possibly parse reduction
432
+ ParseResult parseReductionClause (
433
+ OpAsmParser &parser, Region ®ion,
434
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435
+ SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
436
+ SmallVectorImpl<OpAsmParser::Argument> &privates) {
437
437
if (failed (parser.parseOptionalKeyword (" reduction" )))
438
- return parser. parseRegion (region );
438
+ return failure ( );
439
439
440
440
SmallVector<SymbolRefAttr> reductionVec;
441
- SmallVector<OpAsmParser::Argument> privates;
442
441
443
442
if (failed (
444
443
parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren, [&]() {
@@ -452,28 +451,46 @@ parseWsReduction(OpAsmParser &parser, Region ®ion,
452
451
})))
453
452
return failure ();
454
453
455
- for (std:: size_t i = 0 ; i < privates. size (); ++i ) {
456
- privates[i] .type = types[i] ;
454
+ for (auto [prv, type] : llvm::zip_equal (privates, types) ) {
455
+ prv .type = type ;
457
456
}
458
457
SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
459
458
reductionSymbols = ArrayAttr::get (parser.getContext (), reductions);
460
- return parser.parseRegion (region, privates);
461
- }
462
-
463
- void printWsReduction (OpAsmPrinter &p, Operation *op, Region ®ion,
464
- ValueRange operands, TypeRange types,
465
- ArrayAttr reductionSymbols) {
466
- if (reductionSymbols) {
467
- p << " reduction(" ;
468
- llvm::interleaveComma (llvm::zip_equal (reductionSymbols, operands,
469
- region.front ().getArguments (), types),
470
- p, [&p](auto t) {
471
- auto [sym, op, arg, type] = t;
472
- p << sym << " " << op << " -> " << arg << " : "
473
- << type;
474
- });
475
- p << " ) " ;
476
- }
459
+ return success ();
460
+ }
461
+
462
+ static void printReductionClause (OpAsmPrinter &p, Operation *op, Region ®ion,
463
+ ValueRange operands, TypeRange types,
464
+ ArrayAttr reductionSymbols) {
465
+ p << " reduction(" ;
466
+ llvm::interleaveComma (llvm::zip_equal (reductionSymbols, operands,
467
+ region.front ().getArguments (), types),
468
+ p, [&p](auto t) {
469
+ auto [sym, op, arg, type] = t;
470
+ p << sym << " " << op << " -> " << arg << " : "
471
+ << type;
472
+ });
473
+ p << " ) " ;
474
+ }
475
+
476
+ static ParseResult
477
+ parseParallelRegion (OpAsmParser &parser, Region ®ion,
478
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479
+ SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
480
+
481
+ llvm::SmallVector<OpAsmParser::Argument> privates;
482
+ if (succeeded (parseReductionClause (parser, region, operands, types,
483
+ reductionSymbols, privates)))
484
+ return parser.parseRegion (region, privates);
485
+
486
+ return parser.parseRegion (region);
487
+ }
488
+
489
+ static void printParallelRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
490
+ ValueRange operands, TypeRange types,
491
+ ArrayAttr reductionSymbols) {
492
+ if (reductionSymbols)
493
+ printReductionClause (p, op, region, operands, types, reductionSymbols);
477
494
p.printRegion (region, /* printEntryBlockArgs=*/ false );
478
495
}
479
496
@@ -1164,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region ®ion,
1164
1181
loopVarTypes = SmallVector<Type>(ivs.size (), loopVarType);
1165
1182
for (auto &iv : ivs)
1166
1183
iv.type = loopVarType;
1184
+
1167
1185
return parser.parseRegion (region, ivs);
1168
1186
}
1169
1187
0 commit comments