diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 1013cbc8ca562..c458a500eb367 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias; def SPIRV_Float32 : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; -def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], +// Remove the vector size restriction. +// Although the vector size can be upto (2^64-1), uint64, +// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose +// for all practical cases. +// Also unsigned is used for the number elements for composite tyeps. +def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. @@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType allowedTypes> : "Joint Matrix">; class SPIRV_ScalarOrVectorOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; + AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>; class SPIRV_ScalarOrVectorOrCoopMatrixOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>, SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 4fc14e30b8a10..8e5f2b065d6bb 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Whether the number of elements of a vector is from the given +// `allowedRanges` list, the list has two values, start and end +// of the range (inclusive). +class IsVectorOfLengthRangePred allowedRanges> + : And<[IsVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive). +class IsFixedVectorOfLengthRangePred allowedRanges> + : And<[IsFixedVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Whether the minimum number of elements of a scalable vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive). +class IsScalableVectorOfMinLengthRangePred allowedRanges> + : And<[IsScalableVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Any vector where the number of elements is from the given +// `allowedRanges` list. +class VectorOfLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedRanges` list. +class FixedVectorOfLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any scalable vector where the minimum number of elements is from the given +// `allowedRanges` list. +class ScalableVectorOfMinLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any vector where the number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class VectorOfLengthRangeAndType allowedRanges, list allowedTypes> + : Type.predicate, VectorOfLengthRange.predicate]>, + VectorOf.summary # VectorOfLengthRange.summary, + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class FixedVectorOfLengthRangeAndType allowedRanges, list allowedTypes> + : Type< + And<[FixedVectorOf.predicate, FixedVectorOfLengthRange.predicate]>, + FixedVectorOf.summary # FixedVectorOfLengthRange.summary, + "::mlir::VectorType">; + +// Any scalable vector where the minimum number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class ScalableVectorOfMinLengthRangeAndType allowedRanges, list allowedTypes> + : Type< + And<[ScalableVectorOf.predicate, ScalableVectorOfMinLengthRange.predicate]>, + ScalableVectorOf.summary # ScalableVectorOfMinLengthRange.summary, + "::mlir::VectorType">; + + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index a51d77dda78bf..be85d3c330a88 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); } - if (t.getNumElements() > 4) { + // Number of elements should be between [2 - 2^32 -1], + // since getNumElements() returns an unsigned, the upper limit check is + // unnecessary. + if (t.getNumElements() < 2) { parser.emitError( - typeLoc, "vector length has to be less than or equal to 4 but found ") + typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") << t.getNumElements(); return Type(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 39d6603a46f96..9d39d99b41482 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) { } bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && - llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && - llvm::isa(type.getElementType()); + // Number of elements should be between [2 - 2^32 -1], + // since getNumElements() returns an unsigned, the upper limit check is + // unnecessary. + return type.getRank() == 1 && llvm::isa(type.getElementType()) && + type.getNumElements() >= 2; } Type CompositeType::getElementType(unsigned index) const { @@ -171,9 +173,17 @@ void CompositeType::getCapabilities( .Case([&](VectorType type) { auto vecSize = getNumElements(); if (vecSize == 8 || vecSize == 16) { - static const Capability caps[] = {Capability::Vector16}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); + static constexpr Capability caps[] = {Capability::Vector16, + Capability::VectorAnyINTEL}; + capabilities.push_back(caps); + } + // VectorAnyINTEL capability removes the vector size restriction and + // allows the vector size to be up to (2^32-1). + // Vector16 capability allows the vector size to be 8 and 16 + SmallVector allowedVecRange = {2, 3, 4, 8, 16}; + if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) { + static constexpr Capability caps[] = {Capability::VectorAnyINTEL}; + capabilities.push_back(caps); } return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c75d217663a9e..25e6a080642e6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -43,9 +43,13 @@ using namespace mlir; template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + const spirv::SPIRVType::ExtensionArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) { + return llvm::is_contained(ors, elidedExt); + })) continue; LLVM_DEBUG({ @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + const spirv::SPIRVType::CapabilityArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) { + return llvm::is_contained(ors, elidedCap); + })) continue; LLVM_DEBUG({ @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } -/// Returns true if the given `storageClass` needs explicit layout when used in -/// Shader environments. +/// Check capabilities and extensions requirements +/// Checks that `capCandidates`, `extCandidates`, and capability +/// (`capCandidates`) infered extension requirements are possible to be +/// satisfied with the given `targetEnv`. +/// It also provides a way to relax requirements for certain capabilities and +/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to +/// allow passes to relax certain requirements based on an option (e.g., +/// relaxing bitwidth requirement, see `convertScalarType()`, +/// `ConvertVectorType()`). +template +static LogicalResult checkCapabilityAndExtensionRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates, + const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates, + const ArrayRef elidedCapCandidates = {}, + const ArrayRef elidedExtCandidates = {}) { + SmallVector, 8> updatedExtCandidates; + llvm::append_range(updatedExtCandidates, extCandidates); + + if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates, + elidedCapCandidates))) + return failure(); + // Add capablity infered extensions to the list of extension requirement list, + // only considers the capabilities that already available in the `targetEnv`. + + // WARNING: Some capabilities are part of both the core SPIR-V + // specification and an extension (e.g., 'Groups' capability is part of both + // core specification and SPV_AMD_shader_ballot extension, hence we should + // relax the capability inferred extension for these cases). + static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups}; + ArrayRef multiModalCapsArrayRef(multiModalCaps, + std::size(multiModalCaps)); + + for (auto cap : targetEnv.getAttr().getCapabilities()) { + if (llvm::any_of(multiModalCapsArrayRef, + [&cap](spirv::Capability mMCap) { return cap == mMCap; })) + continue; + std::optional> ext = getExtensions(cap); + if (ext) + updatedExtCandidates.push_back(*ext); + } + if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates, + elidedExtCandidates))) + return failure(); + return success(); +} + +/// Returns true if the given `storageClass` needs explicit layout when used +/// in Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { switch (storageClass) { case spirv::StorageClass::PhysicalStorageBuffer: @@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, type.getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions))) return type; // Otherwise we need to adjust the type, which really means adjusting the @@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); + // If the bit-width related capabilities and extensions are not met + // for lower bit-width (<32-bit), convert it to 32-bit + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); + if (!elementType) + return nullptr; + type = VectorType::get(type.getShape(), elementType); + + SmallVector elidedCaps; + SmallVector elidedExts; + + // Relax the bitwidth requirements for capabilities and extensions + if (options.emulateLT32BitScalarTypes) { + elidedCaps.push_back(spirv::Capability::Int8); + elidedCaps.push_back(spirv::Capability::Int16); + elidedCaps.push_back(spirv::Capability::Float16); + } + // For capabilities whose requirements were relaxed, relax requirements for + // the extensions that were infered by those capabilities (e.g., elidedCaps) + for (spirv::Capability cap : elidedCaps) { + std::optional> ext = spirv::getExtensions(cap); + if (ext) + llvm::append_range(elidedExts, *ext); + } // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions, elidedCaps, elidedExts))) return type; - auto elementType = - convertScalarType(targetEnv, options, scalarType, storageClass); - if (elementType) - return VectorType::get(type.getShape(), elementType); return nullptr; } @@ -656,8 +731,9 @@ std::optional castToSourceType(const spirv::TargetEnv &targetEnv, SmallVector, 2> caps; scalarType.getExtensions(exts); scalarType.getCapabilities(caps); - if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || - failed(checkExtensionRequirements(type, targetEnv, exts))) { + + if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps, + exts))) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } @@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { - typeExtensions.clear(); - cast(valueType).getExtensions(typeExtensions); - if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, - typeExtensions))) - return false; - typeCapabilities.clear(); cast(valueType).getCapabilities(typeCapabilities); - if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, - typeCapabilities))) + typeExtensions.clear(); + cast(valueType).getExtensions(typeExtensions); + // Checking for capability and extension requirements along with capability + // infered extensions. + // If a capability is present, the extension that + // supports it should also be present, this reduces the burden of adding + // extension requirement that may or maynot be added in + // CompositeType::getExtensions(). + if (failed(checkCapabilityAndExtensionRequirements( + op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) return false; } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 0d92a8e676d85..d61ace8d6876b 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -11,9 +11,9 @@ module attributes { #spirv.vce, #spirv.resource_limits<>> } { -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { +func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) { // expected-error@+1 {{failed to legalize operation 'arith.subi'}} - %1 = arith.subi %arg0, %arg0: vector<5xi32> + %1 = arith.subi %arg0, %arg1: vector<5xi32> return } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 0221e4815a939..6ceeade486efd 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) { } } // end module + +// ----- + +//===----------------------------------------------------------------------===// +// VectorAnyINTEL support +//===----------------------------------------------------------------------===// + +// Check that with VectorAnyINTEL, VectorComputeINTEL capability, +// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed. +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @any_vector +func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) { + // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32> + %0 = arith.subi %arg0, %arg1: vector<16xi32> + return +} + +// CHECK-LABEL: @max_vector +func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) { + // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32> + %0 = arith.subi %arg0, %arg1: vector<4294967295xi32> + return +} + + +// Check float vector types of any size. +// CHECK-LABEL: @float_vector58 +func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16> + %0 = arith.addf %arg0, %arg0: vector<5xf16> + // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64> + %1 = arith.mulf %arg1, %arg1: vector<8xf64> + return +} + +} // end module diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 82d750755ffe2..6f364c5b0875c 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -351,8 +351,21 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { -// CHECK-NOT: spirv.func @large_vector -func.func @large_vector(%arg0: vector<1024xi32>) { return } +// CHECK-NOT: spirv.func @large_vector_unsupported +func.func @large_vector_unsupported(%arg0: vector<1024xi32>) { return } + +} // end module + + +// ----- + +// Check that large vectors are supported with VectorAnyINTEL or VectorComputeINTEL. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK: spirv.func @large_any_vector +func.func @large_any_vector(%arg0: vector<1024xi32>) { return } } // end module diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index 1e5d64387650c..5fc50ae99cfe9 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -2,7 +2,7 @@ module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -22,7 +22,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -42,7 +42,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -62,7 +62,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -82,7 +82,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -102,7 +102,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -122,7 +122,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -142,7 +142,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -162,7 +162,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -182,7 +182,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -202,7 +202,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -222,7 +222,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -242,7 +242,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -262,7 +262,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -282,7 +282,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -302,7 +302,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -322,7 +322,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -342,7 +342,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -362,7 +362,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -382,7 +382,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -402,7 +402,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -422,7 +422,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -442,7 +442,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -462,7 +462,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -482,7 +482,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -502,7 +502,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -522,7 +522,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -542,7 +542,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -562,7 +562,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -582,7 +582,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -602,7 +602,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -622,7 +622,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -637,3 +637,22 @@ gpu.module @kernels { } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + // CHECK-NOT: spirv.func @test_unsupported + gpu.func @test_unsupported(%arg : i32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} + %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32) + gpu.return + } +} + +} \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir index 82a2316f6c784..88a8e507c1993 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseOr %arg0, %arg1 : f16 return %0 : f16 } @@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseXor %arg0, %arg1 : f16 return %0 : f16 } @@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseAnd %arg0, %arg1 : f16 return %0 : f16 } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 3683e5b469b17..a95a6001fd204 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 %2 = spirv.GL.Exp %arg0 : vector<5xf32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 53a1015de75bc..6929ef9b21d0e 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { // ----- spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { - // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index 7dc0bd99f54b3..fa4d9e253307d 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 29a4a46136156..24fe2f9458413 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- -func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} - %2 = spirv.CL.exp %arg0 : i32 +func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () { + // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> + %2 = spirv.CL.exp %arg0 : vector<5xf32> return } // ----- -func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.exp %arg0 : vector<5xf32> +func.func @exp(%arg0 : i32) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %2 = spirv.CL.exp %arg0 : i32 return } @@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () { return } +// ----- + +func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () { + // CHECK: spirv.CL.fabs {{%.*}} : vector<5xf32> + %2 = spirv.CL.fabs %arg0 : vector<5xf32> + return +} + func.func @fabsf64(%arg0 : f64) -> () { // CHECK: spirv.CL.fabs {{%.*}} : f64 %2 = spirv.CL.fabs %arg0 : f64 @@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () { // ----- -func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.fabs %arg0 : vector<5xf32> - return -} - -// ----- - func.func @fabs(%arg0 : f32, %arg1 : f32) -> () { // expected-error @+1 {{expected ':'}} %2 = spirv.CL.fabs %arg0, %arg1 : i32 @@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () { return } +// ----- + +func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () { + // CHECK: spirv.CL.s_abs {{%.*}} : vector<5xi32> + %2 = spirv.CL.s_abs %arg0 : vector<5xi32> + return +} + func.func @sabsi64(%arg0 : i64) -> () { // CHECK: spirv.CL.s_abs {{%.*}} : i64 %2 = spirv.CL.s_abs %arg0 : i64 @@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () { return } -// ----- -func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} - %2 = spirv.CL.s_abs %arg0 : vector<5xi32> - return -} // ----- diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir index b1ea13c6854fd..90144afc6f3af 100644 --- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir @@ -6,9 +6,9 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.FMul %arg0, %arg1 : f32 spirv.Return } - spirv.func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<4xf32> - %0 = spirv.FAdd %arg0, %arg1 : vector<4xf32> + spirv.func @fadd(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) "None" { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<5xf32> + %0 = spirv.FAdd %arg0, %arg1 : vector<5xf32> spirv.Return } spirv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir index 9a2e4cf62e370..31a7f616d648e 100644 --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -39,6 +39,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce) "None" { + // CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32> + %0 = spirv.CL.fabs %arg0 : vector<5000xf32> + spirv.Return + } + spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { // CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 %13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32