diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index fd18b212bad51..746dc2b62787d 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -621,10 +621,12 @@ class ClauseProcessor { llvm::SmallVectorImpl *mapSymLocs = nullptr, llvm::SmallVectorImpl *mapSymbols = nullptr) const; - bool processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) const; + bool + processReduction(mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr) const; bool processSectionsReduction(mlir::Location currentLocation) const; bool processTo(llvm::SmallVectorImpl &result) const; bool @@ -1079,12 +1081,14 @@ class ReductionProcessor { /// Creates a reduction declaration and associates it with an OpenMP block /// directive. - static void addReductionDecl( - mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) { + static void + addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ @@ -1114,6 +1118,8 @@ class ReductionProcessor { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); @@ -1148,6 +1154,8 @@ class ReductionProcessor { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); @@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap( bool ClauseProcessor::processReduction( mlir::Location currentLocation, llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) const { + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols) + const { return findRepeatableClause( [&](const ClauseTy::Reduction *reductionClause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); + reductionVars, reductionDeclSymbols, + reductionSymbols); }); } @@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo { return *this; } + OpWithBodyGenInfo & + setReductions(llvm::SmallVector *value1, + llvm::SmallVector *value2) { + reductionSymbols = value1; + reductionTypes = value2; + return *this; + } + OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) { genRegionEntryCB = value; return *this; @@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo { const Fortran::parser::OmpClauseList *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; + /// [in] if provided, list of reduction symbols + llvm::SmallVector *reductionSymbols = + nullptr; + /// [in] if provided, list of reduction types + llvm::SmallVector *reductionTypes = nullptr; /// [in] if provided, emits the op's region entry. Otherwise, an emtpy block /// is created in the region. GenOMPRegionEntryCBFn genRegionEntryCB = nullptr; @@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector allocateOperands, allocatorOperands, reductionVars; llvm::SmallVector reductionDeclSymbols; + llvm::SmallVector reductionSymbols; ClauseProcessor cp(converter, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, @@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, cp.processDefault(); cp.processAllocate(allocatorOperands, allocateOperands); if (!outerCombined) - cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols, + &reductionSymbols); + + llvm::SmallVector reductionTypes; + reductionTypes.reserve(reductionVars.size()); + llvm::transform(reductionVars, std::back_inserter(reductionTypes), + [](mlir::Value v) { return v.getType(); }); + + auto reductionCallback = [&](mlir::Operation *op) { + llvm::SmallVector locs(reductionVars.size(), + currentLocation); + auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {}, + reductionTypes, locs); + for (auto [arg, prv] : + llvm::zip_equal(reductionSymbols, block->getArguments())) { + converter.bindSymbol(*arg, prv); + } + return reductionSymbols; + }; return genOpWithBody( OpWithBodyGenInfo(converter, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauseList) + .setReductions(&reductionSymbols, &reductionTypes) + .setGenRegionEntryCb(reductionCallback), /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, numThreadsClauseOperand, allocateOperands, allocatorOperands, reductionVars, @@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; } - if (singleDirective) { - genOpenMPReduction(converter, beginClauseList); + if (singleDirective) return; - } // Codegen for combined directives bool combinedDirective = false; @@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter, ")"); genNestedEvaluations(converter, eval); - genOpenMPReduction(converter, beginClauseList); } static void diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 index 6580aeb13ccd1..4b223e822760a 100644 --- a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 +++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 @@ -27,9 +27,11 @@ !CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"} !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref) { -!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref +!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32 +!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] +!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -48,9 +50,11 @@ subroutine simple_int_add !CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"} !CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32 !CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref) { -!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref +!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref +!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32 +!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32 +!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -72,11 +76,15 @@ subroutine simple_real_add !CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref) { !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref +!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref +!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32 +!CHECK: fir.store %[[RES1]] to %[[PRV1]] +!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref +!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]] +!CHECK: fir.store %[[RES0]] to %[[PRV0]] !CHECK: omp.terminator !CHECK: } !CHECK: return diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90 index 81a93aebbd266..8f3ac3dc357af 100644 --- a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 +++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90 @@ -28,9 +28,12 @@ !CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref +!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32 +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -50,9 +53,12 @@ subroutine simple_int_add !CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32 !CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref +!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32 +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -76,11 +82,17 @@ subroutine simple_real_add !CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref) { +!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref +!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref +!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32 +!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref +!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref +!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32 +!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return diff --git a/flang/test/Lower/OpenMP/parallel-reduction.f90 b/flang/test/Lower/OpenMP/parallel-reduction.f90 new file mode 100644 index 0000000000000..a07d118b0ba19 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction.f90 @@ -0,0 +1,38 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK: omp.reduction.declare @[[REDUCTION_DECLARE:[_a-z0-9]+]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[I0:[_a-z0-9]+]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[I0]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[C0:[_a-z0-9]+]]: i32, %[[C1:[_a-z0-9]+]]: i32): +!CHECK: %[[CR:[_a-z0-9]+]] = arith.addi %[[C0]], %[[C1]] : i32 +!CHECK: omp.yield(%[[CR]] : i32) +!CHECK: } +!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "mn"} { +!CHECK: %[[RED_ACCUM_REF:[_a-z0-9]+]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"} +!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32 +!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref +!CHECK: omp.parallel reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref) { +!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32 +!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref +!CHECK: omp.terminator +!CHECK: } +!CHECK: %[[RED_ACCUM_VAL:[_a-z0-9]+]] = fir.load %[[RED_ACCUM_DECL]]#0 : !fir.ref +!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[RED_ACCUM_VAL]]) fastmath : (!fir.ref, i32) -> i1 +!CHECK: return +!CHECK: } + +program mn + integer :: i + i = 0 + + !$omp parallel reduction(+:i) + i = 1 + !$omp end parallel + + print *, i +end program diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ca36350548577..5d84217b9c701 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -191,11 +191,8 @@ def ParallelOp : OpenMP_Op<"parallel", [ unsigned getNumReductionVars() { return getReductionVars().size(); } }]; let assemblyFormat = [{ - oilist( `reduction` `(` - custom( - $reduction_vars, type($reduction_vars), $reductions - ) `)` - | `if` `(` $if_expr_var `:` type($if_expr_var) `)` + oilist( + `if` `(` $if_expr_var `:` type($if_expr_var) `)` | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)` | `allocate` `(` custom( @@ -203,7 +200,7 @@ def ParallelOp : OpenMP_Op<"parallel", [ $allocators_vars, type($allocators_vars) ) `)` | `proc_bind` `(` custom($proc_bind_val) `)` - ) $region attr-dict + ) custom($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 381f17d080419..394f0629ba6e4 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/Interfaces/FoldInterfaces.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" @@ -34,6 +35,7 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" +#include "mlir/Support/LogicalResult.h" using namespace mlir; using namespace mlir::omp; @@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op, // Parser, printer and verifier for ReductionVarList //===----------------------------------------------------------------------===// +ParseResult +parseReductionClause(OpAsmParser &parser, Region ®ion, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &reductionSymbols, + SmallVectorImpl &privates) { + if (failed(parser.parseOptionalKeyword("reduction"))) + return failure(); + + SmallVector reductionVec; + + if (failed( + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + if (parser.parseAttribute(reductionVec.emplace_back()) || + parser.parseOperand(operands.emplace_back()) || + parser.parseArrow() || + parser.parseArgument(privates.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + + for (auto [prv, type] : llvm::zip_equal(privates, types)) { + prv.type = type; + } + SmallVector reductions(reductionVec.begin(), reductionVec.end()); + reductionSymbols = ArrayAttr::get(parser.getContext(), reductions); + return success(); +} + +static void printReductionClause(OpAsmPrinter &p, Operation *op, Region ®ion, + ValueRange operands, TypeRange types, + ArrayAttr reductionSymbols) { + p << "reduction("; + llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands, + region.front().getArguments(), types), + p, [&p](auto t) { + auto [sym, op, arg, type] = t; + p << sym << " " << op << " -> " << arg << " : " + << type; + }); + p << ") "; +} + +static ParseResult +parseParallelRegion(OpAsmParser &parser, Region ®ion, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &reductionSymbols) { + + llvm::SmallVector privates; + if (succeeded(parseReductionClause(parser, region, operands, types, + reductionSymbols, privates))) + return parser.parseRegion(region, privates); + + return parser.parseRegion(region); +} + +static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, + ValueRange operands, TypeRange types, + ArrayAttr reductionSymbols) { + if (reductionSymbols) + printReductionClause(p, op, region, operands, types, reductionSymbols); + p.printRegion(region, /*printEntryBlockArgs=*/false); +} + /// reduction-entry-list ::= reduction-entry /// | reduction-entry-list `,` reduction-entry /// reduction-entry ::= symbol-ref `->` ssa-id `:` type @@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region ®ion, loopVarTypes = SmallVector(ivs.size(), loopVarType); for (auto &iv : ivs) iv.type = loopVarType; + return parser.parseRegion(region, ivs); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 79956f82ed141..c87e895bb5404 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Allocate reduction vars SmallVector privateReductionVariables; DenseMap reductionVariableMap; - allocReductionVars(opInst, builder, moduleTranslation, allocaIP, - reductionDecls, privateReductionVariables, - reductionVariableMap); + { + llvm::IRBuilderBase::InsertPointGuard guard(builder); + builder.restoreIP(allocaIP); + auto args = opInst.getRegion().getArguments(); + + for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) { + llvm::Value *var = builder.CreateAlloca( + moduleTranslation.convertType(reductionDecls[i].getType())); + moduleTranslation.mapValue(args[i], var); + privateReductionVariables.push_back(var); + reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var); + } + } // Store the mapping between reduction variables and their private copies on // ModuleTranslation stack. It can be then recovered when translating diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 65a704d18107b..651405964c067 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) { func.func @parallel_reduction() { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr) - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr) + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32 + %3 = llvm.fadd %1, %2 : f32 + llvm.store %3, %prv : f32, !llvm.ptr omp.terminator } return @@ -654,13 +656,14 @@ func.func @parallel_reduction() { func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32 + llvm.fadd %1, %2 : f32 // CHECK: omp.yield omp.yield } @@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) { // CHECK-LABEL: func @parallel_reduction2 func.func @parallel_reduction2() { %0 = memref.alloca() : memref<1xf32> - // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) - omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) { + // CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>) + omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction - omp.reduction %1, %0 : f32, memref<1xf32> + %2 = arith.constant 0 : index + %3 = memref.load %prv[%2] : memref<1xf32> + // CHECK: llvm.fadd + %4 = llvm.fadd %1, %3 : f32 + memref.store %4, %prv[%2] : memref<1xf32> omp.terminator } return @@ -813,13 +819,14 @@ func.func @parallel_reduction2() { func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) { - omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) { // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32 + %3 = llvm.fadd %1, %2 : f32 // CHECK: omp.yield omp.yield } diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir index 93ab578df9e4e..dae83c0cf92ed 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir @@ -441,9 +441,11 @@ atomic { llvm.func @simple_reduction_parallel() { %c1 = llvm.mlir.constant(1 : i32) : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { %1 = llvm.mlir.constant(2.0 : f32) : f32 - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + %3 = llvm.fadd %2, %1 : f32 + llvm.store %3, %prv : f32, !llvm.ptr omp.terminator } llvm.return @@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) { %lb = llvm.mlir.constant(1 : i64) : i64 %step = llvm.mlir.constant(1 : i64) : i64 - omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) { + omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) { omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) { %ival = llvm.trunc %iv : i64 to i32 - omp.reduction %ival, %0 : i32, !llvm.ptr + %lprv = llvm.load %prv : !llvm.ptr -> i32 + %add = llvm.add %lprv, %ival : i32 + llvm.store %add, %prv : i32, !llvm.ptr omp.yield } omp.terminator