From 4f93456498d2c83c1f1d7fe1cfbfb5e0bcc33629 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 6 Oct 2023 14:40:00 +0000 Subject: [PATCH 1/6] [mlir][ArmSVE] Add convert_to/from_svbool ops This adds slightly higher-level ops for converting masks between svbool and SVE predicate types. The main reason to use these over the intrinsics is these ops support vectors of masks (via unrolling). E.g. ``` // Convert a svbool mask to a mask of SVE predicates: %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> // => Results in vector<2x[8]xi1> ``` Or: ``` // Convert a mask of SVE predicates to a svbool mask: %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ``` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 67 +++++++++++++++++ mlir/include/mlir/IR/CommonTypeConstraints.td | 43 +++++++++++ mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp | 1 + mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt | 1 + .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../Transforms/LegalizeForLLVMExport.cpp | 62 +++++++++++++++- mlir/test/Dialect/ArmSVE/invalid.mlir | 51 +++++++++++++ .../Dialect/ArmSVE/legalize-for-llvm.mlir | 73 ++++++++++++++++++- mlir/test/Dialect/ArmSVE/roundtrip.mlir | 49 ++++++++++++- 9 files changed, 343 insertions(+), 5 deletions(-) create mode 100644 mlir/test/Dialect/ArmSVE/invalid.mlir diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index d4294b4dd9fd4..fa7f6d080d5a9 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect { This dialect contains the definitions necessary to target specific Arm SVE scalable vector operations. }]; + + let dependentDialects = ["vector::VectorDialect"]; } //===----------------------------------------------------------------------===// @@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType< def SVEPredicate : ScalableVectorOfRankAndLengthAndType< [1], [16, 8, 4, 2, 1], [I1]>; +// Generalizations of SVBool and SVEPredicate to ranks >= 1. +// These are masks with a single trailing scalable dimension. +def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>; +def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>; + //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } + +class SvboolTypeContraint : TypesMatchWith< + "expected corresponding svbool type widened to [16]xi1", + lhsArg, rhsArg, + "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; + +def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", + [Pure, SvboolTypeContraint<"result", "source">]> +{ + let summary = "Convert a svbool type to a SVE predicate type"; + let description = [{ + Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g. + `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing + dimension can be scalable. + + Example 1: Convert a 1-D svbool mask to a SVE predicate. + ```mlir + %svbool = vector.load %memref[%c0] : memref, vector<[16]xi1> + %mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1> + ``` + + Example 2: Convert a 2-D svbool mask to a mask of SVE predicates. + ```mlir + %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> + %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> + ``` + }]; + let arguments = (ins SVBoolMask:$source); + let results = (outs SVEMask:$result); + let assemblyFormat = "$source attr-dict `:` type($result)"; +} + +def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool", + [Pure, SvboolTypeContraint<"source", "result">]> +{ + let summary = "Convert a predicate type to a svbool type"; + let description = [{ + Converts SVE predicate types (or vectors of predicate types, e.g. + `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can + be scalable. + + Example 1: Convert a 1-D SVE predicate to a svbool mask. + ```mlir + %mask = vector.create_mask %dim_size : vector<[4]xi1> + %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> + // => Results in vector<[16]xi1> + ``` + + Example 2: Convert a 2-D mask of SVE predicates to a svbool mask. + ```mlir + %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> + %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> + // => Results in vector<2x[16]xi1> + ``` + }]; + let arguments = (ins SVEMask:$source); + let results = (outs SVBoolMask:$result); + let assemblyFormat = "$source attr-dict `:` type($source)"; +} + def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", [Commutative]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 4fc14e30b8a10..54a5a97fe2b64 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && ::llvm::cast($_self).isScalable()}]>; +// Whether a type is a scalable VectorType, with a single trailing scalable dimension. +def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, + CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>; + // Whether a type is a VectorType and all dimensions are scalable. def allDimsScalableVectorTypePred : And<[ IsVectorTypePred, @@ -404,6 +410,10 @@ class ScalableVectorOf allowedTypes> : ShapedContainerType; +class TrailingScalableVectorOf allowedTypes> : + ShapedContainerType; + // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : @@ -481,10 +491,32 @@ class IsScalableVectorOfLengthPred allowedLengths> : == }] # allowedlength>)>]>; +class abs { + int ret = !if(!lt(value, 0), !sub(0, value), value); +} + +// Whether the n-th (starting from 1) dim of the shape matches the given `size`. +// Negative values index in reverse. +class IsNthDimSizeIsOneOfPred allowedSizes> + : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs.ret>, + CPred<"::llvm::is_contained(ArrayRef({" # !interleave(allowedSizes, ", ") # "}), " + # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" + # !if(!lt(n, 0), + "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, + "" # !sub(n, 1)) + # "))">]>; + // Whether the shape of a vector matches the given `shape` list. class IsVectorOfShape shape> : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; +// Any ShapedType where the size of the n-th dim is contained in `sizes`. +// Negative values index in reverse. +class ShapedTypeWithNthDimOfSize allowedSizes> : Type< + IsNthDimSizeIsOneOfPred, + " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", + "::mlir::ShapedType">; + // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< @@ -546,6 +578,17 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Any scalable vector with a single trailing scalable dimensions, where the +// size of the trailing dimension is in `allowedTrailingSizes` list, and the +// type is in the `allowedTypes` list. +class TrailingScalableVectorOfSizeAndType allowedTrailingSizes, + list allowedTypes> : AllOfType< + [TrailingScalableVectorOf, + ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], + TrailingScalableVectorOf.summary # + ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, + "::mlir::VectorType">; + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index b7f1020deba1e..594c9b4c270f2 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt index fffc77245d12c..9ef7384fc5492 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt @@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRVectorDialect MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index 7031ab4f799c4..2f1c43fae240d 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms LINK_LIBS PUBLIC MLIRArmSVEDialect MLIRFuncDialect + MLIRVectorDialect MLIRIR MLIRLLVMCommonConversion MLIRLLVMDialect diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index abbb978304068..d280d2415ecdb 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -66,6 +68,54 @@ using ScalableMaskedDivFOpLowering = OneToOneConvertToLLVMPattern; +namespace { + +template +struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Op convertOp, typename Op::Adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = convertOp.getLoc(); + + auto source = convertOp.getSource(); + VectorType sourceType = source.getType(); + VectorType resultType = convertOp.getResult().getType(); + + Value result = rewriter.create( + loc, resultType, rewriter.getZeroAttr(resultType)); + + SmallVector tileShape(sourceType.getRank(), 1); + tileShape.back() = sourceType.getShape().back(); + + for (SmallVector index : + StaticTileOffsetRange(sourceType.getShape(), tileShape)) { + auto extractOrInsertPosition = ArrayRef(index).drop_back(); + auto sourceVector = rewriter.create( + loc, source, extractOrInsertPosition); + auto convertedType = + VectorType::Builder(llvm::cast(sourceVector.getType())) + .setDim(0, resultType.getShape().back()); + auto convertedVector = + rewriter.create(loc, TypeRange{convertedType}, sourceVector); + result = rewriter.create(loc, convertedVector, result, + extractOrInsertPosition); + } + + rewriter.replaceOp(convertOp, result); + return success(); + } +}; + +using ConvertToSvboolOpLowering = + SvboolConversionOpLowering; + +using ConvertFromSvboolOpLowering = + SvboolConversionOpLowering; + +} // namespace + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { @@ -88,7 +138,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( ScalableMaskedMulFOpLowering, ScalableMaskedSDivIOpLowering, ScalableMaskedUDivIOpLowering, - ScalableMaskedDivFOpLowering>(converter); + ScalableMaskedDivFOpLowering, + ConvertToSvboolOpLowering, + ConvertFromSvboolOpLowering>(converter); // clang-format on } @@ -107,7 +159,9 @@ void mlir::configureArmSVELegalizeForExportTarget( ScalableMaskedMulFIntrOp, ScalableMaskedSDivIIntrOp, ScalableMaskedUDivIIntrOp, - ScalableMaskedDivFIntrOp>(); + ScalableMaskedDivFIntrOp, + ConvertToSvboolIntrOp, + ConvertFromSvboolIntrOp>(); target.addIllegalOp(); + ScalableMaskedDivFOp, + ConvertToSvboolOp, + ConvertFromSvboolOp>(); // clang-format on } diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir new file mode 100644 index 0000000000000..a1fa0d0292b7b --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/invalid.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2> + return %mask : vector<2x[8]xi2> +} + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1> + return %mask : vector<[7]xi1> +} + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1> + return %mask : vector<[4]x[8]xi1> +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2> + return %bool : vector<2x[16]xi1> +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1> + return +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1> + return +} + + diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 2d980db981034..04f2f43e6a5e7 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s +// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s func.func @arm_sve_sdot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, @@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_smmla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_udot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_ummla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, @@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, return %4 : vector<[4]xi32> } +// ----- + func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, @@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>, return %3 : vector<[4]xf32> } +// ----- + func.func @arm_sve_abs_diff(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[4]xi32> { @@ -111,8 +123,67 @@ func.func @arm_sve_abs_diff(%a: vector<[4]xi32>, return %3 : vector<[4]xi32> } +// ----- + func.func @get_vector_scale() -> index { // CHECK: llvm.intr.vscale %0 = vector.vscale return %0 : index } + +// ----- + +func.func @convert_1d_mask_to_svbool(%mask: vector<[4]xi1>) -> vector<[16]xi1> +{ + // CHECK: "arm_sve.intr.convert.to.svbool"(%{{.*}}) : (vector<[4]xi1>) -> vector<[16]xi1> + %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> + return %svbool : vector<[16]xi1> +} + +// ----- + +func.func @convert_1d_mask_from_svbool(%svbool: vector<[16]xi1>) -> vector<[2]xi1> +{ + // CHECK: "arm_sve.intr.convert.from.svbool"(%{{.*}}) : (vector<[16]xi1>) -> vector<[2]xi1> + %mask = arm_sve.convert_from_svbool %svbool : vector<[2]xi1> + return %mask : vector<[2]xi1> +} + +// ----- + +// CHECK-LABEL: @convert_2d_mask_to_svbool( +// CHECK-SAME: %[[MASK:.*]]: !llvm.array<2 x vector<[8]xi1>>) +func.func @convert_2d_mask_to_svbool(%mask: vector<2x[8]xi1>) -> vector<2x[16]xi1> +{ + // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK0:.*]] = llvm.extractvalue %[[MASK]][0] : !llvm.array<2 x vector<[8]xi1>> + // CHECK-NEXT: %[[SVBOOL0:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK0]]) : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[SVBOOL0]], %[[RES0]][0] : !llvm.array<2 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK1:.*]] = llvm.extractvalue %[[MASK]][1] : !llvm.array<2 x vector<[8]xi1>> + // CHECK-NEXT: %[[SVBOOL1:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK1]]) : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[SVBOOL:.*]] = llvm.insertvalue %[[SVBOOL1]], %[[RES1]][1] : !llvm.array<2 x vector<[16]xi1>> + %svbool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi1> + // CHECK-NEXT: llvm.return %[[SVBOOL]] : !llvm.array<2 x vector<[16]xi1>> + return %svbool : vector<2x[16]xi1> +} + +// ----- + +// CHECK-LABEL: @convert_2d_mask_from_svbool( +// CHECK-SAME: %[[SVBOOL:.*]]: !llvm.array<3 x vector<[16]xi1>>) +func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[1]xi1> +{ + // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense : vector<3x[1]xi1>) : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL0:.*]] = llvm.extractvalue %[[SVBOOL]][0] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK0:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL0]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[MASK0]], %[[RES0]][0] : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL1:.*]] = llvm.extractvalue %[[SVBOOL]][1] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK1:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL1]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[RES2:.*]] = llvm.insertvalue %[[MASK1]], %[[RES1]][1] : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL2:.*]] = llvm.extractvalue %[[SVBOOL]][2] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK2:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL2]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[MASK:.*]] = llvm.insertvalue %[[MASK2]], %[[RES2]][2] : !llvm.array<3 x vector<[1]xi1>> + %mask = arm_sve.convert_from_svbool %svbool : vector<3x[1]xi1> + // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>> + return %mask : vector<3x[1]xi1> +} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index d2ca035c17bfb..af390bb330a34 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s func.func @arm_sve_sdot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, @@ -9,6 +9,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_smmla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -18,6 +20,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_udot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -27,6 +31,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_ummla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -36,6 +42,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, @@ -61,6 +69,8 @@ func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, return %2 : vector<[4]xi32> } +// ----- + func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, @@ -82,3 +92,40 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>, vector<[4]xf32> return %3 : vector<[4]xf32> } + +// ----- + +func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>, + %b: vector<[2]xi1>, + %c: vector<[4]xi1>, + %d: vector<[8]xi1>) { + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1> + %1 = arm_sve.convert_to_svbool %a : vector<[1]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[2]xi1> + %2 = arm_sve.convert_to_svbool %b : vector<[2]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[4]xi1> + %3 = arm_sve.convert_to_svbool %c : vector<[4]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1> + %4 = arm_sve.convert_to_svbool %d : vector<[8]xi1> + return +} + +// ----- + +func.func @arm_sve_convert_from_svbool(%bool: vector<[16]xi1>) { + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1> + %1 = arm_sve.convert_from_svbool %bool : vector<[1]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1> + %2 = arm_sve.convert_from_svbool %bool : vector<[2]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1> + %3 = arm_sve.convert_from_svbool %bool : vector<[4]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1> + %4 = arm_sve.convert_from_svbool %bool : vector<[8]xi1> + return +} From 9c656ca0ba0624b5840badd412ba2323d007e52a Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 10 Oct 2023 14:39:40 +0000 Subject: [PATCH 2/6] Fixups --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 47 ++++++++++++------- mlir/include/mlir/IR/CommonTypeConstraints.td | 19 ++++++-- .../Transforms/LegalizeForLLVMExport.cpp | 23 +++++++++ .../Dialect/ArmSVE/legalize-for-llvm.mlir | 2 +- 4 files changed, 69 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index fa7f6d080d5a9..cae87b764fc67 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -45,7 +45,7 @@ def SVEPredicate : ScalableVectorOfRankAndLengthAndType< // Generalizations of SVBool and SVEPredicate to ranks >= 1. // These are masks with a single trailing scalable dimension. def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>; -def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>; +def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>; //===----------------------------------------------------------------------===// // ArmSVE op definitions @@ -243,14 +243,13 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } - -class SvboolTypeContraint : TypesMatchWith< +class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", - [Pure, SvboolTypeContraint<"result", "source">]> + [Pure, SvboolTypeConstraint<"result", "source">]> { let summary = "Convert a svbool type to a SVE predicate type"; let description = [{ @@ -260,25 +259,33 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", Example 1: Convert a 1-D svbool mask to a SVE predicate. ```mlir - %svbool = vector.load %memref[%c0] : memref, vector<[16]xi1> - %mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1> + %source = vector.load %memref[%c0] : memref, vector<[16]xi1> + %result = arm_sve.convert_from_svbool %source : vector<[4]xi1> ``` Example 2: Convert a 2-D svbool mask to a mask of SVE predicates. ```mlir - %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> - %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> + %source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> + %result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1> ``` + + --- + + A `svbool` is the smallest SVE predicate type that has a in-memory + representation (and maps to a full predicate register). In MLIR `svbool` is + represented as `vector<[16]xi1>`. Smaller SVE predicate types + (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to + a predicate after loading. }]; let arguments = (ins SVBoolMask:$source); - let results = (outs SVEMask:$result); + let results = (outs SVEPredicateMask:$result); let assemblyFormat = "$source attr-dict `:` type($result)"; } def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool", - [Pure, SvboolTypeContraint<"source", "result">]> + [Pure, SvboolTypeConstraint<"source", "result">]> { - let summary = "Convert a predicate type to a svbool type"; + let summary = "Convert a SVE predicate type to a svbool type"; let description = [{ Converts SVE predicate types (or vectors of predicate types, e.g. `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can @@ -286,19 +293,27 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool", Example 1: Convert a 1-D SVE predicate to a svbool mask. ```mlir - %mask = vector.create_mask %dim_size : vector<[4]xi1> - %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> + %source = vector.create_mask %dim_size : vector<[4]xi1> + %result = arm_sve.convert_to_svbool %source : vector<[4]xi1> // => Results in vector<[16]xi1> ``` Example 2: Convert a 2-D mask of SVE predicates to a svbool mask. ```mlir - %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> - %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> + %source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> + %result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ``` + + --- + + A `svbool` is the smallest SVE predicate type that has a in-memory + representation (and maps to a full predicate register). In MLIR `svbool` is + represented as `vector<[16]xi1>`. Smaller SVE predicate types + (`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be + stored. }]; - let arguments = (ins SVEMask:$source); + let arguments = (ins SVEPredicateMask:$source); let results = (outs SVBoolMask:$result); let assemblyFormat = "$source attr-dict `:` type($source)"; } diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 54a5a97fe2b64..a7970e59de8c2 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -38,10 +38,17 @@ def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) & ::llvm::cast($_self).isScalable()}]>; // Whether a type is a scalable VectorType, with a single trailing scalable dimension. -def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, - CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>; +// Examples: +// Valid: +// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> +// Invalid +// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> +def IsOnlyTrailingDimScalablePred : And<[ + CPred<"::llvm::isa<::mlir::VectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, + CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)"> +]>; // Whether a type is a VectorType and all dimensions are scalable. def allDimsScalableVectorTypePred : And<[ @@ -410,8 +417,10 @@ class ScalableVectorOf allowedTypes> : ShapedContainerType; +// Any vector with a single trailing scalable dimension, with an element type in +// the `allowedTypes` list. class TrailingScalableVectorOf allowedTypes> : - ShapedContainerType; // Whether the number of elements of a vector is from the given diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index d280d2415ecdb..ca9e280f51085 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -70,6 +70,25 @@ using ScalableMaskedDivFOpLowering = namespace { +/// Unrolls a conversion to/from equivalent vector types, to allow using a +/// conversion intrinsic that only supports 1-D vector types. +/// +/// Example: +/// ``` +/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> +/// ``` +/// is rewritten into: +/// ``` +/// %cst = arith.constant dense : vector<2x[16]xi1> +/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> +/// %2 = "arm_sve.intr.convert.to.svbool"(%1) +/// : (vector<[4]xi1>) -> vector<[16]xi1> +/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1> +/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> +/// %5 = "arm_sve.intr.convert.to.svbool"(%4) +/// : (vector<[4]xi1>) -> vector<[16]xi1> +/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1> +/// ``` template struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -86,9 +105,13 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { Value result = rewriter.create( loc, resultType, rewriter.getZeroAttr(resultType)); + // We want to iterate over the input vector in steps of the trailing + // dimension. So this creates tile shape where all leading dimensions are 1, + // and the trailing dimension step is the size of the dimension. SmallVector tileShape(sourceType.getRank(), 1); tileShape.back() = sourceType.getShape().back(); + // Iterate over all scalable mask/predicate slices of the source vector. for (SmallVector index : StaticTileOffsetRange(sourceType.getShape(), tileShape)) { auto extractOrInsertPosition = ArrayRef(index).drop_back(); diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 04f2f43e6a5e7..8e76fb7119b84 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s func.func @arm_sve_sdot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, From 308f5f16529341cb2790fdf1edbeb556fc7a0f8f Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 11 Oct 2023 14:20:31 +0000 Subject: [PATCH 3/6] Fixups --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 4 +- mlir/include/mlir/IR/CommonTypeConstraints.td | 27 +++++++---- .../Transforms/LegalizeForLLVMExport.cpp | 4 +- mlir/test/Dialect/ArmSVE/roundtrip.mlir | 46 ++++++++++++++++--- 4 files changed, 61 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index cae87b764fc67..826f7aac9b380 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -274,8 +274,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", A `svbool` is the smallest SVE predicate type that has a in-memory representation (and maps to a full predicate register). In MLIR `svbool` is represented as `vector<[16]xi1>`. Smaller SVE predicate types - (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to - a predicate after loading. + (`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to + the original predicate type after loading. }]; let arguments = (ins SVBoolMask:$source); let results = (outs SVEPredicateMask:$result); diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index a7970e59de8c2..0c5453ee1a068 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -500,20 +500,27 @@ class IsScalableVectorOfLengthPred allowedLengths> : == }] # allowedlength>)>]>; -class abs { - int ret = !if(!lt(value, 0), !sub(0, value), value); +// Normalizes an index so it can be bounds checked. +// Negative values are mapped to their absolute value. +// - These are used to index in reverse (i.e. index -1 would be the last element) +// Positive values are mapped to their value + 1. +// - This results the same range of values as the negative indices +// This allows bounds checking to be: len(list) >= NormalizeIndex.ret. +class NormalizeIndex { + int ret = !if(!lt(value, 0), !sub(0, value), !add(value, 1)); } -// Whether the n-th (starting from 1) dim of the shape matches the given `size`. +// Whether the n-th dim of the shape matches the given `size`. // Negative values index in reverse. class IsNthDimSizeIsOneOfPred allowedSizes> - : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs.ret>, - CPred<"::llvm::is_contained(ArrayRef({" # !interleave(allowedSizes, ", ") # "}), " - # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" - # !if(!lt(n, 0), - "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, - "" # !sub(n, 1)) - # "))">]>; + : And<[ + CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex.ret>, + CPred<"::llvm::is_contained(ArrayRef({" # !interleave(allowedSizes, ", ") # "}), " + # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" + # !if(!lt(n, 0), + "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, + "" # n) + # "))">]>; // Whether the shape of a vector matches the given `shape` list. class IsVectorOfShape shape> diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index ca9e280f51085..f54a26c27c2ac 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -83,11 +83,11 @@ namespace { /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> /// %2 = "arm_sve.intr.convert.to.svbool"(%1) /// : (vector<[4]xi1>) -> vector<[16]xi1> -/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1> +/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> /// %5 = "arm_sve.intr.convert.to.svbool"(%4) /// : (vector<[4]xi1>) -> vector<[16]xi1> -/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1> +/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> /// ``` template struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index af390bb330a34..c9a0b6db8fa80 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -98,7 +98,11 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>, func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>, %b: vector<[2]xi1>, %c: vector<[4]xi1>, - %d: vector<[8]xi1>) { + %d: vector<[8]xi1>, + %e: vector<2x3x[1]xi1>, + %f: vector<4x[2]xi1>, + %g: vector<1x1x1x2x[4]xi1>, + %h: vector<100x[8]xi1>) { // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1> %1 = arm_sve.convert_to_svbool %a : vector<[1]xi1> @@ -110,22 +114,52 @@ func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>, // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1> %4 = arm_sve.convert_to_svbool %d : vector<[8]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<2x3x[1]xi1> + %5 = arm_sve.convert_to_svbool %e : vector<2x3x[1]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<4x[2]xi1> + %6 = arm_sve.convert_to_svbool %f : vector<4x[2]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<1x1x1x2x[4]xi1> + %7 = arm_sve.convert_to_svbool %g : vector<1x1x1x2x[4]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<100x[8]xi1> + %8 = arm_sve.convert_to_svbool %h : vector<100x[8]xi1> + return } // ----- -func.func @arm_sve_convert_from_svbool(%bool: vector<[16]xi1>) { +func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>, + %b: vector<2x3x[16]xi1>, + %c: vector<4x[16]xi1>, + %d: vector<1x1x1x1x[16]xi1>, + %e: vector<32x[16]xi1>) { // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1> - %1 = arm_sve.convert_from_svbool %bool : vector<[1]xi1> + %1 = arm_sve.convert_from_svbool %a : vector<[1]xi1> // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1> - %2 = arm_sve.convert_from_svbool %bool : vector<[2]xi1> + %2 = arm_sve.convert_from_svbool %a : vector<[2]xi1> // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1> - %3 = arm_sve.convert_from_svbool %bool : vector<[4]xi1> + %3 = arm_sve.convert_from_svbool %a : vector<[4]xi1> // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1> - %4 = arm_sve.convert_from_svbool %bool : vector<[8]xi1> + %4 = arm_sve.convert_from_svbool %a : vector<[8]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<2x3x[1]xi1> + %5 = arm_sve.convert_from_svbool %b : vector<2x3x[1]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<4x[2]xi1> + %6 = arm_sve.convert_from_svbool %c : vector<4x[2]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<1x1x1x1x[4]xi1> + %7 = arm_sve.convert_from_svbool %d : vector<1x1x1x1x[4]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<32x[8]xi1> + %8 = arm_sve.convert_from_svbool %e : vector<32x[8]xi1> + return } From 1738a9b7101cd16c89bebced77ea7324a91d1636 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 11 Oct 2023 16:55:22 +0000 Subject: [PATCH 4/6] Fixups --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 6 ++- mlir/include/mlir/IR/CommonTypeConstraints.td | 46 ++++++++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 826f7aac9b380..e3f3d9e62e8fb 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -44,8 +44,10 @@ def SVEPredicate : ScalableVectorOfRankAndLengthAndType< // Generalizations of SVBool and SVEPredicate to ranks >= 1. // These are masks with a single trailing scalable dimension. -def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>; -def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>; +def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType< + [16], [I1]>; +def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType< + [16, 8, 4, 2, 1], [I1]>; //===----------------------------------------------------------------------===// // ArmSVE op definitions diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 0c5453ee1a068..0740a0a0c9783 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -43,7 +43,7 @@ def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) & // - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> // Invalid // - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> -def IsOnlyTrailingDimScalablePred : And<[ +def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[ CPred<"::llvm::isa<::mlir::VectorType>($_self)">, CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, @@ -419,8 +419,8 @@ class ScalableVectorOf allowedTypes> : // Any vector with a single trailing scalable dimension, with an element type in // the `allowedTypes` list. -class TrailingScalableVectorOf allowedTypes> : - ShapedContainerType allowedTypes> : + ShapedContainerType; // Whether the number of elements of a vector is from the given @@ -500,18 +500,30 @@ class IsScalableVectorOfLengthPred allowedLengths> : == }] # allowedlength>)>]>; -// Normalizes an index so it can be bounds checked. +// Normalizes a (possibly negative) index so it can be easily bounds checked. // Negative values are mapped to their absolute value. // - These are used to index in reverse (i.e. index -1 would be the last element) // Positive values are mapped to their value + 1. // - This results the same range of values as the negative indices -// This allows bounds checking to be: len(list) >= NormalizeIndex.ret. +// This allows bounds checking to be: len(list) >= NormalizeIndex.ret (see +// first CPred of IsNthDimSizeIsOneOfPred). class NormalizeIndex { int ret = !if(!lt(value, 0), !sub(0, value), !add(value, 1)); } -// Whether the n-th dim of the shape matches the given `size`. -// Negative values index in reverse. +// Whether the n-th dim of the shape is contained within `allowedSizes`. +// Negative values for `n` index in reverse. +// +// Examples: +// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}> +// - Accepts any shape where the first dim is 2, 3, or 4. +// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc +// IsNthDimSizeIsOneOfPred<-1, {16}> +// - Accepts any shape where the last dim is 16. +// * This means shapes like 2x16, 16, 1x2x3x4x16, etc +// IsNthDimSizeIsOneOfPred<-2, {10, 5}> +// - Accepts any shape where the second to last dim is 10 or 5. +// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc class IsNthDimSizeIsOneOfPred allowedSizes> : And<[ CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex.ret>, @@ -526,13 +538,6 @@ class IsNthDimSizeIsOneOfPred allowedSizes> class IsVectorOfShape shape> : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; -// Any ShapedType where the size of the n-th dim is contained in `sizes`. -// Negative values index in reverse. -class ShapedTypeWithNthDimOfSize allowedSizes> : Type< - IsNthDimSizeIsOneOfPred, - " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", - "::mlir::ShapedType">; - // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< @@ -594,14 +599,21 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`. +// Negative values for `n` index in reverse. +class ShapedTypeWithNthDimOfSize allowedSizes> : Type< + IsNthDimSizeIsOneOfPred, + " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", + "::mlir::ShapedType">; + // Any scalable vector with a single trailing scalable dimensions, where the // size of the trailing dimension is in `allowedTrailingSizes` list, and the // type is in the `allowedTypes` list. -class TrailingScalableVectorOfSizeAndType allowedTrailingSizes, +class VectorWithTrailingDimScalableOfSizeAndType allowedTrailingSizes, list allowedTypes> : AllOfType< - [TrailingScalableVectorOf, + [VectorWithTrailingDimScalableOf, ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], - TrailingScalableVectorOf.summary # + VectorWithTrailingDimScalableOf.summary # ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, "::mlir::VectorType">; From cde393010881faa3c4060bfd8129c95f614bb125 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 11 Oct 2023 16:59:12 +0000 Subject: [PATCH 5/6] Add comment --- mlir/include/mlir/IR/CommonTypeConstraints.td | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 0740a0a0c9783..b85360b5881a3 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -419,6 +419,9 @@ class ScalableVectorOf allowedTypes> : // Any vector with a single trailing scalable dimension, with an element type in // the `allowedTypes` list. +// +// Note: This Similar to ScalableVectorOf, with the extra requirement that only +// the trailing dim is scalable. class VectorWithTrailingDimScalableOf allowedTypes> : ShapedContainerType; From 60308167cbb7a2674f533ff3778f4cabeb8ca810 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 11 Oct 2023 17:38:11 +0000 Subject: [PATCH 6/6] Rewrite comment --- mlir/include/mlir/IR/CommonTypeConstraints.td | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index b85360b5881a3..c3f18965e343a 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -503,15 +503,15 @@ class IsScalableVectorOfLengthPred allowedLengths> : == }] # allowedlength>)>]>; -// Normalizes a (possibly negative) index so it can be easily bounds checked. -// Negative values are mapped to their absolute value. -// - These are used to index in reverse (i.e. index -1 would be the last element) -// Positive values are mapped to their value + 1. -// - This results the same range of values as the negative indices -// This allows bounds checking to be: len(list) >= NormalizeIndex.ret (see -// first CPred of IsNthDimSizeIsOneOfPred). +// Normalizes an index so the indices in both directions have the same value. +// For example, when indexing forwards index 2 is the third element. When +// indexing in reverse the third element is -3. This helper would map both of +// these to the "normalized" index of 3. This makes the bounds checking in +// IsNthDimSizeIsOneOfPred simpler (see first CPred). class NormalizeIndex { - int ret = !if(!lt(value, 0), !sub(0, value), !add(value, 1)); + int ret = !if(!lt(value, 0), + !sub(0, value) /* -value if negative */, + !add(value, 1) /* value + 1 if positive*/); } // Whether the n-th dim of the shape is contained within `allowedSizes`.