diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d5f3634377e4c..fe2793e75aa60 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2033,7 +2033,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef staticPos, int64_t poisonVal) { - if (!llvm::is_contained(staticPos, poisonVal)) + if (!is_contained(staticPos, poisonVal)) return {}; return ub::PoisonAttr::get(context); @@ -2041,12 +2041,63 @@ static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, /// Fold a vector extract from is a poison source. static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) { - if (llvm::isa_and_nonnull(srcAttr)) + if (isa_and_nonnull(srcAttr)) return srcAttr; return {}; } +/// Fold a vector extract extracting from a DenseElementsAttr. +static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, + Attribute srcAttr) { + auto denseAttr = dyn_cast_if_present(srcAttr); + if (!denseAttr) { + return {}; + } + + if (denseAttr.isSplat()) { + Attribute newAttr = denseAttr.getSplatValue(); + if (auto vecDstType = dyn_cast(extractOp.getType())) + newAttr = DenseElementsAttr::get(vecDstType, newAttr); + return newAttr; + } + + auto vecTy = cast(extractOp.getSourceVectorType()); + if (vecTy.isScalable()) + return {}; + + if (extractOp.hasDynamicPosition()) { + return {}; + } + + // Materializing subsets of a large constant array can generally lead to + // explosion in IR size because of different combination of subsets that + // can exist. However, vector.extract is a restricted form of subset + // extract where you can only extract non-overlapping (or the same) subset for + // a given rank of the subset. Because of this property, the IR size can only + // increase at most by `rank * size(array)` from a single constant array being + // extracted by multiple extracts. + + // Calculate the linearized position of the continuous chunk of elements to + // extract. + SmallVector completePositions(vecTy.getRank(), 0); + copy(extractOp.getStaticPosition(), completePositions.begin()); + int64_t startPos = + linearize(completePositions, computeStrides(vecTy.getShape())); + auto denseValuesBegin = denseAttr.value_begin() + startPos; + + TypedAttr newAttr; + if (auto resVecTy = dyn_cast(extractOp.getType())) { + SmallVector elementValues( + denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); + newAttr = DenseElementsAttr::get(resVecTy, elementValues); + } else { + newAttr = *denseValuesBegin; + } + + return newAttr; +} + OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type @@ -2058,6 +2109,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return res; if (auto res = foldPoisonSrcExtractOp(adaptor.getVector())) return res; + if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector())) + return res; if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) @@ -2121,80 +2174,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { } }; -// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. -class ExtractOpSplatConstantFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp extractOp, - PatternRewriter &rewriter) const override { - // Return if 'ExtractOp' operand is not defined by a splat vector - // ConstantOp. - Value sourceVector = extractOp.getVector(); - Attribute vectorCst; - if (!matchPattern(sourceVector, m_Constant(&vectorCst))) - return failure(); - auto splat = llvm::dyn_cast(vectorCst); - if (!splat) - return failure(); - TypedAttr newAttr = splat.getSplatValue(); - if (auto vecDstType = llvm::dyn_cast(extractOp.getType())) - newAttr = DenseElementsAttr::get(vecDstType, newAttr); - rewriter.replaceOpWithNewOp(extractOp, newAttr); - return success(); - } -}; - -// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp. -class ExtractOpNonSplatConstantFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp extractOp, - PatternRewriter &rewriter) const override { - // TODO: Canonicalization for dynamic position not implemented yet. - if (extractOp.hasDynamicPosition()) - return failure(); - - // Return if 'ExtractOp' operand is not defined by a compatible vector - // ConstantOp. - Value sourceVector = extractOp.getVector(); - Attribute vectorCst; - if (!matchPattern(sourceVector, m_Constant(&vectorCst))) - return failure(); - - auto vecTy = llvm::cast(sourceVector.getType()); - if (vecTy.isScalable()) - return failure(); - - // The splat case is handled by `ExtractOpSplatConstantFolder`. - auto dense = llvm::dyn_cast(vectorCst); - if (!dense || dense.isSplat()) - return failure(); - - // Calculate the linearized position of the continuous chunk of elements to - // extract. - llvm::SmallVector completePositions(vecTy.getRank(), 0); - copy(extractOp.getStaticPosition(), completePositions.begin()); - int64_t elemBeginPosition = - linearize(completePositions, computeStrides(vecTy.getShape())); - auto denseValuesBegin = dense.value_begin() + elemBeginPosition; - - TypedAttr newAttr; - if (auto resVecTy = llvm::dyn_cast(extractOp.getType())) { - SmallVector elementValues( - denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); - newAttr = DenseElementsAttr::get(resVecTy, elementValues); - } else { - newAttr = *denseValuesBegin; - } - - rewriter.replaceOpWithNewOp(extractOp, newAttr); - return success(); - } -}; - // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern { public: @@ -2332,8 +2311,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index e66fbe968d9b0..cd83e1239fdda 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic( // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex> -// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> -// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex> -// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex> -// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex> - -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> // CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> // ----- @@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 - -// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex> -// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex> -// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> -// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_9]] : tensor<1x4xf32> // CHECK: } @@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32 // CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> -// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex> -// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector // CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector to vector<1x1x4xi32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32> diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 99b1bbab1eede..8e5ddbfffcdd9 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector // ----- // ALL-LABEL: test_vector_extract_scalar -func.func @test_vector_extract_scalar() { +func.func @test_vector_extract_scalar(%idx : index) { %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> // ALL-NOT: vector.shuffle // ALL: vector.extract // ALL-NOT: vector.shuffle - %0 = vector.extract %cst[0] : i32 from vector<4xi32> + %0 = vector.extract %cst[%idx] : i32 from vector<4xi32> return } diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir index c5cb09b9aa9f9..b4ebb14b8829e 100644 --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref, %idx: index, %idx2: // CHECK-LABEL: func @transfer_write_arith_constant( // CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index -// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32> -// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32> -// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] func.func @transfer_write_arith_constant(%m: memref, %idx: index) { %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32> vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 20e9400ed698d..5be267c1be984 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>, // CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>, // CHECK-SAME: %[[VAL_4:.*]]: index, // CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> { +// CHECK: %[[TRUE:.*]] = arith.constant true // CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<4xi1> // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32> // CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex> -// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1> // CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex> -// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>) +// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>) // CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32> -// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1> // CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex> -// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>) +// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>) // CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32> -// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1> // CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex> -// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>) +// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>) // CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32> -// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1> // CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex> -// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>) +// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>) // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>