diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index d4294b4dd9fd4..e3f3d9e62e8fb 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect { This dialect contains the definitions necessary to target specific Arm SVE scalable vector operations. }]; + + let dependentDialects = ["vector::VectorDialect"]; } //===----------------------------------------------------------------------===// @@ -40,6 +42,13 @@ def SVBool : ScalableVectorOfRankAndLengthAndType< def SVEPredicate : ScalableVectorOfRankAndLengthAndType< [1], [16, 8, 4, 2, 1], [I1]>; +// Generalizations of SVBool and SVEPredicate to ranks >= 1. +// These are masks with a single trailing scalable dimension. +def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType< + [16], [I1]>; +def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType< + [16, 8, 4, 2, 1], [I1]>; + //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -236,6 +245,81 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +class SvboolTypeConstraint : TypesMatchWith< + "expected corresponding svbool type widened to [16]xi1", + lhsArg, rhsArg, + "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; + +def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", + [Pure, SvboolTypeConstraint<"result", "source">]> +{ + let summary = "Convert a svbool type to a SVE predicate type"; + let description = [{ + Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g. + `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing + dimension can be scalable. + + Example 1: Convert a 1-D svbool mask to a SVE predicate. + ```mlir + %source = vector.load %memref[%c0] : memref, vector<[16]xi1> + %result = arm_sve.convert_from_svbool %source : vector<[4]xi1> + ``` + + Example 2: Convert a 2-D svbool mask to a mask of SVE predicates. + ```mlir + %source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> + %result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1> + ``` + + --- + + A `svbool` is the smallest SVE predicate type that has a in-memory + representation (and maps to a full predicate register). In MLIR `svbool` is + represented as `vector<[16]xi1>`. Smaller SVE predicate types + (`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to + the original predicate type after loading. + }]; + let arguments = (ins SVBoolMask:$source); + let results = (outs SVEPredicateMask:$result); + let assemblyFormat = "$source attr-dict `:` type($result)"; +} + +def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool", + [Pure, SvboolTypeConstraint<"source", "result">]> +{ + let summary = "Convert a SVE predicate type to a svbool type"; + let description = [{ + Converts SVE predicate types (or vectors of predicate types, e.g. + `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can + be scalable. + + Example 1: Convert a 1-D SVE predicate to a svbool mask. + ```mlir + %source = vector.create_mask %dim_size : vector<[4]xi1> + %result = arm_sve.convert_to_svbool %source : vector<[4]xi1> + // => Results in vector<[16]xi1> + ``` + + Example 2: Convert a 2-D mask of SVE predicates to a svbool mask. + ```mlir + %source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> + %result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1> + // => Results in vector<2x[16]xi1> + ``` + + --- + + A `svbool` is the smallest SVE predicate type that has a in-memory + representation (and maps to a full predicate register). In MLIR `svbool` is + represented as `vector<[16]xi1>`. Smaller SVE predicate types + (`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be + stored. + }]; + let arguments = (ins SVEPredicateMask:$source); + let results = (outs SVBoolMask:$result); + let assemblyFormat = "$source attr-dict `:` type($source)"; +} + def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", [Commutative]>; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 4fc14e30b8a10..c3f18965e343a 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && ::llvm::cast($_self).isScalable()}]>; +// Whether a type is a scalable VectorType, with a single trailing scalable dimension. +// Examples: +// Valid: +// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> +// Invalid +// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> +def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[ + CPred<"::llvm::isa<::mlir::VectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, + CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)"> +]>; + // Whether a type is a VectorType and all dimensions are scalable. def allDimsScalableVectorTypePred : And<[ IsVectorTypePred, @@ -404,6 +417,15 @@ class ScalableVectorOf allowedTypes> : ShapedContainerType; +// Any vector with a single trailing scalable dimension, with an element type in +// the `allowedTypes` list. +// +// Note: This Similar to ScalableVectorOf, with the extra requirement that only +// the trailing dim is scalable. +class VectorWithTrailingDimScalableOf allowedTypes> : + ShapedContainerType; + // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : @@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred allowedLengths> : == }] # allowedlength>)>]>; +// Normalizes an index so the indices in both directions have the same value. +// For example, when indexing forwards index 2 is the third element. When +// indexing in reverse the third element is -3. This helper would map both of +// these to the "normalized" index of 3. This makes the bounds checking in +// IsNthDimSizeIsOneOfPred simpler (see first CPred). +class NormalizeIndex { + int ret = !if(!lt(value, 0), + !sub(0, value) /* -value if negative */, + !add(value, 1) /* value + 1 if positive*/); +} + +// Whether the n-th dim of the shape is contained within `allowedSizes`. +// Negative values for `n` index in reverse. +// +// Examples: +// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}> +// - Accepts any shape where the first dim is 2, 3, or 4. +// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc +// IsNthDimSizeIsOneOfPred<-1, {16}> +// - Accepts any shape where the last dim is 16. +// * This means shapes like 2x16, 16, 1x2x3x4x16, etc +// IsNthDimSizeIsOneOfPred<-2, {10, 5}> +// - Accepts any shape where the second to last dim is 10 or 5. +// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc +class IsNthDimSizeIsOneOfPred allowedSizes> + : And<[ + CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex.ret>, + CPred<"::llvm::is_contained(ArrayRef({" # !interleave(allowedSizes, ", ") # "}), " + # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" + # !if(!lt(n, 0), + "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, + "" # n) + # "))">]>; + // Whether the shape of a vector matches the given `shape` list. class IsVectorOfShape shape> : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; @@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`. +// Negative values for `n` index in reverse. +class ShapedTypeWithNthDimOfSize allowedSizes> : Type< + IsNthDimSizeIsOneOfPred, + " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", + "::mlir::ShapedType">; + +// Any scalable vector with a single trailing scalable dimensions, where the +// size of the trailing dimension is in `allowedTrailingSizes` list, and the +// type is in the `allowedTypes` list. +class VectorWithTrailingDimScalableOfSizeAndType allowedTrailingSizes, + list allowedTypes> : AllOfType< + [VectorWithTrailingDimScalableOf, + ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], + VectorWithTrailingDimScalableOf.summary # + ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, + "::mlir::VectorType">; + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index b7f1020deba1e..594c9b4c270f2 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt index fffc77245d12c..9ef7384fc5492 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt @@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRVectorDialect MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index 7031ab4f799c4..2f1c43fae240d 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms LINK_LIBS PUBLIC MLIRArmSVEDialect MLIRFuncDialect + MLIRVectorDialect MLIRIR MLIRLLVMCommonConversion MLIRLLVMDialect diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index abbb978304068..f54a26c27c2ac 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering = OneToOneConvertToLLVMPattern; +namespace { + +/// Unrolls a conversion to/from equivalent vector types, to allow using a +/// conversion intrinsic that only supports 1-D vector types. +/// +/// Example: +/// ``` +/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> +/// ``` +/// is rewritten into: +/// ``` +/// %cst = arith.constant dense : vector<2x[16]xi1> +/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> +/// %2 = "arm_sve.intr.convert.to.svbool"(%1) +/// : (vector<[4]xi1>) -> vector<[16]xi1> +/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> +/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> +/// %5 = "arm_sve.intr.convert.to.svbool"(%4) +/// : (vector<[4]xi1>) -> vector<[16]xi1> +/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> +/// ``` +template +struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Op convertOp, typename Op::Adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = convertOp.getLoc(); + + auto source = convertOp.getSource(); + VectorType sourceType = source.getType(); + VectorType resultType = convertOp.getResult().getType(); + + Value result = rewriter.create( + loc, resultType, rewriter.getZeroAttr(resultType)); + + // We want to iterate over the input vector in steps of the trailing + // dimension. So this creates tile shape where all leading dimensions are 1, + // and the trailing dimension step is the size of the dimension. + SmallVector tileShape(sourceType.getRank(), 1); + tileShape.back() = sourceType.getShape().back(); + + // Iterate over all scalable mask/predicate slices of the source vector. + for (SmallVector index : + StaticTileOffsetRange(sourceType.getShape(), tileShape)) { + auto extractOrInsertPosition = ArrayRef(index).drop_back(); + auto sourceVector = rewriter.create( + loc, source, extractOrInsertPosition); + auto convertedType = + VectorType::Builder(llvm::cast(sourceVector.getType())) + .setDim(0, resultType.getShape().back()); + auto convertedVector = + rewriter.create(loc, TypeRange{convertedType}, sourceVector); + result = rewriter.create(loc, convertedVector, result, + extractOrInsertPosition); + } + + rewriter.replaceOp(convertOp, result); + return success(); + } +}; + +using ConvertToSvboolOpLowering = + SvboolConversionOpLowering; + +using ConvertFromSvboolOpLowering = + SvboolConversionOpLowering; + +} // namespace + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { @@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( ScalableMaskedMulFOpLowering, ScalableMaskedSDivIOpLowering, ScalableMaskedUDivIOpLowering, - ScalableMaskedDivFOpLowering>(converter); + ScalableMaskedDivFOpLowering, + ConvertToSvboolOpLowering, + ConvertFromSvboolOpLowering>(converter); // clang-format on } @@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget( ScalableMaskedMulFIntrOp, ScalableMaskedSDivIIntrOp, ScalableMaskedUDivIIntrOp, - ScalableMaskedDivFIntrOp>(); + ScalableMaskedDivFIntrOp, + ConvertToSvboolIntrOp, + ConvertFromSvboolIntrOp>(); target.addIllegalOp(); + ScalableMaskedDivFOp, + ConvertToSvboolOp, + ConvertFromSvboolOp>(); // clang-format on } diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir new file mode 100644 index 0000000000000..a1fa0d0292b7b --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/invalid.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2> + return %mask : vector<2x[8]xi2> +} + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1> + return %mask : vector<[7]xi1> +} + +// ----- + +func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> { + // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} + %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1> + return %mask : vector<[4]x[8]xi1> +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2> + return %bool : vector<2x[16]xi1> +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1> + return +} + +// ----- + +func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> { + // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}} + %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1> + return +} + + diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 2d980db981034..8e76fb7119b84 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s +// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s func.func @arm_sve_sdot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, @@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_smmla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_udot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_ummla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) @@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, @@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, return %4 : vector<[4]xi32> } +// ----- + func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, @@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>, return %3 : vector<[4]xf32> } +// ----- + func.func @arm_sve_abs_diff(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[4]xi32> { @@ -111,8 +123,67 @@ func.func @arm_sve_abs_diff(%a: vector<[4]xi32>, return %3 : vector<[4]xi32> } +// ----- + func.func @get_vector_scale() -> index { // CHECK: llvm.intr.vscale %0 = vector.vscale return %0 : index } + +// ----- + +func.func @convert_1d_mask_to_svbool(%mask: vector<[4]xi1>) -> vector<[16]xi1> +{ + // CHECK: "arm_sve.intr.convert.to.svbool"(%{{.*}}) : (vector<[4]xi1>) -> vector<[16]xi1> + %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> + return %svbool : vector<[16]xi1> +} + +// ----- + +func.func @convert_1d_mask_from_svbool(%svbool: vector<[16]xi1>) -> vector<[2]xi1> +{ + // CHECK: "arm_sve.intr.convert.from.svbool"(%{{.*}}) : (vector<[16]xi1>) -> vector<[2]xi1> + %mask = arm_sve.convert_from_svbool %svbool : vector<[2]xi1> + return %mask : vector<[2]xi1> +} + +// ----- + +// CHECK-LABEL: @convert_2d_mask_to_svbool( +// CHECK-SAME: %[[MASK:.*]]: !llvm.array<2 x vector<[8]xi1>>) +func.func @convert_2d_mask_to_svbool(%mask: vector<2x[8]xi1>) -> vector<2x[16]xi1> +{ + // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK0:.*]] = llvm.extractvalue %[[MASK]][0] : !llvm.array<2 x vector<[8]xi1>> + // CHECK-NEXT: %[[SVBOOL0:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK0]]) : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[SVBOOL0]], %[[RES0]][0] : !llvm.array<2 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK1:.*]] = llvm.extractvalue %[[MASK]][1] : !llvm.array<2 x vector<[8]xi1>> + // CHECK-NEXT: %[[SVBOOL1:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK1]]) : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[SVBOOL:.*]] = llvm.insertvalue %[[SVBOOL1]], %[[RES1]][1] : !llvm.array<2 x vector<[16]xi1>> + %svbool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi1> + // CHECK-NEXT: llvm.return %[[SVBOOL]] : !llvm.array<2 x vector<[16]xi1>> + return %svbool : vector<2x[16]xi1> +} + +// ----- + +// CHECK-LABEL: @convert_2d_mask_from_svbool( +// CHECK-SAME: %[[SVBOOL:.*]]: !llvm.array<3 x vector<[16]xi1>>) +func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[1]xi1> +{ + // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense : vector<3x[1]xi1>) : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL0:.*]] = llvm.extractvalue %[[SVBOOL]][0] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK0:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL0]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[MASK0]], %[[RES0]][0] : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL1:.*]] = llvm.extractvalue %[[SVBOOL]][1] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK1:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL1]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[RES2:.*]] = llvm.insertvalue %[[MASK1]], %[[RES1]][1] : !llvm.array<3 x vector<[1]xi1>> + // CHECK-NEXT: %[[SVBOOL2:.*]] = llvm.extractvalue %[[SVBOOL]][2] : !llvm.array<3 x vector<[16]xi1>> + // CHECK-NEXT: %[[MASK2:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL2]]) : (vector<[16]xi1>) -> vector<[1]xi1> + // CHECK-NEXT: %[[MASK:.*]] = llvm.insertvalue %[[MASK2]], %[[RES2]][2] : !llvm.array<3 x vector<[1]xi1>> + %mask = arm_sve.convert_from_svbool %svbool : vector<3x[1]xi1> + // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>> + return %mask : vector<3x[1]xi1> +} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index d2ca035c17bfb..c9a0b6db8fa80 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s func.func @arm_sve_sdot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, @@ -9,6 +9,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_smmla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -18,6 +20,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_udot(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -27,6 +31,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_ummla(%a: vector<[16]xi8>, %b: vector<[16]xi8>, %c: vector<[4]xi32>) -> vector<[4]xi32> { @@ -36,6 +42,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, return %0 : vector<[4]xi32> } +// ----- + func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, @@ -61,6 +69,8 @@ func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, return %2 : vector<[4]xi32> } +// ----- + func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, @@ -82,3 +92,74 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>, vector<[4]xf32> return %3 : vector<[4]xf32> } + +// ----- + +func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>, + %b: vector<[2]xi1>, + %c: vector<[4]xi1>, + %d: vector<[8]xi1>, + %e: vector<2x3x[1]xi1>, + %f: vector<4x[2]xi1>, + %g: vector<1x1x1x2x[4]xi1>, + %h: vector<100x[8]xi1>) { + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1> + %1 = arm_sve.convert_to_svbool %a : vector<[1]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[2]xi1> + %2 = arm_sve.convert_to_svbool %b : vector<[2]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[4]xi1> + %3 = arm_sve.convert_to_svbool %c : vector<[4]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1> + %4 = arm_sve.convert_to_svbool %d : vector<[8]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<2x3x[1]xi1> + %5 = arm_sve.convert_to_svbool %e : vector<2x3x[1]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<4x[2]xi1> + %6 = arm_sve.convert_to_svbool %f : vector<4x[2]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<1x1x1x2x[4]xi1> + %7 = arm_sve.convert_to_svbool %g : vector<1x1x1x2x[4]xi1> + + // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<100x[8]xi1> + %8 = arm_sve.convert_to_svbool %h : vector<100x[8]xi1> + + return +} + +// ----- + +func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>, + %b: vector<2x3x[16]xi1>, + %c: vector<4x[16]xi1>, + %d: vector<1x1x1x1x[16]xi1>, + %e: vector<32x[16]xi1>) { + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1> + %1 = arm_sve.convert_from_svbool %a : vector<[1]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1> + %2 = arm_sve.convert_from_svbool %a : vector<[2]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1> + %3 = arm_sve.convert_from_svbool %a : vector<[4]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1> + %4 = arm_sve.convert_from_svbool %a : vector<[8]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<2x3x[1]xi1> + %5 = arm_sve.convert_from_svbool %b : vector<2x3x[1]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<4x[2]xi1> + %6 = arm_sve.convert_from_svbool %c : vector<4x[2]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<1x1x1x1x[4]xi1> + %7 = arm_sve.convert_from_svbool %d : vector<1x1x1x1x[4]xi1> + + // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<32x[8]xi1> + %8 = arm_sve.convert_from_svbool %e : vector<32x[8]xi1> + + return +}