From 4f5ccf9560855438a68341b461917834ab60ecc0 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 11 Jun 2025 19:42:47 +0000 Subject: [PATCH 1/4] [mlir][Vector] Add simple folders for `vector.from_element`/`vector.to_elements` This PR adds simple folders to remove no-op sequences of `vector.from_elements` and `vector.to_elements`. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 2 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 90 +++++++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 52 +++++++++++ 3 files changed, 144 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 125cd4645ccc2..7c44cfbde0367 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -836,6 +836,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ let arguments = (ins AnyVectorOfAnyRank:$source); let results = (outs Variadic:$elements); let assemblyFormat = "$source attr-dict `:` type($source)"; + let hasFolder = 1; } def Vector_FromElementsOp : Vector_Op<"from_elements", [ @@ -873,6 +874,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [ let arguments = (ins Variadic:$elements); let results = (outs AnyFixedVectorOfAnyRank:$dest); let assemblyFormat = "$elements attr-dict `:` type($dest)"; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e576eeac23656..7482b6a22c400 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2373,10 +2373,100 @@ std::optional> FMAOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +//===----------------------------------------------------------------------===// +// ToElementsOp +//===----------------------------------------------------------------------===// + +/// Returns true if all the `operands` are defined by `defOp`. +/// Otherwise, returns false. +static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) { + if (operands.empty()) + return false; + + return llvm::all_of(operands, [&](Value operand) { + Operation *currentDef = operand.getDefiningOp(); + return currentDef == defOp; + }); +} + +/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into +/// (%e0, %e1, ...). For example: +/// +/// %0 = vector.from_elements %a, %b, %c : vector<3xf32> +/// %1:3 = vector.to_elements %0 : vector<3xf32> +/// user_op %1#0, %1#1, %1#2 +/// +/// becomes: +/// +/// user_op %a, %b, %c +/// +static LogicalResult +foldToElementsFromElements(ToElementsOp toElementsOp, + SmallVectorImpl &results) { + auto fromElementsOp = toElementsOp.getSource().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + results.append(fromElementsOp.getElements().begin(), + fromElementsOp.getElements().end()); + return success(); +} + +LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + if (succeeded(foldToElementsFromElements(*this, results))) + return success(); + return failure(); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// +/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector. +/// +/// Case #1: Input and output vectors are the same. +/// +/// %0:3 = vector.to_elements %a : vector<3xf32> +/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32> +/// user_op %1 +/// +/// becomes: +/// +/// user_op %a +/// +static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) { + auto fromElemsOperands = fromElementsOp.getElements(); + + if (fromElemsOperands.empty()) + return {}; + + auto toElementsOp = fromElemsOperands[0].getDefiningOp(); + if (!toElementsOp) + return {}; + + if (!haveSameDefiningOp(fromElemsOperands, toElementsOp)) + return {}; + + // Case #1: Input and output vectors are the same. Forward the input vector. + Value toElementsInput = toElementsOp.getSource(); + if (fromElementsOp.getType() == toElementsInput.getType() && + llvm::equal(fromElemsOperands, toElementsOp.getResults())) { + return toElementsInput; + } + + // TODO: Support cases with different input and output shapes and different + // number of elements. + + return {}; +} + +OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { + if (auto result = foldFromElementsToElements(*this)) + return result; + return {}; +} + /// Rewrite a vector.from_elements into a vector.splat if all elements are the /// same SSA value. E.g.: /// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 6691cb52acdc0..65b73375831da 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3023,6 +3023,58 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, // ----- +// CHECK-LABEL: func @to_elements_from_elements_no_op( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32 +func.func @to_elements_from_elements_no_op(%a: f32, %b: f32) -> (f32, f32) { + // CHECK-NOT: vector.from_elements + // CHECK-NOT: vector.to_elements + %0 = vector.from_elements %b, %a : vector<2xf32> + %1:2 = vector.to_elements %0 : vector<2xf32> + // CHECK: return %[[B]], %[[A]] + return %1#0, %1#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_elements_no_op( +// CHECK-SAME: %[[A:.*]]: vector<4x2xf32> +func.func @from_elements_to_elements_no_op(%a: vector<4x2xf32>) -> vector<4x2xf32> { + // CHECK-NOT: vector.from_elements + // CHECK-NOT: vector.to_elements + %0:8 = vector.to_elements %a : vector<4x2xf32> + %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<4x2xf32> + // CHECK: return %[[A]] + return %1 : vector<4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_elements_dup_elems( +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +func.func @from_elements_to_elements_dup_elems(%a: vector<4xf32>) -> vector<4x2xf32> { + // CHECK: %[[TO_EL:.*]]:4 = vector.to_elements %[[A]] + // CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#0, %[[TO_EL]]#1, %[[TO_EL]]#2 + %0:4 = vector.to_elements %a : vector<4xf32> // 4 elements + %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#0, %0#1, %0#2, %0#3 : vector<4x2xf32> + // CHECK: return %[[FROM_EL]] + return %1 : vector<4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_elements_shuffle( +// CHECK-SAME: %[[A:.*]]: vector<4x2xf32> +func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2xf32> { + // CHECK: %[[TO_EL:.*]]:8 = vector.to_elements %[[A]] + // CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#7, %[[TO_EL]]#0, %[[TO_EL]]#6 + %0:8 = vector.to_elements %a : vector<4x2xf32> + %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<4x2xf32> + // CHECK: return %[[FROM_EL]] + return %1 : vector<4x2xf32> +} + +// ----- + // CHECK-LABEL: func @vector_insert_const_regression( // CHECK: llvm.mlir.undef // CHECK: vector.insert From 7fc28b25b3a73b331004f55a451bfbf7d0775965 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 18 Jun 2025 20:45:54 +0000 Subject: [PATCH 2/4] Feedback --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7482b6a22c400..fab49e27562ac 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2407,16 +2407,13 @@ foldToElementsFromElements(ToElementsOp toElementsOp, if (!fromElementsOp) return failure(); - results.append(fromElementsOp.getElements().begin(), - fromElementsOp.getElements().end()); + llvm::append_range(results, fromElementsOp.getElements()); return success(); } LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - if (succeeded(foldToElementsFromElements(*this, results))) - return success(); - return failure(); + return foldToElementsFromElements(*this, results); } //===----------------------------------------------------------------------===// @@ -2462,9 +2459,7 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) { } OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { - if (auto result = foldFromElementsToElements(*this)) - return result; - return {}; + return foldFromElementsToElements(*this); } /// Rewrite a vector.from_elements into a vector.splat if all elements are the From 5561485a5a9f892e3ccc3dd0a1aafe955ae57a6d Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 18 Jun 2025 20:54:48 +0000 Subject: [PATCH 3/4] Fix format --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index fab49e27562ac..25341b62634a7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2403,7 +2403,8 @@ static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) { static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl &results) { - auto fromElementsOp = toElementsOp.getSource().getDefiningOp(); + auto fromElementsOp = + toElementsOp.getSource().getDefiningOp(); if (!fromElementsOp) return failure(); From f725c325f5fd27d9317228c439153f7514d2566c Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 18 Jun 2025 20:58:56 +0000 Subject: [PATCH 4/4] Fix auto --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 25341b62634a7..6f0ac6bb58282 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2434,8 +2434,7 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, /// user_op %a /// static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) { - auto fromElemsOperands = fromElementsOp.getElements(); - + OperandRange fromElemsOperands = fromElementsOp.getElements(); if (fromElemsOperands.empty()) return {};