Skip to content

Commit 576993a

Browse files
committed
[mlir][flang][openmp] Rework parallel reduction operations
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch.
1 parent da6806d commit 576993a

File tree

8 files changed

+174
-61
lines changed

8 files changed

+174
-61
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3131
#include "mlir/Dialect/SCF/IR/SCF.h"
3232
#include "mlir/Transforms/RegionUtils.h"
33-
#include "llvm/ADT/STLExtras.h"
33+
#include "llvm/ADT/SmallVector.h"
3434
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3535
#include "llvm/Support/CommandLine.h"
3636

@@ -618,7 +618,9 @@ class ClauseProcessor {
618618
bool processReduction(
619619
mlir::Location currentLocation,
620620
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
621-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
621+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
622+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *mapSymbols =
623+
nullptr) const;
622624
bool processSectionsReduction(mlir::Location currentLocation) const;
623625
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
624626
bool
@@ -1177,12 +1179,14 @@ static mlir::Value getIfClauseOperand(
11771179

11781180
/// Creates a reduction declaration and associates it with an OpenMP block
11791181
/// directive.
1180-
static void
1181-
addReductionDecl(mlir::Location currentLocation,
1182-
Fortran::lower::AbstractConverter &converter,
1183-
const Fortran::parser::OmpReductionClause &reduction,
1184-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1185-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1182+
static void addReductionDecl(
1183+
mlir::Location currentLocation,
1184+
Fortran::lower::AbstractConverter &converter,
1185+
const Fortran::parser::OmpReductionClause &reduction,
1186+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1187+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1188+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
1189+
nullptr) {
11861190
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
11871191
mlir::omp::ReductionDeclareOp decl;
11881192
const auto &redOperator{
@@ -1210,7 +1214,9 @@ addReductionDecl(mlir::Location currentLocation,
12101214
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
12111215
if (const auto *name{
12121216
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1213-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1217+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1218+
if (reductionSymbols)
1219+
reductionSymbols->push_back(symbol);
12141220
mlir::Value symVal = converter.getSymbolAddress(*symbol);
12151221
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
12161222
symVal = declOp.getBase();
@@ -1249,7 +1255,9 @@ addReductionDecl(mlir::Location currentLocation,
12491255
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
12501256
if (const auto *name{
12511257
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1252-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1258+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1259+
if (reductionSymbols)
1260+
reductionSymbols->push_back(symbol);
12531261
mlir::Value symVal = converter.getSymbolAddress(*symbol);
12541262
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
12551263
symVal = declOp.getBase();
@@ -1859,12 +1867,14 @@ bool ClauseProcessor::processMap(
18591867
bool ClauseProcessor::processReduction(
18601868
mlir::Location currentLocation,
18611869
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1862-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1870+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1871+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
1872+
const {
18631873
return findRepeatableClause<ClauseTy::Reduction>(
18641874
[&](const ClauseTy::Reduction *reductionClause,
18651875
const Fortran::parser::CharBlock &) {
18661876
addReductionDecl(currentLocation, converter, reductionClause->v,
1867-
reductionVars, reductionDeclSymbols);
1877+
reductionVars, reductionDeclSymbols, reductionSymbols);
18681878
});
18691879
}
18701880

@@ -2357,6 +2367,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
23572367
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
23582368
reductionVars;
23592369
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2370+
llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
23602371

23612372
ClauseProcessor cp(converter, clauseList);
23622373
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2366,9 +2377,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
23662377
cp.processDefault();
23672378
cp.processAllocate(allocatorOperands, allocateOperands);
23682379
if (!outerCombined)
2369-
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
2380+
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
2381+
&reductionSymbols);
23702382

2371-
return genOpWithBody<mlir::omp::ParallelOp>(
2383+
auto op = genOpWithBody<mlir::omp::ParallelOp>(
23722384
converter, eval, currentLocation, outerCombined, &clauseList,
23732385
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
23742386
numThreadsClauseOperand, allocateOperands, allocatorOperands,
@@ -2378,6 +2390,21 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
23782390
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
23792391
reductionDeclSymbols),
23802392
procBindKindAttr);
2393+
2394+
// Add reduction block arguments
2395+
if (!reductionVars.empty()) {
2396+
mlir::Block &regionBlock = op.getRegion().front();
2397+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2398+
for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
2399+
auto savedIP = firOpBuilder.getInsertionPoint();
2400+
firOpBuilder.setInsertionPointToStart(&regionBlock);
2401+
auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
2402+
converter.bindSymbol(*sym, prv);
2403+
firOpBuilder.setInsertionPoint(&regionBlock, savedIP);
2404+
}
2405+
}
2406+
2407+
return op;
23812408
}
23822409

23832410
static mlir::omp::SectionOp
@@ -3350,7 +3377,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
33503377
}
33513378

33523379
genNestedEvaluations(converter, eval);
3353-
genOpenMPReduction(converter, beginClauseList);
33543380
}
33553381

33563382
static void

flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
2828
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
2929
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
30-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
31-
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
32-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
30+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
31+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
32+
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
33+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
34+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
3335
!CHECK: omp.terminator
3436
!CHECK: }
3537
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
4850
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
4951
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5052
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
51-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
52-
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
53-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
53+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
54+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
55+
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
56+
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
57+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
5458
!CHECK: omp.terminator
5559
!CHECK: }
5660
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
7276
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
7377
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7478
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
75-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
79+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
7680
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
77-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
81+
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
82+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
83+
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
84+
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
7885
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
79-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
86+
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
87+
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
8088
!CHECK: omp.terminator
8189
!CHECK: }
8290
!CHECK: return

flang/test/Lower/OpenMP/parallel-reduction-add.f90

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
2929
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
3030
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
31-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
31+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
32+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
33+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
3234
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
33-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
35+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
36+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
3437
!CHECK: omp.terminator
3538
!CHECK: }
3639
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
5053
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
5154
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5255
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
53-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
56+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
57+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
58+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
5459
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
55-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
60+
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
61+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
5662
!CHECK: omp.terminator
5763
!CHECK: }
5864
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
7682
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
7783
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7884
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
79-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
85+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
86+
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
87+
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
8088
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
81-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
89+
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
90+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
91+
!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
92+
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
8293
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
83-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
94+
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
95+
!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
8496
!CHECK: omp.terminator
8597
!CHECK: }
8698
!CHECK: return

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,16 @@ def ParallelOp : OpenMP_Op<"parallel", [
200200
unsigned getNumReductionVars() { return getReductionVars().size(); }
201201
}];
202202
let assemblyFormat = [{
203-
oilist( `reduction` `(`
204-
custom<ReductionVarList>(
205-
$reduction_vars, type($reduction_vars), $reductions
206-
) `)`
207-
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
203+
oilist(
204+
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
208205
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
209206
| `allocate` `(`
210207
custom<AllocateAndAllocator>(
211208
$allocate_vars, type($allocate_vars),
212209
$allocators_vars, type($allocators_vars)
213210
) `)`
214211
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
215-
) $region attr-dict
212+
) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
216213
}];
217214
let hasVerifier = 1;
218215
}

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
#include "mlir/Interfaces/FoldInterfaces.h"
2222

2323
#include "llvm/ADT/BitVector.h"
24+
#include "llvm/ADT/STLExtras.h"
2425
#include "llvm/ADT/STLForwardCompat.h"
2526
#include "llvm/ADT/SmallString.h"
2627
#include "llvm/ADT/StringExtras.h"
2728
#include "llvm/ADT/StringRef.h"
28-
#include "llvm/ADT/TypeSwitch.h"
2929
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3030
#include <cstddef>
3131
#include <optional>
@@ -434,6 +434,55 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
434434
// Parser, printer and verifier for ReductionVarList
435435
//===----------------------------------------------------------------------===//
436436

437+
static ParseResult
438+
parseWsReduction(OpAsmParser &parser, Region &region,
439+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
440+
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
441+
442+
// possibly parse reduction
443+
if (failed(parser.parseOptionalKeyword("reduction")))
444+
return parser.parseRegion(region);
445+
446+
SmallVector<SymbolRefAttr> reductionVec;
447+
SmallVector<OpAsmParser::Argument> privates;
448+
449+
if (failed(
450+
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
451+
if (parser.parseAttribute(reductionVec.emplace_back()) ||
452+
parser.parseOperand(operands.emplace_back()) ||
453+
parser.parseArrow() ||
454+
parser.parseArgument(privates.emplace_back()) ||
455+
parser.parseColonType(types.emplace_back()))
456+
return failure();
457+
return success();
458+
})))
459+
return failure();
460+
461+
for (std::size_t i = 0; i < privates.size(); ++i) {
462+
privates[i].type = types[i];
463+
}
464+
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
465+
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
466+
return parser.parseRegion(region, privates);
467+
}
468+
469+
void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
470+
ValueRange operands, TypeRange types,
471+
ArrayAttr reductionSymbols) {
472+
if (reductionSymbols) {
473+
p << "reduction(";
474+
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
475+
region.front().getArguments(), types),
476+
p, [&p](auto t) {
477+
auto [sym, op, arg, type] = t;
478+
p << sym << " " << op << " -> " << arg << " : "
479+
<< type;
480+
});
481+
p << ") ";
482+
}
483+
p.printRegion(region, /*printEntryBlockArgs=*/false);
484+
}
485+
437486
/// reduction-entry-list ::= reduction-entry
438487
/// | reduction-entry-list `,` reduction-entry
439488
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10181018
// Allocate reduction vars
10191019
SmallVector<llvm::Value *> privateReductionVariables;
10201020
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1022-
reductionDecls, privateReductionVariables,
1023-
reductionVariableMap);
1021+
{
1022+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
1023+
builder.restoreIP(allocaIP);
1024+
auto args = opInst.getRegion().getArguments();
1025+
1026+
for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
1027+
llvm::Value *var = builder.CreateAlloca(
1028+
moduleTranslation.convertType(reductionDecls[i].getType()));
1029+
moduleTranslation.mapValue(args[i], var);
1030+
privateReductionVariables.push_back(var);
1031+
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
1032+
}
1033+
}
10241034

10251035
// Store the mapping between reduction variables and their private copies on
10261036
// ModuleTranslation stack. It can be then recovered when translating

0 commit comments

Comments
 (0)