From 08c38a85510e6a7f826486ed4c759690f1ac2a22 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sat, 15 Jun 2024 20:27:07 +0100 Subject: [PATCH] [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and `@transfer_read_flattenable_negative` duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both tests are deleted (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is preserved). 2. `@transfer_read_flattenable_negative2` is replaced with `@transfer_read_non_contiguous_src` and `@transfer_write_non_contiguous_src` (i.e. a dedicated test for xfer_read and xfer_read with more descriptive func names) Depends on https://github.com/llvm/llvm-project/pull/95743. **Only review the top commit.** --- .../Vector/vector-transfer-flatten.mlir | 116 +++++++----------- 1 file changed, 44 insertions(+), 72 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 40a8b7e5e0737..3a5041fca53fc 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_read_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { - - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> -} - -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -214,6 +195,28 @@ func.func @transfer_read_0d( // ----- +// Strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_read_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// vector.transfer_write /// [Pattern: FlattenContiguousRowMajorTransferWritePattern] @@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_write_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, - %vec : vector<2x1x2x2xi8>) { - - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return -} - -// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -427,6 +411,28 @@ func.func @transfer_write_0d( // ----- +// The strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_write_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, + %vec : vector<5x4x3x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>> + return +} + +// CHECK-LABEL: func.func @transfer_write_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// TODO: Categorize + re-format ///---------------------------------------------------------------------------------------- @@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto // ----- -func.func @transfer_read_flattenable_negative( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8> - return %v : vector<2x2x2x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative -// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @transfer_read_flattenable_negative2( - %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative2 -// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { %add = arith.addi %arg0, %arg0 : vector<1x8xi32> return %add : vector<1x8xi32>