diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 0cd134717b1a0..35418b38df9b2 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -1,14 +1,87 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s ///---------------------------------------------------------------------------------------- -/// vector.transfer_write +/// vector.transfer_write -> vector.transpose + vector.transfer_write +/// [Pattern: TransferWritePermutationLowering] ///---------------------------------------------------------------------------------------- -/// Input: -/// * vector.transfer_write op with a map which _is not_ the permutation of a -/// minor identity +/// Input: +/// * vector.transfer_write op with a permutation that under a transpose +/// _would be_ a minor identity permutation map /// Output: -/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a +/// * vector.transpose + vector.transfer_write with a permutation map which +/// _is_ a minor identity + +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>, +// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) { +// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16> +// CHECK: vector.transfer_write +// CHECK-NOT: permutation_map +// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16> +func.func @xfer_write_transposing_permutation_map( + %arg0: vector<4x8xi16>, + %mem: memref<2x2x8x4xi16>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x8xi16>, memref<2x2x8x4xi16> + + return +} + +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>, +// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>, +// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) { +// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16> +// CHECK: vector.transfer_write +// CHECK-NOT: permutation_map +// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16> +func.func @xfer_write_transposing_permutation_map_with_mask_scalable( + %arg0: vector<4x[8]xi16>, + %mem: memref<2x2x?x4xi16>, + %mask: vector<[8]x4xi1>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x[8]xi16>, memref<2x2x?x4xi16> + + return +} + +// Masked version is not supported +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked +// CHECK-NOT: vector.transpose +func.func @xfer_write_transposing_permutation_map_masked( + %arg0: vector<4x8xi16>, + %mem: memref<2x2x8x4xi16>, + %mask: vector<8x4xi1>) { + + %c0 = arith.constant 0 : index + vector.mask %mask { + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x8xi16>, memref<2x2x8x4xi16> + } : vector<8x4xi1> + + return +} + +///---------------------------------------------------------------------------------------- +/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write +/// [Patterns: TransferWriteNonPermutationLowering + TransferWritePermutationLowering] +///---------------------------------------------------------------------------------------- +/// Input: +/// * vector.transfer_write op with a map which _is not_ a permutation of a /// minor identity +/// Output: +/// * vector.broadcast + vector.transpose + vector.transfer_write with a map +/// which _is_ a permutation of a minor identity // CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width( // CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32> @@ -94,7 +167,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width( ///---------------------------------------------------------------------------------------- /// vector.transfer_read ///---------------------------------------------------------------------------------------- -/// Input: +/// Input: /// * vector.transfer_read op with a permutation map /// Output: /// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy + @@ -190,6 +263,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// vector.transfer_read +///---------------------------------------------------------------------------------------- +/// TODO: Review and categorize // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> // CHECK: func.func @transfer_read_reduce_rank_scalable(