-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][flang][openmp] Rework parallel reduction operations #79308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: David Truby (DavidTruby) ChangesThis patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the Patch is 23.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79308.diff 8 Files Affected:
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c770d1c60718c35..f12111a370b1e9b 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -30,7 +30,7 @@
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/CommandLine.h"
@@ -618,7 +618,9 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *mapSymbols =
+ nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
@@ -1177,12 +1179,14 @@ static mlir::Value getIfClauseOperand(
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
-static void
-addReductionDecl(mlir::Location currentLocation,
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+static void addReductionDecl(
+ mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
+ nullptr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
@@ -1210,7 +1214,9 @@ addReductionDecl(mlir::Location currentLocation,
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1249,7 +1255,9 @@ addReductionDecl(mlir::Location currentLocation,
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1859,12 +1867,14 @@ bool ClauseProcessor::processMap(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
+ const {
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols);
+ reductionVars, reductionDeclSymbols, reductionSymbols);
});
}
@@ -2357,6 +2367,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
ClauseProcessor cp(converter, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2366,9 +2377,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
cp.processDefault();
cp.processAllocate(allocatorOperands, allocateOperands);
if (!outerCombined)
- cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
+ cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
+ &reductionSymbols);
- return genOpWithBody<mlir::omp::ParallelOp>(
+ auto op = genOpWithBody<mlir::omp::ParallelOp>(
converter, eval, currentLocation, outerCombined, &clauseList,
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
@@ -2378,6 +2390,21 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
+
+ // Add reduction block arguments
+ if (!reductionVars.empty()) {
+ mlir::Block ®ionBlock = op.getRegion().front();
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
+ auto savedIP = firOpBuilder.getInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(®ionBlock);
+ auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
+ converter.bindSymbol(*sym, prv);
+ firOpBuilder.setInsertionPoint(®ionBlock, savedIP);
+ }
+ }
+
+ return op;
}
static mlir::omp::SectionOp
@@ -3350,7 +3377,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}
genNestedEvaluations(converter, eval);
- genOpenMPReduction(converter, beginClauseList);
}
static void
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
index 6580aeb13ccd1e1..4b223e822760a99 100644
--- a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
@@ -27,9 +27,11 @@
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
-!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
-!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
-!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK: fir.store %[[RES1]] to %[[PRV1]]
+!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK: fir.store %[[RES0]] to %[[PRV0]]
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
index 81a93aebbd26619..d88b3aab2f351c8 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
@@ -28,9 +28,12 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
+!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
+!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d614f2666a85abc..b15f00656943b6e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -200,11 +200,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
unsigned getNumReductionVars() { return getReductionVars().size(); }
}];
let assemblyFormat = [{
- oilist( `reduction` `(`
- custom<ReductionVarList>(
- $reduction_vars, type($reduction_vars), $reductions
- ) `)`
- | `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+ oilist(
+ `if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
@@ -212,7 +209,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
- ) $region attr-dict
+ ) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e69cd0d386bd24..3b52b95c2982d03 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -21,11 +21,11 @@
#include "mlir/Interfaces/FoldInterfaces.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include <cstddef>
#include <optional>
@@ -434,6 +434,55 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
+static ParseResult
+parseWsReduction(OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+ SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+
+ // possibly parse reduction
+ if (failed(parser.parseOptionalKeyword("reduction")))
+ return parser.parseRegion(region);
+
+ SmallVector<SymbolRefAttr> reductionVec;
+ SmallVector<OpAsmParser::Argument> privates;
+
+ if (failed(
+ parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(reductionVec.emplace_back()) ||
+ parser.parseOperand(operands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(privates.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ for (std::size_t i = 0; i < privates.size(); ++i) {
+ privates[i].type = types[i];
+ }
+ SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
+ reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+ return parser.parseRegion(region, privates);
+}
+
+void printWsReduction(OpAsmPrinter &p, Operation *op, Region ®ion,
+ ValueRange operands, TypeRange types,
+ ArrayAttr reductionSymbols) {
+ if (reductionSymbols) {
+ p << "reduction(";
+ llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
+ region.front().getArguments(), types),
+ p, [&p](auto t) {
+ auto [sym, op, arg, type] = t;
+ p << sym << " " << op << " -> " << arg << " : "
+ << type;
+ });
+ p << ") ";
+ }
+ p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
/// reduction-entry-list ::= reduction-entry
/// | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e7aebc3ce4be569..02b0e587e462863 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Allocate reduction vars
SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
- reductionDecls, privateReductionVariables,
- reductionVariableMap);
+ {
+ llvm::IRBuilderBase::InsertPointGuard guard(builder);
+ builder.restoreIP(allocaIP);
+ auto args = opInst.getRegion().getArguments();
+
+ for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
+ llvm::Value *var = builder.CreateAlloca(
+ moduleTranslation.convertType(reductionDecls[i].getType()));
+ moduleTranslation.mapValue(args[i], var);
+ privateReductionVariables.push_back(var);
+ reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
+ }
+ }
// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3d4f6435572f7f5..50789a8157da77c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -610,11 +610,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
func.func @parallel_reduction() {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr)
- omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+ // CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr)
+ omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
%1 = arith.constant 2.0 : f32
- // CHECK: omp.reduction %{{.+}}, %{{.+}}
- omp.reduction %1, %0 : f32, !llvm.ptr
+ %2 = llvm.load %prv : !llvm.ptr -> f32
+ // CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
+ %3 = llvm.fadd %1, %2 : f32
+ llvm.store %3, %prv : f32, !llvm.ptr
omp.terminator
}
return
@@ -624,13 +626,14 @@ func.func @parallel_reduction() {
func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
- // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
- omp.parallel reduction(@add_f3...
[truncated]
|
Please could you rebase this on top of the current HEAD. There has been quite a bit of churn in OpenMP.cpp. |
576993a
to
6e794d1
Compare
Sorry, this rebase took longer than I expected; it should be updated now. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
7bd65b2
to
7890b53
Compare
if (singleDirective) { | ||
genOpenMPReduction(converter, beginClauseList); | ||
if (singleDirective) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is genOpenMPReduction
being removed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If genOpenMPReduction is called her it replaces the normal operations (e.g. arith.addi
) with an omp.reduction
operation that is no longer used/needed with the new reduction lowering. At some point when all operations have moved to the new reduction style, those operations should be removed entirely from the OpenMP dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However we have not changed processing for the single
directive. So it should continue to use the old flow? Although, the lack of testing suggests it was never used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
singleDirective=true
doesn't appear to mean the single
directive is the one being processed, but rather that a single directive is being processed as opposed to a nested one. The variable name is quite unclear!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless, of the directives processed by this function, only parallel
and teams
can actually have reduction clauses on them, and I don't think Flang supports teams
reductions yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, reduction for teams is not supported.
@@ -3522,7 +3551,6 @@ genOMP(Fortran::lower::AbstractConverter &converter, | |||
")"); | |||
|
|||
genNestedEvaluations(converter, eval); | |||
genOpenMPReduction(converter, beginClauseList); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is genOpenMPReduction
being removed here?
flang/lib/Lower/OpenMP.cpp
Outdated
mlir::Value indexVal = | ||
fir::getBase(op.getRegion().front().getArgument(argIndex)); | ||
storeOp = | ||
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); | ||
} | ||
firOpBuilder.setInsertionPointAfter(storeOp); | ||
} else if (reductionArgs.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add a TODO to remove the else when reduction support for worksharing loop is added in the same way as this patch for parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
c9ad4d7
to
10f9807
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Could a new test be added that would show the new capabilities? |
I have added a new test test/Lower/OpenMP/parallel-reduction.f90 (https://github.com/llvm/llvm-project/pull/79308/files#diff-a659fd2e2c1970fcce5baa536110c82fbc58e30c522b8e36c769cd6eb8383436). The test should print the number of threads. Previously this test case would not work correctly since it always expected the reduction operation to be present in the region. Now we have removed this restriction allowing more cases to be handled and to match the spec better. If there are no concerns, I would like to rebase and submit this. |
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch.
1ed4bc3
to
c84f525
Compare
This operation did not model the behaviour of reductions in the openmp standard. It has since been replaced by block arguments on the outer operation. See llvm#79308 and llvm#80019
This operation did not model the behaviour of reductions in the openmp standard. It has since been replaced by block arguments on the outer operation. See llvm#79308 and llvm#80019
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction.
The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region.
This patch only makes the change for the
parallel
operation, the change for thewsloop
operation will be in a separate patch.