Skip to content

Commit 6e794d1

Browse files
committed
[mlir][flang][openmp] Rework parallel reduction operations
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch.
1 parent b788d62 commit 6e794d1

File tree

8 files changed

+169
-57
lines changed

8 files changed

+169
-57
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,8 @@ class ClauseProcessor {
621621
bool processReduction(
622622
mlir::Location currentLocation,
623623
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
624+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols = nullptr) const;
625626
bool processSectionsReduction(mlir::Location currentLocation) const;
626627
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
627628
bool
@@ -1077,7 +1078,8 @@ class ReductionProcessor {
10771078
Fortran::lower::AbstractConverter &converter,
10781079
const Fortran::parser::OmpReductionClause &reduction,
10791080
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1080-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1081+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1082+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols = nullptr) {
10811083
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
10821084
mlir::omp::ReductionDeclareOp decl;
10831085
const auto &redOperator{
@@ -1106,7 +1108,9 @@ class ReductionProcessor {
11061108
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11071109
if (const auto *name{
11081110
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1109-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1111+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1112+
if (reductionSymbols)
1113+
reductionSymbols->push_back(symbol);
11101114
mlir::Value symVal = converter.getSymbolAddress(*symbol);
11111115
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
11121116
symVal = declOp.getBase();
@@ -1138,7 +1142,9 @@ class ReductionProcessor {
11381142
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11391143
if (const auto *name{
11401144
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1141-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1145+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1146+
if (reductionSymbols)
1147+
reductionSymbols->push_back(symbol);
11421148
mlir::Value symVal = converter.getSymbolAddress(*symbol);
11431149
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
11441150
symVal = declOp.getBase();
@@ -1932,13 +1938,14 @@ bool ClauseProcessor::processMap(
19321938
bool ClauseProcessor::processReduction(
19331939
mlir::Location currentLocation,
19341940
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1935-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1941+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1942+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols) const {
19361943
return findRepeatableClause<ClauseTy::Reduction>(
19371944
[&](const ClauseTy::Reduction *reductionClause,
19381945
const Fortran::parser::CharBlock &) {
19391946
ReductionProcessor rp;
19401947
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
1941-
reductionVars, reductionDeclSymbols);
1948+
reductionVars, reductionDeclSymbols, reductionSymbols);
19421949
});
19431950
}
19441951

@@ -2502,6 +2509,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25022509
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
25032510
reductionVars;
25042511
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2512+
llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
25052513

25062514
ClauseProcessor cp(converter, clauseList);
25072515
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2511,9 +2519,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25112519
cp.processDefault();
25122520
cp.processAllocate(allocatorOperands, allocateOperands);
25132521
if (!outerCombined)
2514-
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
2522+
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
2523+
&reductionSymbols);
25152524

2516-
return genOpWithBody<mlir::omp::ParallelOp>(
2525+
auto op = genOpWithBody<mlir::omp::ParallelOp>(
25172526
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
25182527
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25192528
numThreadsClauseOperand, allocateOperands, allocatorOperands,
@@ -2523,6 +2532,24 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25232532
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
25242533
reductionDeclSymbols),
25252534
procBindKindAttr);
2535+
2536+
// Add reduction block arguments
2537+
if (!reductionVars.empty()) {
2538+
mlir::Block &regionBlock = op.getRegion().front();
2539+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2540+
for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
2541+
auto savedIP = firOpBuilder.getInsertionPoint();
2542+
firOpBuilder.setInsertionPointToStart(&regionBlock);
2543+
auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
2544+
converter.bindSymbol(*sym, prv);
2545+
val.replaceUsesWithIf(prv, [&regionBlock](mlir::OpOperand &use) {
2546+
return use.getOwner()->getBlock() == &regionBlock;
2547+
});
2548+
firOpBuilder.setInsertionPoint(&regionBlock, savedIP);
2549+
}
2550+
}
2551+
2552+
return op;
25262553
}
25272554

25282555
static mlir::omp::SectionOp
@@ -3483,10 +3510,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
34833510
break;
34843511
}
34853512

3486-
if (singleDirective) {
3487-
genOpenMPReduction(converter, beginClauseList);
3513+
if (singleDirective)
34883514
return;
3489-
}
34903515

34913516
// Codegen for combined directives
34923517
bool combinedDirective = false;
@@ -3522,7 +3547,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
35223547
")");
35233548

35243549
genNestedEvaluations(converter, eval);
3525-
genOpenMPReduction(converter, beginClauseList);
35263550
}
35273551

35283552
static void

flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
2828
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
2929
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
30-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
31-
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
32-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
30+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
31+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
32+
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
33+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
34+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
3335
!CHECK: omp.terminator
3436
!CHECK: }
3537
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
4850
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
4951
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5052
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
51-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
52-
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
53-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
53+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
54+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
55+
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
56+
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
57+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
5458
!CHECK: omp.terminator
5559
!CHECK: }
5660
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
7276
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
7377
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7478
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
75-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
79+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
7680
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
77-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
81+
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
82+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
83+
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
84+
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
7885
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
79-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
86+
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
87+
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
8088
!CHECK: omp.terminator
8189
!CHECK: }
8290
!CHECK: return

flang/test/Lower/OpenMP/parallel-reduction-add.f90

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
2929
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
3030
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
31-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
31+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
32+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
33+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
3234
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
33-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
35+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
36+
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : i32, !fir.ref<i32>
3437
!CHECK: omp.terminator
3538
!CHECK: }
3639
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
5053
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
5154
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5255
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
53-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
56+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
57+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
58+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
5459
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
55-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
60+
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
61+
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : f32, !fir.ref<f32>
5662
!CHECK: omp.terminator
5763
!CHECK: }
5864
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
7682
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
7783
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7884
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
79-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
85+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
86+
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
87+
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
8088
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
81-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
89+
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RPRV]] : !fir.ref<f32>
90+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
91+
!CHECK: hlfir.assign %[[RES1]] to %[[RPRV]] : f32, !fir.ref<f32>
92+
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IPRV]] : !fir.ref<i32>
8293
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
83-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
94+
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
95+
!CHECK: hlfir.assign %[[RES0]] to %[[IPRV]] : i32, !fir.ref<i32>
8496
!CHECK: omp.terminator
8597
!CHECK: }
8698
!CHECK: return

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,16 @@ def ParallelOp : OpenMP_Op<"parallel", [
200200
unsigned getNumReductionVars() { return getReductionVars().size(); }
201201
}];
202202
let assemblyFormat = [{
203-
oilist( `reduction` `(`
204-
custom<ReductionVarList>(
205-
$reduction_vars, type($reduction_vars), $reductions
206-
) `)`
207-
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
203+
oilist(
204+
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
208205
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
209206
| `allocate` `(`
210207
custom<AllocateAndAllocator>(
211208
$allocate_vars, type($allocate_vars),
212209
$allocators_vars, type($allocators_vars)
213210
) `)`
214211
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
215-
) $region attr-dict
212+
) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
216213
}];
217214
let hasVerifier = 1;
218215
}

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Interfaces/FoldInterfaces.h"
2222

2323
#include "llvm/ADT/BitVector.h"
24+
#include "llvm/ADT/STLExtras.h"
2425
#include "llvm/ADT/STLForwardCompat.h"
2526
#include "llvm/ADT/SmallString.h"
2627
#include "llvm/ADT/StringExtras.h"
@@ -427,6 +428,55 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
427428
// Parser, printer and verifier for ReductionVarList
428429
//===----------------------------------------------------------------------===//
429430

431+
static ParseResult
432+
parseWsReduction(OpAsmParser &parser, Region &region,
433+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
434+
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
435+
436+
// possibly parse reduction
437+
if (failed(parser.parseOptionalKeyword("reduction")))
438+
return parser.parseRegion(region);
439+
440+
SmallVector<SymbolRefAttr> reductionVec;
441+
SmallVector<OpAsmParser::Argument> privates;
442+
443+
if (failed(
444+
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
445+
if (parser.parseAttribute(reductionVec.emplace_back()) ||
446+
parser.parseOperand(operands.emplace_back()) ||
447+
parser.parseArrow() ||
448+
parser.parseArgument(privates.emplace_back()) ||
449+
parser.parseColonType(types.emplace_back()))
450+
return failure();
451+
return success();
452+
})))
453+
return failure();
454+
455+
for (std::size_t i = 0; i < privates.size(); ++i) {
456+
privates[i].type = types[i];
457+
}
458+
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
459+
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
460+
return parser.parseRegion(region, privates);
461+
}
462+
463+
void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
464+
ValueRange operands, TypeRange types,
465+
ArrayAttr reductionSymbols) {
466+
if (reductionSymbols) {
467+
p << "reduction(";
468+
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
469+
region.front().getArguments(), types),
470+
p, [&p](auto t) {
471+
auto [sym, op, arg, type] = t;
472+
p << sym << " " << op << " -> " << arg << " : "
473+
<< type;
474+
});
475+
p << ") ";
476+
}
477+
p.printRegion(region, /*printEntryBlockArgs=*/false);
478+
}
479+
430480
/// reduction-entry-list ::= reduction-entry
431481
/// | reduction-entry-list `,` reduction-entry
432482
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10181018
// Allocate reduction vars
10191019
SmallVector<llvm::Value *> privateReductionVariables;
10201020
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1022-
reductionDecls, privateReductionVariables,
1023-
reductionVariableMap);
1021+
{
1022+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
1023+
builder.restoreIP(allocaIP);
1024+
auto args = opInst.getRegion().getArguments();
1025+
1026+
for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
1027+
llvm::Value *var = builder.CreateAlloca(
1028+
moduleTranslation.convertType(reductionDecls[i].getType()));
1029+
moduleTranslation.mapValue(args[i], var);
1030+
privateReductionVariables.push_back(var);
1031+
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
1032+
}
1033+
}
10241034

10251035
// Store the mapping between reduction variables and their private copies on
10261036
// ModuleTranslation stack. It can be then recovered when translating

0 commit comments

Comments
 (0)