Skip to content

Commit 77b8395

Browse files
committed
fixup! [mlir][vector] Extend mask calculation for vector.contract
1 parent 988b67e commit 77b8395

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,18 +912,18 @@ Type ContractionOp::getExpectedMaskType() {
912912

913913
unsigned numVecDims = lhsIdxMap.getNumDims();
914914
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
915-
SmallVector<bool> maskShapeScalabledims(numVecDims, false);
915+
SmallVector<bool> maskShapeScalableDims(numVecDims, false);
916916

917917
// Using the information in the indexing maps, extract the size of each
918918
// dimension in the vector.contract operation from the two input operands.
919919
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
920920
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
921-
maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
921+
maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
922922
lhsType.getScalableDims()[dimIdx];
923923
}
924924
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
925925
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
926-
maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
926+
maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
927927
rhsType.getScalableDims()[dimIdx];
928928
}
929929

@@ -932,7 +932,7 @@ Type ContractionOp::getExpectedMaskType() {
932932

933933
return VectorType::get(maskShape,
934934
IntegerType::get(lhsType.getContext(), /*width=*/1),
935-
maskShapeScalabledims);
935+
maskShapeScalableDims);
936936
}
937937

938938
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -989,12 +989,17 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
989989
indexing_maps = #matmat_accesses,
990990
iterator_types = ["parallel", "parallel", "reduction"]
991991
}
992-
func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
993-
%arg1: vector<4x[8]xf32>,
994-
%arg2: vector<3x[8]xf32>,
995-
%m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
996-
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
997-
: vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
992+
// CHECK-LABEL: func.func @contraction_masked_scalable(
993+
// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>,
994+
// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>,
995+
// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>,
996+
// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
997+
func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
998+
%B: vector<4x[8]xf32>,
999+
%C: vector<3x[8]xf32>,
1000+
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
1001+
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
1002+
%0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
1003+
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
9981004
return %0 : vector<3x[8]xf32>
9991005
}
1000-

mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,3 @@ transform.sequence failures(propagate) {
6060
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
6161
} : !transform.any_op
6262
}
63-

0 commit comments

Comments
 (0)