From 9e6aa2ff97d9a917b44c3bdb28a05f5e58a6964e Mon Sep 17 00:00:00 2001 From: David Truby Date: Wed, 24 Jan 2024 15:31:15 +0000 Subject: [PATCH 1/2] [mlir][flang][openmp] Rework parallel reduction operations This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch. --- flang/lib/Lower/OpenMP.cpp | 78 ++++++++++++++----- .../OpenMP/FIR/parallel-reduction-add.f90 | 26 ++++--- .../Lower/OpenMP/parallel-reduction-add.f90 | 26 +++++-- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 9 +-- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 68 ++++++++++++++++ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 +++- mlir/test/Dialect/OpenMP/ops.mlir | 39 ++++++---- mlir/test/Target/LLVMIR/openmp-reduction.mlir | 12 ++- 8 files changed, 211 insertions(+), 63 deletions(-) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index fd18b212bad51..746dc2b62787d 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -621,10 +621,12 @@ class ClauseProcessor { llvm::SmallVectorImpl *mapSymLocs = nullptr, llvm::SmallVectorImpl *mapSymbols = nullptr) const; - bool processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) const; + bool + processReduction(mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr) const; bool processSectionsReduction(mlir::Location currentLocation) const; bool processTo(llvm::SmallVectorImpl &result) const; bool @@ -1079,12 +1081,14 @@ class ReductionProcessor { /// Creates a reduction declaration and associates it with an OpenMP block /// directive. - static void addReductionDecl( - mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) { + static void + addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ @@ -1114,6 +1118,8 @@ class ReductionProcessor { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); @@ -1148,6 +1154,8 @@ class ReductionProcessor { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); @@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap( bool ClauseProcessor::processReduction( mlir::Location currentLocation, llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) const { + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols) + const { return findRepeatableClause( [&](const ClauseTy::Reduction *reductionClause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); + reductionVars, reductionDeclSymbols, + reductionSymbols); }); } @@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo { return *this; } + OpWithBodyGenInfo & + setReductions(llvm::SmallVector *value1, + llvm::SmallVector *value2) { + reductionSymbols = value1; + reductionTypes = value2; + return *this; + } + OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) { genRegionEntryCB = value; return *this; @@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo { const Fortran::parser::OmpClauseList *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; + /// [in] if provided, list of reduction symbols + llvm::SmallVector *reductionSymbols = + nullptr; + /// [in] if provided, list of reduction types + llvm::SmallVector *reductionTypes = nullptr; /// [in] if provided, emits the op's region entry. Otherwise, an emtpy block /// is created in the region. GenOMPRegionEntryCBFn genRegionEntryCB = nullptr; @@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector allocateOperands, allocatorOperands, reductionVars; llvm::SmallVector reductionDeclSymbols; + llvm::SmallVector reductionSymbols; ClauseProcessor cp(converter, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, @@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, cp.processDefault(); cp.processAllocate(allocatorOperands, allocateOperands); if (!outerCombined) - cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols, + &reductionSymbols); + + llvm::SmallVector reductionTypes; + reductionTypes.reserve(reductionVars.size()); + llvm::transform(reductionVars, std::back_inserter(reductionTypes), + [](mlir::Value v) { return v.getType(); }); + + auto reductionCallback = [&](mlir::Operation *op) { + llvm::SmallVector locs(reductionVars.size(), + currentLocation); + auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {}, + reductionTypes, locs); + for (auto [arg, prv] : + llvm::zip_equal(reductionSymbols, block->getArguments())) { + converter.bindSymbol(*arg, prv); + } + return reductionSymbols; + }; return genOpWithBody( OpWithBodyGenInfo(converter, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauseList) + .setReductions(&reductionSymbols, &reductionTypes) + .setGenRegionEntryCb(reductionCallback), /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, numThreadsClauseOperand, allocateOperands, allocatorOperands, reductionVars, @@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; } - if (singleDirective) { - genOpenMPReduction(converter, beginClauseList); + if (singleDirective) return; - } // Codegen for combined directives bool combinedDirective = false; @@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter, ")"); genNestedEvaluations(converter, eval); - genOpenMPReduction(converter, beginClauseList); } static void diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 index 6580aeb13ccd1..4b223e822760a 100644 --- a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 +++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 @@ -27,9 +27,11 @@ !CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"} !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref) { -!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref +!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32 +!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] +!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -48,9 +50,11 @@ subroutine simple_int_add !CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"} !CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32 !CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref) { -!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref +!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref +!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32 +!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32 +!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -72,11 +76,15 @@ subroutine simple_real_add !CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref) { !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref +!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref +!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32 +!CHECK: fir.store %[[RES1]] to %[[PRV1]] +!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref +!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]] +!CHECK: fir.store %[[RES0]] to %[[PRV0]] !CHECK: omp.terminator !CHECK: } !CHECK: return diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90 index 81a93aebbd266..8f3ac3dc357af 100644 --- a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 +++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90 @@ -28,9 +28,12 @@ !CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref +!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32 +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -50,9 +53,12 @@ subroutine simple_int_add !CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32 !CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref +!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32 +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return @@ -76,11 +82,17 @@ subroutine simple_real_add !CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref !CHECK: %[[I_START:.*]] = arith.constant 0 : i32 !CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref -!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref) { +!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref) { +!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32 -!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref +!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref +!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32 +!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref +!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref !CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32 -!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref +!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32 +!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref !CHECK: omp.terminator !CHECK: } !CHECK: return diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ca36350548577..5d84217b9c701 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -191,11 +191,8 @@ def ParallelOp : OpenMP_Op<"parallel", [ unsigned getNumReductionVars() { return getReductionVars().size(); } }]; let assemblyFormat = [{ - oilist( `reduction` `(` - custom( - $reduction_vars, type($reduction_vars), $reductions - ) `)` - | `if` `(` $if_expr_var `:` type($if_expr_var) `)` + oilist( + `if` `(` $if_expr_var `:` type($if_expr_var) `)` | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)` | `allocate` `(` custom( @@ -203,7 +200,7 @@ def ParallelOp : OpenMP_Op<"parallel", [ $allocators_vars, type($allocators_vars) ) `)` | `proc_bind` `(` custom($proc_bind_val) `)` - ) $region attr-dict + ) custom($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 381f17d080419..394f0629ba6e4 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/Interfaces/FoldInterfaces.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" @@ -34,6 +35,7 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" +#include "mlir/Support/LogicalResult.h" using namespace mlir; using namespace mlir::omp; @@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op, // Parser, printer and verifier for ReductionVarList //===----------------------------------------------------------------------===// +ParseResult +parseReductionClause(OpAsmParser &parser, Region ®ion, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &reductionSymbols, + SmallVectorImpl &privates) { + if (failed(parser.parseOptionalKeyword("reduction"))) + return failure(); + + SmallVector reductionVec; + + if (failed( + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + if (parser.parseAttribute(reductionVec.emplace_back()) || + parser.parseOperand(operands.emplace_back()) || + parser.parseArrow() || + parser.parseArgument(privates.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + + for (auto [prv, type] : llvm::zip_equal(privates, types)) { + prv.type = type; + } + SmallVector reductions(reductionVec.begin(), reductionVec.end()); + reductionSymbols = ArrayAttr::get(parser.getContext(), reductions); + return success(); +} + +static void printReductionClause(OpAsmPrinter &p, Operation *op, Region ®ion, + ValueRange operands, TypeRange types, + ArrayAttr reductionSymbols) { + p << "reduction("; + llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands, + region.front().getArguments(), types), + p, [&p](auto t) { + auto [sym, op, arg, type] = t; + p << sym << " " << op << " -> " << arg << " : " + << type; + }); + p << ") "; +} + +static ParseResult +parseParallelRegion(OpAsmParser &parser, Region ®ion, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &reductionSymbols) { + + llvm::SmallVector privates; + if (succeeded(parseReductionClause(parser, region, operands, types, + reductionSymbols, privates))) + return parser.parseRegion(region, privates); + + return parser.parseRegion(region); +} + +static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, + ValueRange operands, TypeRange types, + ArrayAttr reductionSymbols) { + if (reductionSymbols) + printReductionClause(p, op, region, operands, types, reductionSymbols); + p.printRegion(region, /*printEntryBlockArgs=*/false); +} + /// reduction-entry-list ::= reduction-entry /// | reduction-entry-list `,` reduction-entry /// reduction-entry ::= symbol-ref `->` ssa-id `:` type @@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region ®ion, loopVarTypes = SmallVector(ivs.size(), loopVarType); for (auto &iv : ivs) iv.type = loopVarType; + return parser.parseRegion(region, ivs); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 79956f82ed141..c87e895bb5404 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Allocate reduction vars SmallVector privateReductionVariables; DenseMap reductionVariableMap; - allocReductionVars(opInst, builder, moduleTranslation, allocaIP, - reductionDecls, privateReductionVariables, - reductionVariableMap); + { + llvm::IRBuilderBase::InsertPointGuard guard(builder); + builder.restoreIP(allocaIP); + auto args = opInst.getRegion().getArguments(); + + for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) { + llvm::Value *var = builder.CreateAlloca( + moduleTranslation.convertType(reductionDecls[i].getType())); + moduleTranslation.mapValue(args[i], var); + privateReductionVariables.push_back(var); + reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var); + } + } // Store the mapping between reduction variables and their private copies on // ModuleTranslation stack. It can be then recovered when translating diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 65a704d18107b..651405964c067 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) { func.func @parallel_reduction() { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr) - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr) + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32 + %3 = llvm.fadd %1, %2 : f32 + llvm.store %3, %prv : f32, !llvm.ptr omp.terminator } return @@ -654,13 +656,14 @@ func.func @parallel_reduction() { func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32 + llvm.fadd %1, %2 : f32 // CHECK: omp.yield omp.yield } @@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) { // CHECK-LABEL: func @parallel_reduction2 func.func @parallel_reduction2() { %0 = memref.alloca() : memref<1xf32> - // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) - omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) { + // CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>) + omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction - omp.reduction %1, %0 : f32, memref<1xf32> + %2 = arith.constant 0 : index + %3 = memref.load %prv[%2] : memref<1xf32> + // CHECK: llvm.fadd + %4 = llvm.fadd %1, %3 : f32 + memref.store %4, %prv[%2] : memref<1xf32> omp.terminator } return @@ -813,13 +819,14 @@ func.func @parallel_reduction2() { func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) { - omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) { // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { %1 = arith.constant 2.0 : f32 - // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32 + %3 = llvm.fadd %1, %2 : f32 // CHECK: omp.yield omp.yield } diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir index 93ab578df9e4e..dae83c0cf92ed 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir @@ -441,9 +441,11 @@ atomic { llvm.func @simple_reduction_parallel() { %c1 = llvm.mlir.constant(1 : i32) : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) { %1 = llvm.mlir.constant(2.0 : f32) : f32 - omp.reduction %1, %0 : f32, !llvm.ptr + %2 = llvm.load %prv : !llvm.ptr -> f32 + %3 = llvm.fadd %2, %1 : f32 + llvm.store %3, %prv : f32, !llvm.ptr omp.terminator } llvm.return @@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) { %lb = llvm.mlir.constant(1 : i64) : i64 %step = llvm.mlir.constant(1 : i64) : i64 - omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) { + omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) { omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) { %ival = llvm.trunc %iv : i64 to i32 - omp.reduction %ival, %0 : i32, !llvm.ptr + %lprv = llvm.load %prv : !llvm.ptr -> i32 + %add = llvm.add %lprv, %ival : i32 + llvm.store %add, %prv : i32, !llvm.ptr omp.yield } omp.terminator From c84f5259490ef697d2f0271af7d62dd888b2c094 Mon Sep 17 00:00:00 2001 From: Kiran Chandramohan Date: Wed, 7 Feb 2024 12:52:43 +0000 Subject: [PATCH 2/2] [Flang][OpenMP] Add new test to demonstrate reduction --- .../test/Lower/OpenMP/parallel-reduction.f90 | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 flang/test/Lower/OpenMP/parallel-reduction.f90 diff --git a/flang/test/Lower/OpenMP/parallel-reduction.f90 b/flang/test/Lower/OpenMP/parallel-reduction.f90 new file mode 100644 index 0000000000000..a07d118b0ba19 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction.f90 @@ -0,0 +1,38 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK: omp.reduction.declare @[[REDUCTION_DECLARE:[_a-z0-9]+]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[I0:[_a-z0-9]+]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[I0]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[C0:[_a-z0-9]+]]: i32, %[[C1:[_a-z0-9]+]]: i32): +!CHECK: %[[CR:[_a-z0-9]+]] = arith.addi %[[C0]], %[[C1]] : i32 +!CHECK: omp.yield(%[[CR]] : i32) +!CHECK: } +!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "mn"} { +!CHECK: %[[RED_ACCUM_REF:[_a-z0-9]+]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"} +!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32 +!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref +!CHECK: omp.parallel reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref) { +!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32 +!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref +!CHECK: omp.terminator +!CHECK: } +!CHECK: %[[RED_ACCUM_VAL:[_a-z0-9]+]] = fir.load %[[RED_ACCUM_DECL]]#0 : !fir.ref +!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[RED_ACCUM_VAL]]) fastmath : (!fir.ref, i32) -> i1 +!CHECK: return +!CHECK: } + +program mn + integer :: i + i = 0 + + !$omp parallel reduction(+:i) + i = 1 + !$omp end parallel + + print *, i +end program