From 303b01dc8cbc68bfb5e11b3c317dae6947c93f16 Mon Sep 17 00:00:00 2001 From: Amir Bishara Date: Wed, 4 Dec 2024 23:50:29 +0200 Subject: [PATCH 1/2] [mlir][bufferization]-Add lit tests for unhandled cases in EmptyTensorElimination In many cases the emptyTensorElimination can not transform or eliminate the empty tensor which is being inserted into the `SubsetInsertionOpInterface`. Two major reasons for that: 1- Failing when trying to find a legal/suitable insertion point for the `subsetExtract` which is about to replace the empty tensor. However, we may try to handle this issue by moving the needed values which responsible on building the `subsetExtract` nearby the empty tensor (which is about to be eliminated). Thus increasing the probability to find a legal insertion point. 2-The EmptyTensorElimination transform replaces the tensor.empty's uses all at once in one apply, rather than replacing only the specific use which was visited in the use-def chain (when traversing from the tensor.insert_slice). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case. Both cases may result in many copies in the coming bufferization which can not be canonicalized. The first case can be noticed when having a `tensor.empty` followed by `SubsetInsertionOpInterface` (or in simple words `tensor.insert_slice`), which have been lowered from `tensor/tosa.concat`. The second case can be noticed when having a `tensor.empty`, with many uses and leading to applying the transformation only once, since the whole uses have been replaced at once. This MR only adds the lit tests for the cases shown above (NFC), to emphasize how the transform works, in the coming MRs will upload a slight changes to handle these cases in a more generic way. --- ...ot-bufferize-empty-tensor-elimination.mlir | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index efe59af97d964..9d9bb44316046 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -365,3 +365,101 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32 bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> () return } + +// ----- + +// `EmptyTensorElimination` fails to find a valid insertion +// point for the new injected `SubsetExtraction`. +// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors +func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + // CHECK: memref.alloc + // CHECK: memref.alloc + // CHECK: memref.alloc + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %empty_2 = tensor.empty() : tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor +func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + // CHECK: memref.alloc + // CHECK: memref.alloc + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %empty_2 = tensor.empty() : tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// `EmptyTensorElimination` replaces all of the uses of the tensor +// empty with the new injected `SubsetExtraction`, without to consider +// the specific use has been tracked, sometimes creating a non existent +// bufferization conflicts. + +// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty +// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty +func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice + // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]] + // CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + // CHECK: memref.copy + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read +func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>) + -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) { + %cst_1 = arith.constant 1.0 : f32 + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + // CHECK: memref.alloc + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %res_2 = linalg.generic{ + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"] + } + ins(%empty_1 : tensor<5x6x64xf32>) + outs(%arg2 :tensor<5x6x64xf32>) { + ^bb0(%in: f32, %out: f32): + %res = arith.addf %in, %in : f32 + linalg.yield %res : f32 + } -> tensor<5x6x64xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32> +} From 11506c3394c950ad5513dbd0a5be75bc837c50dd Mon Sep 17 00:00:00 2001 From: Amir Bishara Date: Fri, 6 Dec 2024 13:27:20 +0200 Subject: [PATCH 2/2] [mlir][bufferization]-Replace only one use in TensorEmptyElimination This MR hanldes the second case where we want to replace only the specific use which was visited in the `use-def` chain (when traversing from the tensor.insert_slice's source). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case, Thus eliminating a potential copies. --- .../IR/BufferizableOpInterface.h | 6 +- .../IR/BufferizableOpInterface.cpp | 8 ++- .../Transforms/EmptyTensorElimination.cpp | 56 +++++++++++-------- ...ize-analysis-empty-tensor-elimination.mlir | 1 + ...ot-bufferize-empty-tensor-elimination.mlir | 20 ++++--- 5 files changed, 54 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 4866e31b19d5d..983f7a29cb220 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -459,7 +459,8 @@ class AnalysisState { /// Starting from `value`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which /// `condition` evaluates to true. OpOperands of such matching Values are not - /// traversed any further. + /// traversed any further, the visited aliasing opOperands will be preserved + /// through `visitedOpOperands`. /// /// When reaching the end of a chain, also return the last Value of that /// chain if `config.alwaysIncludeLeaves` is set. @@ -484,7 +485,8 @@ class AnalysisState { /// `config`. SetVector findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - TraversalConfig config = TraversalConfig()) const; + TraversalConfig config = TraversalConfig(), + llvm::DenseSet *visitedOpOperands = nullptr) const; /// Find the values that may define the contents of the given value at /// runtime. A block argument is always a definition. An OpResult is a diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 065739ea8e595..f8a7a22787404 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const { // Starting from `value`, follow the use-def chain in reverse, always selecting // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any -// further. +// further, the visited aliasing opOperands will be preserved through +// `visitedOpOperands`. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - TraversalConfig config) const { + TraversalConfig config, + llvm::DenseSet *visitedOpOperands) const { llvm::DenseSet visited; llvm::SetVector result, workingSet; workingSet.insert(value); @@ -553,6 +555,8 @@ llvm::SetVector AnalysisState::findValueInReverseUseDefChain( } workingSet.insert(a.opOperand->get()); + if (visitedOpOperands) + visitedOpOperands->insert(a.opOperand); } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index cb2efef5c038b..abc0635a2cdff 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, return true; } -/// Return true if the given `insertionPoint` dominates all uses of -/// `emptyTensorOp`. -static bool insertionPointDominatesUses(const DominanceInfo &domInfo, - Operation *insertionPoint, - Operation *emptyTensorOp) { - return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) { - return domInfo.dominates(insertionPoint, user); - }); -} - -/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming -/// that the replacement may use any value from `neededValues`. +/// Find a valid insertion point for a replacement of `emptyTensorOp`'s +/// use of `user` operation, assuming that the replacement may use any +/// value from `neededValues`. static Operation * -findValidInsertionPoint(Operation *emptyTensorOp, +findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector &neededValues) { DominanceInfo domInfo; + Operation *candidateInsertionPoint = emptyTensorOp; - // Gather all possible insertion points: the location of `emptyTensorOp` and - // right after the definition of each value in `neededValues`. + // Gather all possible insertion points: the location of + // `candidateInsertionPoint` and right after the definition of each value in + // `neededValues`. SmallVector insertionPointCandidates; - insertionPointCandidates.push_back(emptyTensorOp); + insertionPointCandidates.push_back(candidateInsertionPoint); for (Value val : neededValues) { // Note: The anchor op is using all of `neededValues`, so: // * in case of a block argument: There must be at least one op in the block @@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp, if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, neededValues)) continue; - // Check if the insertion point is before all uses. - if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp)) + // Check if the insertion point is before the use to be replaced. + if (!domInfo.dominates(insertionPoint, user)) continue; return insertionPoint; } @@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp, LogicalResult mlir::bufferization::eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { OpBuilder::InsertionGuard g(rewriter); - + llvm::DenseSet visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { + visitedOpOperands.clear(); OpOperand &source = op.getSourceOperand(); // Skip operands that do not bufferize inplace. "tensor.empty" could still // be replaced, but the transformation may not be beneficial. @@ -131,16 +125,28 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( config.followSameTypeOrCastsOnly = true; SetVector emptyTensors = state.findValueInReverseUseDefChain( source.get(), /*condition=*/ - [&](Value val) { return val.getDefiningOp(); }, - config); + [&](Value val) { return val.getDefiningOp(); }, config, + &visitedOpOperands); for (Value v : emptyTensors) { Operation *emptyTensorOp = v.getDefiningOp(); + // Find the use to be replaced from the use-def chain. + auto iter = llvm::find_if( + visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { + return llvm::count(emptyTensorOp->getUses(), *opOperand); + }); + // This could be achieved when a use of `emptyTensorOp` is being + // consumed by `SubsetInsertionOpInterface`'s source directly. + if (iter == visitedOpOperands.end()) + continue; + OpOperand *useToBeReplaced = *iter; + Operation *user = useToBeReplaced->getOwner(); + // Find a suitable insertion point. If no suitable insertion point for // the replacement can be found, skip this replacement. Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, neededValues); + findValidInsertionPoint(emptyTensorOp, user, neededValues); if (!insertionPoint) continue; @@ -159,8 +165,10 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( replacement = rewriter.create(v.getLoc(), v.getType(), replacement); } - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensorOp, replacement); + // Replace the specific use of the tensor::EmptyOp. + rewriter.modifyOpInPlace(user, [&]() { + user->setOperand(useToBeReplaced->getOperandNumber(), replacement); + }); state.resetCache(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir index 2ba8246a8d525..9150986f4c2a2 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir @@ -55,6 +55,7 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor< // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"] %cst = arith.constant 0.000000e+00 : f32 + // CHECK: bufferization.alloc_tensor(%arg1) %0 = tensor.empty(%arg1) : tensor // CHECK: bufferization.alloc_tensor(%arg1) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 9d9bb44316046..26434774730e1 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -396,8 +396,9 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> // CHECK: memref.alloc - // CHECK: memref.alloc + // CHECK-NOT: memref.alloc %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> %empty_1 = tensor.empty() : tensor<5x6x64xf32> %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> @@ -413,10 +414,9 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { // ----- -// `EmptyTensorElimination` replaces all of the uses of the tensor -// empty with the new injected `SubsetExtraction`, without to consider -// the specific use has been tracked, sometimes creating a non existent -// bufferization conflicts. +// `EmptyTensorElimination` will replace the specific use of the tensor +// empty with the new injected `SubsetExtraction`, i.e. the specific use +// which has been tracked. // CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty // CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty @@ -427,13 +427,13 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { %empty_1 = tensor.empty() : tensor<5x6x64xf32> // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]] - // CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] + // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> // CHECK: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> - // CHECK: memref.copy + // CHECK-NOT: memref.copy %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_2 : tensor<5x6x128xf32> @@ -442,11 +442,13 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { // ----- // CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read +// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>) -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) { %cst_1 = arith.constant 1.0 : f32 %empty_1 = tensor.empty() : tensor<5x6x64xf32> - // CHECK: memref.alloc + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32> + // CHECK-NOT: memref.alloc %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %res_2 = linalg.generic{ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], @@ -458,7 +460,7 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t %res = arith.addf %in, %in : f32 linalg.yield %res : f32 } -> tensor<5x6x64xf32> - // CHECK: memref.copy + // CHECK-NOT: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>