diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index ac55433fadb2f..64bfb0fab743a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -68,9 +68,14 @@ enum class BroadcastableToResult { DimensionMismatch = 2, SourceTypeNotAVector = 3 }; + +struct VectorDim { + int64_t dim; + bool isScalable; +}; BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, - std::pair *mismatchingDims = nullptr); + std::pair *mismatchingDims = nullptr); /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 434ff3956c250..ead7faae46cbe 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -367,6 +367,8 @@ def Vector_BroadcastOp : s_1 x .. x s_j x .. x s_k ``` + * in addition, any scalable unit dimension, `[1]`, must match exactly. + The source operand is duplicated over all the missing leading dimensions and stretched over the trailing dimensions where the source has a non-equal dimension of 1. These rules imply that any scalar broadcast (k=0) to any diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5047bd925d4c5..22220a6672382 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp( return res; } -BroadcastableToResult -mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, - std::pair *mismatchingDims) { +BroadcastableToResult mlir::vector::isBroadcastableTo( + Type srcType, VectorType dstVectorType, + std::pair *mismatchingDims) { // Broadcast scalar to vector of the same element type. if (srcType.isIntOrIndexOrFloat() && dstVectorType && getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) @@ -2390,13 +2390,31 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, // Source has an exact match or singleton value for all trailing dimensions // (all leading dimensions are simply duplicated). int64_t lead = dstRank - srcRank; - for (int64_t r = 0; r < srcRank; ++r) { - int64_t srcDim = srcVectorType.getDimSize(r); - int64_t dstDim = dstVectorType.getDimSize(lead + r); - if (srcDim != 1 && srcDim != dstDim) { - if (mismatchingDims) { - mismatchingDims->first = srcDim; - mismatchingDims->second = dstDim; + for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) { + // Have mismatching dims (in the sense of vector.broadcast semantics) been + // encountered? + bool foundMismatchingDims = false; + + // Check fixed-width dims. + int64_t srcDim = srcVectorType.getDimSize(dimIdx); + int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx); + if (srcDim != 1 && srcDim != dstDim) + foundMismatchingDims = true; + + // Check scalable flags. + bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx]; + bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx]; + if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) || + (srcDimScalableFlag != dstDimScalableFlag)) + foundMismatchingDims = true; + + if (foundMismatchingDims) { + if (mismatchingDims != nullptr) { + mismatchingDims->first.dim = srcDim; + mismatchingDims->first.isScalable = srcDimScalableFlag; + + mismatchingDims->second.dim = dstDim; + mismatchingDims->second.isScalable = dstDimScalableFlag; } return BroadcastableToResult::DimensionMismatch; } @@ -2406,16 +2424,22 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, } LogicalResult BroadcastOp::verify() { - std::pair mismatchingDims; + std::pair mismatchingDims; BroadcastableToResult res = isBroadcastableTo( getSourceType(), getResultVectorType(), &mismatchingDims); if (res == BroadcastableToResult::Success) return success(); if (res == BroadcastableToResult::SourceRankHigher) return emitOpError("source rank higher than destination rank"); - if (res == BroadcastableToResult::DimensionMismatch) + if (res == BroadcastableToResult::DimensionMismatch) { return emitOpError("dimension mismatch (") - << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; + << (mismatchingDims.first.isScalable ? "[" : "") + << mismatchingDims.first.dim + << (mismatchingDims.first.isScalable ? "]" : "") << " vs. " + << (mismatchingDims.second.isScalable ? "[" : "") + << mismatchingDims.second.dim + << (mismatchingDims.second.isScalable ? "]" : "") << ")"; + } if (res == BroadcastableToResult::SourceTypeNotAVector) return emitOpError("source type is not a vector"); llvm_unreachable("unexpected vector.broadcast op error"); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 00914c1d1baf6..13712578536cd 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -35,6 +35,27 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { // ----- +func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) { + // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}} + %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32> +} + +// ----- + +func.func @broadcast_fixed_to_scalable(%arg0: vector<2xf32>) { + // expected-error@+1 {{'vector.broadcast' op dimension mismatch (2 vs. [2])}} + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32> +} + +// ----- + +func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) { + // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}} + %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32> +} + +// ----- + func.func @broadcast_unknown(%arg0: memref<4x8xf32>) { // expected-error@+1 {{'vector.broadcast' op source type is not a vector}} %1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>