From 2fd5a4de51ff690cb144f9902bf20c7c16c2f036 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 19 Dec 2023 18:10:56 +0100 Subject: [PATCH 1/6] [mlir][scf] Add reductions support to `scf.parallel` fusion --- .../SCF/Transforms/ParallelLoopFusion.cpp | 52 ++++++-- .../Dialect/SCF/parallel-loop-fusion.mlir | 124 +++++++++++++++++- 2 files changed, 166 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index d3dca1427e517..7d9e220518441 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -161,29 +161,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, } /// Prepends operations of firstPloop's body into secondPloop's body. -static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, - OpBuilder b, +/// Updates secondPloop with new loop. +static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, + OpBuilder builder, llvm::function_ref mayAlias) { + Block *block1 = firstPloop.getBody(); + Block *block2 = secondPloop.getBody(); IRMapping firstToSecondPloopIndices; - firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), - secondPloop.getBody()->getArguments()); + firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) return; - b.setInsertionPointToStart(secondPloop.getBody()); - for (auto &op : firstPloop.getBody()->without_terminator()) - b.clone(op, firstToSecondPloopIndices); + DominanceInfo dom; + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) + return; + + ValueRange inits1 = firstPloop.getInitVals(); + ValueRange inits2 = secondPloop.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + IRRewriter b(builder); + b.setInsertionPoint(secondPloop); + auto newSecondPloop = b.create( + secondPloop.getLoc(), secondPloop.getLowerBound(), + secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); + + Block *newBlock = newSecondPloop.getBody(); + newBlock->getTerminator()->erase(); + + block1->getTerminator()->erase(); + + b.inlineBlockBefore(block1, newBlock, newBlock->end(), + newBlock->getArguments()); + b.inlineBlockBefore(block2, newBlock, newBlock->end(), + newBlock->getArguments()); + + ValueRange results = newSecondPloop.getResults(); + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); firstPloop.erase(); + secondPloop.erase(); + secondPloop = newSecondPloop; } void mlir::scf::naivelyFuseParallelOps( Region ®ion, llvm::function_ref mayAlias) { OpBuilder b(region); // Consider every single block and attempt to fuse adjacent loops. + SmallVector, 1> ploopChains; for (auto &block : region) { - SmallVector, 1> ploopChains{{}}; + ploopChains.clear(); + ploopChains.push_back({}); + // Not using `walk()` to traverse only top-level parallel loops and also // make sure that there are no side-effecting ops between the parallel // loops. @@ -201,7 +235,7 @@ void mlir::scf::naivelyFuseParallelOps( // TODO: Handle region side effects properly. noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; } - for (ArrayRef ploops : ploopChains) { + for (MutableArrayRef ploops : ploopChains) { for (int i = 0, e = ploops.size(); i + 1 < e; ++i) fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); } diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 9c136bb635658..94ccbff4d8560 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -89,7 +89,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32> scf.reduce } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %res_elem = arith.addf %A_elem, %c2fp : f32 memref.store %res_elem, %B[%i, %j] : memref<2x2xf32> @@ -575,3 +575,125 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var( // CHECK-NEXT: } // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32> // CHECK-NEXT: return + +// ----- + +func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + return %res1, %res2 : f32, f32 +} + +// CHECK-LABEL: func @fuse_reductions +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]]) : f32 { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_B]]) : f32 { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: scf.yield +// CHECK: return %[[RES]]#0, %[[RES]]#1 + +// ----- + +func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + return %res1, %res2 : f32, f32 +} + +// %res1 is used as second scf.parallel arg, cannot fuse +// CHECK-LABEL: func @reductions_use_res +// CHECK: scf.parallel +// CHECK: scf.parallel + +// ----- + +func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum = arith.addf %B_elem, %res1 : f32 + scf.reduce(%sum) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + scf.yield + } + return %res1, %res2 : f32, f32 +} + +// %res1 is used inside second scf.parallel arg, cannot fuse +// CHECK-LABEL: func @reductions_use_res_inside +// CHECK: scf.parallel +// CHECK: scf.parallel From d858a4af347587be86ace6a619009fdc58b2d87c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 19 Dec 2023 18:17:04 +0100 Subject: [PATCH 2/6] typo --- mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 94ccbff4d8560..7644d1bafb183 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -693,7 +693,7 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) - return %res1, %res2 : f32, f32 } -// %res1 is used inside second scf.parallel arg, cannot fuse +// %res1 is used inside second scf.parallel, cannot fuse // CHECK-LABEL: func @reductions_use_res_inside // CHECK: scf.parallel // CHECK: scf.parallel From 47ec48ee15d319e6973f3422d1bc0a15d1901d84 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 19 Dec 2023 23:05:27 +0100 Subject: [PATCH 3/6] update test --- mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 7644d1bafb183..9ced6d932274e 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -606,7 +606,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3 } // CHECK-LABEL: func @fuse_reductions -// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -628,7 +628,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3 // CHECK: scf.reduce.return %[[R]] : f32 // CHECK: } // CHECK: scf.yield -// CHECK: return %[[RES]]#0, %[[RES]]#1 +// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 // ----- From 65a3b05d56c3674c136ba81bcb552c8d0d2cfb6e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 20 Dec 2023 18:29:17 +0100 Subject: [PATCH 4/6] Update to new reductions format --- .../SCF/Transforms/ParallelLoopFusion.cpp | 34 +++++-- .../Dialect/SCF/parallel-loop-fusion.mlir | 93 +++++++++++++++---- 2 files changed, 102 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 7d9e220518441..853b63f5adaf5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -192,18 +192,38 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); Block *newBlock = newSecondPloop.getBody(); - newBlock->getTerminator()->erase(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); - block1->getTerminator()->erase(); - - b.inlineBlockBefore(block1, newBlock, newBlock->end(), + b.inlineBlockBefore(block2, newBlock, newBlock->begin(), newBlock->getArguments()); - b.inlineBlockBefore(block2, newBlock, newBlock->end(), + b.inlineBlockBefore(block1, newBlock, newBlock->begin(), newBlock->getArguments()); ValueRange results = newSecondPloop.getResults(); - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + if (!results.empty()) { + b.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), + newRedBlock.getArguments()); + } + + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); firstPloop.erase(); secondPloop.erase(); secondPloop = newSecondPloop; diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 9ced6d932274e..d171f96811b10 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -578,7 +578,7 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var( // ----- -func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { +func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -586,26 +586,24 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3 %init2 = arith.constant 2.0 : f32 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - scf.reduce(%A_elem) : f32 { + scf.reduce(%A_elem : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.addf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - scf.reduce(%B_elem) : f32 { + scf.reduce(%B_elem : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.mulf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } return %res1, %res2 : f32, f32 } -// CHECK-LABEL: func @fuse_reductions +// CHECK-LABEL: func @fuse_reductions_two // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -616,22 +614,85 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3 // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) // CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] -// CHECK: scf.reduce(%[[VAL_A]]) : f32 { +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 // CHECK: scf.reduce.return %[[R]] : f32 // CHECK: } -// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] -// CHECK: scf.reduce(%[[VAL_B]]) : f32 { // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 // CHECK: scf.reduce.return %[[R]] : f32 // CHECK: } -// CHECK: scf.yield // CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 // ----- +func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %init3 = arith.constant 3.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 { + %A_elem = memref.load %C[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2, %res3 : f32, f32, f32 +} + +// CHECK-LABEL: func @fuse_reductions_three +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32 +// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32 + +// ----- + func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index @@ -639,21 +700,19 @@ func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, %init1 = arith.constant 1.0 : f32 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - scf.reduce(%A_elem) : f32 { + scf.reduce(%A_elem : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.addf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 { %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - scf.reduce(%B_elem) : f32 { + scf.reduce(%B_elem : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.mulf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } return %res1, %res2 : f32, f32 } @@ -673,22 +732,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) - %init2 = arith.constant 2.0 : f32 %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - scf.reduce(%A_elem) : f32 { + scf.reduce(%A_elem : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.addf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> %sum = arith.addf %B_elem, %res1 : f32 - scf.reduce(%sum) : f32 { + scf.reduce(%sum : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.mulf %lhs, %rhs : f32 scf.reduce.return %1 : f32 } - scf.yield } return %res1, %res2 : f32, f32 } From 083707eae339c3f916619626f4ca8aca10022195 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 4 Jan 2024 19:46:13 +0100 Subject: [PATCH 5/6] add comments --- mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 853b63f5adaf5..5934d85373b03 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -175,6 +175,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, return; DominanceInfo dom; + // We are fusing first loop into second, make sure there are no users of the + // first loop results between loops. for (Operation *user : firstPloop->getUsers()) if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) return; From 8f7b4a40ff443b25f8aac9736f581e90f612c0f6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 31 Jan 2024 23:49:43 +0100 Subject: [PATCH 6/6] more tests --- .../Dialect/SCF/parallel-loop-fusion.mlir | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index d171f96811b10..0d4ea6f20e8d9 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -24,6 +24,32 @@ func.func @fuse_empty_loops() { // ----- +func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.reduce + } + %res = arith.addf %A, %B : f32 + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.reduce + } + return %res : f32 +} +// CHECK-LABEL: func @fuse_ops_between +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32 +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: scf.reduce +// CHECK: } +// CHECK-NOT: scf.parallel + +// ----- + func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index @@ -754,3 +780,36 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) - // CHECK-LABEL: func @reductions_use_res_inside // CHECK: scf.parallel // CHECK: scf.parallel + +// ----- + +func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res3 = arith.addf %res1, %init2 : f32 + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2, %res3 : f32, f32, f32 +} + +// instruction in between the loops uses the first loop result +// CHECK-LABEL: func @reductions_use_res_between +// CHECK: scf.parallel +// CHECK: scf.parallel