diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2aaf1cb7e5878..6473c92a91aa6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape( return e.cast().getValue(); })); auto expected = - VectorType::get(expectedShape, resVectorType.getElementType()); + VectorType::get(expectedShape, resVectorType.getElementType(), + resVectorType.getScalableDims()); if (resVectorType != expected || accVectorType != expected) return op.emitOpError( "invalid accumulator/result vector shape, expected: ") diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 2154304965a5d..f00bc6e97b350 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) - return %0 : f32 } +// CHECK-LABEL: @contraction_to_scalar_scalable +func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 { + // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 + %f0 = arith.constant 0.0: f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32 + %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 + : vector<[10]xf32>, vector<[10]xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + // CHECK-LABEL: @contraction_extra_attrs func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 @@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3 return } +#contraction_matmul_accesses = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] +#contraction_matmul_trait = { + indexing_maps = #contraction_matmul_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} +// CHECK-LABEL: @contraction_matmul_scalable +func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> { + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32> + %res = vector.contract #contraction_matmul_trait %A, %B, %C + : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32> + // CHECK: return %[[X]] : vector<8x[32]xf32> + return %res : vector<8x[32]xf32> +} + // CHECK-LABEL: @create_vector_mask func.func @create_vector_mask() { // CHECK: %[[C2:.*]] = arith.constant 2 : index