Skip to content

Commit c7ea791

Browse files
committed
[mlir][flang][openmp] Rework parallel reduction operations
This patch reworks the way that parallel reduction operations function to better match the expected semantics from the OpenMP specification. Previously specific omp.reduction operations were used inside the region, meaning that the reduction only applied when the correct operation was used, whereas the specification states that any change to the variable inside the region should be taken into account for the reduction. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. This patch only makes the change for the `parallel` operation, the change for the `wsloop` operation will be in a separate patch.
1 parent b788d62 commit c7ea791

File tree

8 files changed

+173
-57
lines changed

8 files changed

+173
-57
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,9 @@ class ClauseProcessor {
621621
bool processReduction(
622622
mlir::Location currentLocation,
623623
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
624-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
624+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
625+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
626+
nullptr) const;
625627
bool processSectionsReduction(mlir::Location currentLocation) const;
626628
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
627629
bool
@@ -1077,7 +1079,9 @@ class ReductionProcessor {
10771079
Fortran::lower::AbstractConverter &converter,
10781080
const Fortran::parser::OmpReductionClause &reduction,
10791081
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1080-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1082+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1083+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
1084+
nullptr) {
10811085
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
10821086
mlir::omp::ReductionDeclareOp decl;
10831087
const auto &redOperator{
@@ -1106,7 +1110,9 @@ class ReductionProcessor {
11061110
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11071111
if (const auto *name{
11081112
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1109-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1113+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1114+
if (reductionSymbols)
1115+
reductionSymbols->push_back(symbol);
11101116
mlir::Value symVal = converter.getSymbolAddress(*symbol);
11111117
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
11121118
symVal = declOp.getBase();
@@ -1138,7 +1144,9 @@ class ReductionProcessor {
11381144
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
11391145
if (const auto *name{
11401146
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1141-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1147+
if (Fortran::semantics::Symbol * symbol{name->symbol}) {
1148+
if (reductionSymbols)
1149+
reductionSymbols->push_back(symbol);
11421150
mlir::Value symVal = converter.getSymbolAddress(*symbol);
11431151
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
11441152
symVal = declOp.getBase();
@@ -1932,13 +1940,16 @@ bool ClauseProcessor::processMap(
19321940
bool ClauseProcessor::processReduction(
19331941
mlir::Location currentLocation,
19341942
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1935-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1943+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1944+
llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
1945+
const {
19361946
return findRepeatableClause<ClauseTy::Reduction>(
19371947
[&](const ClauseTy::Reduction *reductionClause,
19381948
const Fortran::parser::CharBlock &) {
19391949
ReductionProcessor rp;
19401950
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
1941-
reductionVars, reductionDeclSymbols);
1951+
reductionVars, reductionDeclSymbols,
1952+
reductionSymbols);
19421953
});
19431954
}
19441955

@@ -2502,6 +2513,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25022513
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
25032514
reductionVars;
25042515
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2516+
llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
25052517

25062518
ClauseProcessor cp(converter, clauseList);
25072519
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2511,9 +2523,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25112523
cp.processDefault();
25122524
cp.processAllocate(allocatorOperands, allocateOperands);
25132525
if (!outerCombined)
2514-
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
2526+
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
2527+
&reductionSymbols);
25152528

2516-
return genOpWithBody<mlir::omp::ParallelOp>(
2529+
auto op = genOpWithBody<mlir::omp::ParallelOp>(
25172530
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
25182531
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25192532
numThreadsClauseOperand, allocateOperands, allocatorOperands,
@@ -2523,6 +2536,24 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25232536
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
25242537
reductionDeclSymbols),
25252538
procBindKindAttr);
2539+
2540+
// Add reduction block arguments
2541+
if (!reductionVars.empty()) {
2542+
mlir::Block &regionBlock = op.getRegion().front();
2543+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2544+
for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
2545+
auto savedIP = firOpBuilder.getInsertionPoint();
2546+
firOpBuilder.setInsertionPointToStart(&regionBlock);
2547+
auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
2548+
converter.bindSymbol(*sym, prv);
2549+
val.replaceUsesWithIf(prv, [&regionBlock](mlir::OpOperand &use) {
2550+
return use.getOwner()->getBlock() == &regionBlock;
2551+
});
2552+
firOpBuilder.setInsertionPoint(&regionBlock, savedIP);
2553+
}
2554+
}
2555+
2556+
return op;
25262557
}
25272558

25282559
static mlir::omp::SectionOp
@@ -3483,10 +3514,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
34833514
break;
34843515
}
34853516

3486-
if (singleDirective) {
3487-
genOpenMPReduction(converter, beginClauseList);
3517+
if (singleDirective)
34883518
return;
3489-
}
34903519

34913520
// Codegen for combined directives
34923521
bool combinedDirective = false;
@@ -3522,7 +3551,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
35223551
")");
35233552

35243553
genNestedEvaluations(converter, eval);
3525-
genOpenMPReduction(converter, beginClauseList);
35263554
}
35273555

35283556
static void

flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
2828
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
2929
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
30-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
31-
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
32-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
30+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
31+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
32+
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
33+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
34+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
3335
!CHECK: omp.terminator
3436
!CHECK: }
3537
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
4850
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
4951
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5052
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
51-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
52-
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
53-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
53+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
54+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
55+
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
56+
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
57+
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
5458
!CHECK: omp.terminator
5559
!CHECK: }
5660
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
7276
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
7377
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7478
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
75-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
79+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
7680
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
77-
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
81+
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
82+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
83+
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
84+
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
7885
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
79-
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
86+
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
87+
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
8088
!CHECK: omp.terminator
8189
!CHECK: }
8290
!CHECK: return

flang/test/Lower/OpenMP/parallel-reduction-add.f90

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
2929
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
3030
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
31-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
31+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
32+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
33+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
3234
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
33-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
35+
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
36+
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : i32, !fir.ref<i32>
3437
!CHECK: omp.terminator
3538
!CHECK: }
3639
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
5053
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
5154
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
5255
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
53-
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
56+
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
57+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
58+
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
5459
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
55-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
60+
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
61+
!CHECK: hlfir.assign %[[RES]] to %[[PRV]] : f32, !fir.ref<f32>
5662
!CHECK: omp.terminator
5763
!CHECK: }
5864
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
7682
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
7783
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
7884
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
79-
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
85+
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
86+
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
87+
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
8088
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
81-
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
89+
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RPRV]] : !fir.ref<f32>
90+
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
91+
!CHECK: hlfir.assign %[[RES1]] to %[[RPRV]] : f32, !fir.ref<f32>
92+
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IPRV]] : !fir.ref<i32>
8293
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
83-
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
94+
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
95+
!CHECK: hlfir.assign %[[RES0]] to %[[IPRV]] : i32, !fir.ref<i32>
8496
!CHECK: omp.terminator
8597
!CHECK: }
8698
!CHECK: return

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,16 @@ def ParallelOp : OpenMP_Op<"parallel", [
200200
unsigned getNumReductionVars() { return getReductionVars().size(); }
201201
}];
202202
let assemblyFormat = [{
203-
oilist( `reduction` `(`
204-
custom<ReductionVarList>(
205-
$reduction_vars, type($reduction_vars), $reductions
206-
) `)`
207-
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
203+
oilist(
204+
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
208205
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
209206
| `allocate` `(`
210207
custom<AllocateAndAllocator>(
211208
$allocate_vars, type($allocate_vars),
212209
$allocators_vars, type($allocators_vars)
213210
) `)`
214211
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
215-
) $region attr-dict
212+
) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
216213
}];
217214
let hasVerifier = 1;
218215
}

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Interfaces/FoldInterfaces.h"
2222

2323
#include "llvm/ADT/BitVector.h"
24+
#include "llvm/ADT/STLExtras.h"
2425
#include "llvm/ADT/STLForwardCompat.h"
2526
#include "llvm/ADT/SmallString.h"
2627
#include "llvm/ADT/StringExtras.h"
@@ -427,6 +428,55 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
427428
// Parser, printer and verifier for ReductionVarList
428429
//===----------------------------------------------------------------------===//
429430

431+
static ParseResult
432+
parseWsReduction(OpAsmParser &parser, Region &region,
433+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
434+
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
435+
436+
// possibly parse reduction
437+
if (failed(parser.parseOptionalKeyword("reduction")))
438+
return parser.parseRegion(region);
439+
440+
SmallVector<SymbolRefAttr> reductionVec;
441+
SmallVector<OpAsmParser::Argument> privates;
442+
443+
if (failed(
444+
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
445+
if (parser.parseAttribute(reductionVec.emplace_back()) ||
446+
parser.parseOperand(operands.emplace_back()) ||
447+
parser.parseArrow() ||
448+
parser.parseArgument(privates.emplace_back()) ||
449+
parser.parseColonType(types.emplace_back()))
450+
return failure();
451+
return success();
452+
})))
453+
return failure();
454+
455+
for (std::size_t i = 0; i < privates.size(); ++i) {
456+
privates[i].type = types[i];
457+
}
458+
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
459+
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
460+
return parser.parseRegion(region, privates);
461+
}
462+
463+
void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
464+
ValueRange operands, TypeRange types,
465+
ArrayAttr reductionSymbols) {
466+
if (reductionSymbols) {
467+
p << "reduction(";
468+
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
469+
region.front().getArguments(), types),
470+
p, [&p](auto t) {
471+
auto [sym, op, arg, type] = t;
472+
p << sym << " " << op << " -> " << arg << " : "
473+
<< type;
474+
});
475+
p << ") ";
476+
}
477+
p.printRegion(region, /*printEntryBlockArgs=*/false);
478+
}
479+
430480
/// reduction-entry-list ::= reduction-entry
431481
/// | reduction-entry-list `,` reduction-entry
432482
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10181018
// Allocate reduction vars
10191019
SmallVector<llvm::Value *> privateReductionVariables;
10201020
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1022-
reductionDecls, privateReductionVariables,
1023-
reductionVariableMap);
1021+
{
1022+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
1023+
builder.restoreIP(allocaIP);
1024+
auto args = opInst.getRegion().getArguments();
1025+
1026+
for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
1027+
llvm::Value *var = builder.CreateAlloca(
1028+
moduleTranslation.convertType(reductionDecls[i].getType()));
1029+
moduleTranslation.mapValue(args[i], var);
1030+
privateReductionVariables.push_back(var);
1031+
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
1032+
}
1033+
}
10241034

10251035
// Store the mapping between reduction variables and their private copies on
10261036
// ModuleTranslation stack. It can be then recovered when translating

0 commit comments

Comments
 (0)