Skip to content

Commit 988b67e

Browse files
committed
[mlir][vector] Extend mask calculation for vector.contract
Make sure that when calculating the expected mask for `vector.contract`, scalable sizes are correctly taken into account.
1 parent 53b6a16 commit 988b67e

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
912912

913913
unsigned numVecDims = lhsIdxMap.getNumDims();
914914
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
915+
SmallVector<bool> maskShapeScalabledims(numVecDims, false);
915916

916917
// Using the information in the indexing maps, extract the size of each
917918
// dimension in the vector.contract operation from the two input operands.
918-
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
919+
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
919920
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
920-
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
921+
maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
922+
lhsType.getScalableDims()[dimIdx];
923+
}
924+
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
921925
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
926+
maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
927+
rhsType.getScalableDims()[dimIdx];
928+
}
922929

923930
assert(!ShapedType::isDynamicShape(maskShape) &&
924931
"Mask shape couldn't be computed");
925-
// TODO: Extend the scalable vector type representation with a bit map.
926-
assert(!lhsType.isScalable() && !rhsType.isScalable() &&
927-
"Scalable vectors are not supported yet");
928932

929933
return VectorType::get(maskShape,
930-
IntegerType::get(lhsType.getContext(), /*width=*/1));
934+
IntegerType::get(lhsType.getContext(), /*width=*/1),
935+
maskShapeScalabledims);
931936
}
932937

933938
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
979979
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
980980
return
981981
}
982+
983+
#matmat_accesses = [
984+
affine_map<(i, j, k) -> (i, k)>,
985+
affine_map<(i, j, k) -> (k, j)>,
986+
affine_map<(i, j, k) -> (i, j)>
987+
]
988+
#matmat_trait = {
989+
indexing_maps = #matmat_accesses,
990+
iterator_types = ["parallel", "parallel", "reduction"]
991+
}
992+
func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
993+
%arg1: vector<4x[8]xf32>,
994+
%arg2: vector<3x[8]xf32>,
995+
%m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
996+
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
997+
: vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
998+
return %0 : vector<3x[8]xf32>
999+
}
1000+

mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@ transform.sequence failures(propagate) {
6060
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
6161
} : !transform.any_op
6262
}
63+

0 commit comments

Comments
 (0)