From aaeb968f5314bd66402bbbbc3b84549926dd7c6a Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Tue, 26 Sep 2023 14:45:05 -0700 Subject: [PATCH 1/4] [mlir] Add support for vector types whose number of elements are from a range of values Add types and predicates for Vector, Fixed Vector, and Scalable Vector whose number of elements is from a given `allowedRanges` list. The list has two values, start and end of the range (inclusive). --- mlir/include/mlir/IR/CommonTypeConstraints.td | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 4fc14e30b8a10..8e5f2b065d6bb 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType allowedRanks, ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Whether the number of elements of a vector is from the given +// `allowedRanges` list, the list has two values, start and end +// of the range (inclusive). +class IsVectorOfLengthRangePred allowedRanges> + : And<[IsVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive). +class IsFixedVectorOfLengthRangePred allowedRanges> + : And<[IsFixedVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Whether the minimum number of elements of a scalable vector is from the given +// `allowedRanges` list, the list has two values, start and end of the range (inclusive). +class IsScalableVectorOfMinLengthRangePred allowedRanges> + : And<[IsScalableVectorTypePred, + And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, + CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; + +// Any vector where the number of elements is from the given +// `allowedRanges` list. +class VectorOfLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedRanges` list. +class FixedVectorOfLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any scalable vector where the minimum number of elements is from the given +// `allowedRanges` list. +class ScalableVectorOfMinLengthRange allowedRanges> + : Type, + " of length " # !interleave(allowedRanges, "-"), + "::mlir::VectorType">; + +// Any vector where the number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class VectorOfLengthRangeAndType allowedRanges, list allowedTypes> + : Type.predicate, VectorOfLengthRange.predicate]>, + VectorOf.summary # VectorOfLengthRange.summary, + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class FixedVectorOfLengthRangeAndType allowedRanges, list allowedTypes> + : Type< + And<[FixedVectorOf.predicate, FixedVectorOfLengthRange.predicate]>, + FixedVectorOf.summary # FixedVectorOfLengthRange.summary, + "::mlir::VectorType">; + +// Any scalable vector where the minimum number of elements is from the given +// `allowedRanges` list and the type is from the given `allowedTypes` +// list. +class ScalableVectorOfMinLengthRangeAndType allowedRanges, list allowedTypes> + : Type< + And<[ScalableVectorOf.predicate, ScalableVectorOfMinLengthRange.predicate]>, + ScalableVectorOf.summary # ScalableVectorOfMinLengthRange.summary, + "::mlir::VectorType">; + + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; From cce7b285bf61c69e2e4cad3605301b875ba6ac57 Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Mon, 2 Oct 2023 11:15:08 -0700 Subject: [PATCH 2/4] [mlir][spirv] Extend capabilities and extensions requirements checking Allow a way to relax requirements for certain capabilities and extensions (e.g., `elidedCandidates`). Also add a combined check for capabilities and extensions in `checkCapabilityAndExtensionRequirements`. This function checks capabilities, extensions, and capability infered extension requirements. --- .../SPIRV/Transforms/SPIRVConversion.cpp | 67 +++++++++++++++++-- 1 file changed, 61 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c75d217663a9e..7bcd36da0c21e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -43,9 +43,13 @@ using namespace mlir; template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + const spirv::SPIRVType::ExtensionArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) { + return llvm::is_contained(ors, elidedExt); + })) continue; LLVM_DEBUG({ @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + const spirv::SPIRVType::CapabilityArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) { + return llvm::is_contained(ors, elidedCap); + })) continue; LLVM_DEBUG({ @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } -/// Returns true if the given `storageClass` needs explicit layout when used in -/// Shader environments. +/// Check capabilities and extensions requirements +/// Checks that `capCandidates`, `extCandidates`, and capability +/// (`capCandidates`) infered extension requirements are possible to be +/// satisfied with the given `targetEnv`. +/// It also provides a way to relax requirements for certain capabilities and +/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to +/// allow passes to relax certain requirements based on an option (e.g., +/// relaxing bitwidth requirement, see `convertScalarType()`, +/// `ConvertVectorType()`). +template +static LogicalResult checkCapabilityAndExtensionRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates, + const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates, + const ArrayRef elidedCapCandidates = {}, + const ArrayRef elidedExtCandidates = {}) { + SmallVector, 8> updatedExtCandidates; + llvm::append_range(updatedExtCandidates, extCandidates); + + if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates, + elidedCapCandidates))) + return failure(); + // Add capablity infered extensions to the list of extension requirement list, + // only considers the capabilities that already available in the `targetEnv`. + + // WARNING: Some capabilities are part of both the core SPIR-V + // specification and an extension (e.g., 'Groups' capability is part of both + // core specification and SPV_AMD_shader_ballot extension, hence we should + // relax the capability inferred extension for these cases). + static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups}; + ArrayRef multiModalCapsArrayRef(multiModalCaps, + std::size(multiModalCaps)); + + for (auto cap : targetEnv.getAttr().getCapabilities()) { + if (llvm::any_of(multiModalCapsArrayRef, + [&cap](spirv::Capability mMCap) { return cap == mMCap; })) + continue; + std::optional> ext = getExtensions(cap); + if (ext) + updatedExtCandidates.push_back(*ext); + } + if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates, + elidedExtCandidates))) + return failure(); + return success(); +} + +/// Returns true if the given `storageClass` needs explicit layout when used +/// in Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { switch (storageClass) { case spirv::StorageClass::PhysicalStorageBuffer: From 4ae99c664012e95d5a0f7f8c7f007e1d65954fda Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Mon, 2 Oct 2023 11:16:43 -0700 Subject: [PATCH 3/4] [mlir][spirv] Use combined-check for type related extension and capability requirements Replace the seperate extension and capability checking with combined check `checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler. Also adds the extra check for capability inferred extension check. Need for capability inferred extension check: If a capability is a requirement, the respective extension that implements it should also become an extension requirement, there were no support for that check, as a result, the extension requirement had to be added separately. This separate requirement addition causes problem when a feature is enabled by multiple capability, and one of the capability is part of an extension. E.g., vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL" capability, however, only "vectorAnyINTEL" has an extension requirement ("SPV_INTEL_vector_compute"). Since the process of adding capability and extension requirement are independent, there is no way, to handle cases like this. Therefore, for cases like this, enable adding capability requirement initially, then do the check for capability inferred extension. --- .../SPIRV/Transforms/SPIRVConversion.cpp | 59 +++++++++---- .../Conversion/GPUToSPIRV/reductions.mlir | 83 ++++++++++++------- 2 files changed, 92 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 7bcd36da0c21e..25e6a080642e6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -285,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, type.getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions))) return type; // Otherwise we need to adjust the type, which really means adjusting the @@ -397,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); + // If the bit-width related capabilities and extensions are not met + // for lower bit-width (<32-bit), convert it to 32-bit + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); + if (!elementType) + return nullptr; + type = VectorType::get(type.getShape(), elementType); + + SmallVector elidedCaps; + SmallVector elidedExts; + + // Relax the bitwidth requirements for capabilities and extensions + if (options.emulateLT32BitScalarTypes) { + elidedCaps.push_back(spirv::Capability::Int8); + elidedCaps.push_back(spirv::Capability::Int16); + elidedCaps.push_back(spirv::Capability::Float16); + } + // For capabilities whose requirements were relaxed, relax requirements for + // the extensions that were infered by those capabilities (e.g., elidedCaps) + for (spirv::Capability cap : elidedCaps) { + std::optional> ext = spirv::getExtensions(cap); + if (ext) + llvm::append_range(elidedExts, *ext); + } // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions, elidedCaps, elidedExts))) return type; - auto elementType = - convertScalarType(targetEnv, options, scalarType, storageClass); - if (elementType) - return VectorType::get(type.getShape(), elementType); return nullptr; } @@ -711,8 +731,9 @@ std::optional castToSourceType(const spirv::TargetEnv &targetEnv, SmallVector, 2> caps; scalarType.getExtensions(exts); scalarType.getCapabilities(caps); - if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || - failed(checkExtensionRequirements(type, targetEnv, exts))) { + + if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps, + exts))) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } @@ -1205,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { - typeExtensions.clear(); - cast(valueType).getExtensions(typeExtensions); - if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, - typeExtensions))) - return false; - typeCapabilities.clear(); cast(valueType).getCapabilities(typeCapabilities); - if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, - typeCapabilities))) + typeExtensions.clear(); + cast(valueType).getExtensions(typeExtensions); + // Checking for capability and extension requirements along with capability + // infered extensions. + // If a capability is present, the extension that + // supports it should also be present, this reduces the burden of adding + // extension requirement that may or maynot be added in + // CompositeType::getExtensions(). + if (failed(checkCapabilityAndExtensionRequirements( + op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) return false; } diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index 1e5d64387650c..5fc50ae99cfe9 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -2,7 +2,7 @@ module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -22,7 +22,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -42,7 +42,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -62,7 +62,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -82,7 +82,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -102,7 +102,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -122,7 +122,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -142,7 +142,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -162,7 +162,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -182,7 +182,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -202,7 +202,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -222,7 +222,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -242,7 +242,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -262,7 +262,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -282,7 +282,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -302,7 +302,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -322,7 +322,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -342,7 +342,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -362,7 +362,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -382,7 +382,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -402,7 +402,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -422,7 +422,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -442,7 +442,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -462,7 +462,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -482,7 +482,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -502,7 +502,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -522,7 +522,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -542,7 +542,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -562,7 +562,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -582,7 +582,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -602,7 +602,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -622,7 +622,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -637,3 +637,22 @@ gpu.module @kernels { } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + // CHECK-NOT: spirv.func @test_unsupported + gpu.func @test_unsupported(%arg : i32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} + %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32) + gpu.return + } +} + +} \ No newline at end of file From 10416587e015304e0e0fcfb910fd6194768104e6 Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Mon, 2 Oct 2023 11:25:52 -0700 Subject: [PATCH 4/4] [mlir][spirv] Add support for VectorAnyINTEL capability Allow vector of any lengths between [2-2^32-1]. VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4. --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +++-- mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +++- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 22 +++++++--- .../arith-to-spirv-unsupported.mlir | 4 +- .../ArithToSPIRV/arith-to-spirv.mlir | 40 ++++++++++++++++++ .../FuncToSPIRV/types-to-spirv.mlir | 17 +++++++- mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 6 +-- mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +- mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 4 +- mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +- mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 42 ++++++++++--------- mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +-- mlir/test/Target/SPIRV/ocl-ops.mlir | 6 +++ 13 files changed, 124 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 1013cbc8ca562..c458a500eb367 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias; def SPIRV_Float32 : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; -def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], +// Remove the vector size restriction. +// Although the vector size can be upto (2^64-1), uint64, +// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose +// for all practical cases. +// Also unsigned is used for the number elements for composite tyeps. +def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. @@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType allowedTypes> : "Joint Matrix">; class SPIRV_ScalarOrVectorOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; + AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>; class SPIRV_ScalarOrVectorOrCoopMatrixOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>, SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index a51d77dda78bf..be85d3c330a88 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); } - if (t.getNumElements() > 4) { + // Number of elements should be between [2 - 2^32 -1], + // since getNumElements() returns an unsigned, the upper limit check is + // unnecessary. + if (t.getNumElements() < 2) { parser.emitError( - typeLoc, "vector length has to be less than or equal to 4 but found ") + typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") << t.getNumElements(); return Type(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 39d6603a46f96..9d39d99b41482 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) { } bool CompositeType::isValid(VectorType type) { - return type.getRank() == 1 && - llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && - llvm::isa(type.getElementType()); + // Number of elements should be between [2 - 2^32 -1], + // since getNumElements() returns an unsigned, the upper limit check is + // unnecessary. + return type.getRank() == 1 && llvm::isa(type.getElementType()) && + type.getNumElements() >= 2; } Type CompositeType::getElementType(unsigned index) const { @@ -171,9 +173,17 @@ void CompositeType::getCapabilities( .Case([&](VectorType type) { auto vecSize = getNumElements(); if (vecSize == 8 || vecSize == 16) { - static const Capability caps[] = {Capability::Vector16}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); + static constexpr Capability caps[] = {Capability::Vector16, + Capability::VectorAnyINTEL}; + capabilities.push_back(caps); + } + // VectorAnyINTEL capability removes the vector size restriction and + // allows the vector size to be up to (2^32-1). + // Vector16 capability allows the vector size to be 8 and 16 + SmallVector allowedVecRange = {2, 3, 4, 8, 16}; + if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) { + static constexpr Capability caps[] = {Capability::VectorAnyINTEL}; + capabilities.push_back(caps); } return llvm::cast(type.getElementType()) .getCapabilities(capabilities, storage); diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 0d92a8e676d85..d61ace8d6876b 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -11,9 +11,9 @@ module attributes { #spirv.vce, #spirv.resource_limits<>> } { -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { +func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) { // expected-error@+1 {{failed to legalize operation 'arith.subi'}} - %1 = arith.subi %arg0, %arg0: vector<5xi32> + %1 = arith.subi %arg0, %arg1: vector<5xi32> return } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 0221e4815a939..6ceeade486efd 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) { } } // end module + +// ----- + +//===----------------------------------------------------------------------===// +// VectorAnyINTEL support +//===----------------------------------------------------------------------===// + +// Check that with VectorAnyINTEL, VectorComputeINTEL capability, +// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed. +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @any_vector +func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) { + // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32> + %0 = arith.subi %arg0, %arg1: vector<16xi32> + return +} + +// CHECK-LABEL: @max_vector +func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) { + // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32> + %0 = arith.subi %arg0, %arg1: vector<4294967295xi32> + return +} + + +// Check float vector types of any size. +// CHECK-LABEL: @float_vector58 +func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16> + %0 = arith.addf %arg0, %arg0: vector<5xf16> + // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64> + %1 = arith.mulf %arg1, %arg1: vector<8xf64> + return +} + +} // end module diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 82d750755ffe2..6f364c5b0875c 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -351,8 +351,21 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { -// CHECK-NOT: spirv.func @large_vector -func.func @large_vector(%arg0: vector<1024xi32>) { return } +// CHECK-NOT: spirv.func @large_vector_unsupported +func.func @large_vector_unsupported(%arg0: vector<1024xi32>) { return } + +} // end module + + +// ----- + +// Check that large vectors are supported with VectorAnyINTEL or VectorComputeINTEL. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK: spirv.func @large_any_vector +func.func @large_any_vector(%arg0: vector<1024xi32>) { return } } // end module diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir index 82a2316f6c784..88a8e507c1993 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseOr %arg0, %arg1 : f16 return %0 : f16 } @@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseXor %arg0, %arg1 : f16 return %0 : f16 } @@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}} %0 = spirv.BitwiseAnd %arg0, %arg1 : f16 return %0 : f16 } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 3683e5b469b17..a95a6001fd204 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 %2 = spirv.GL.Exp %arg0 : vector<5xf32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 53a1015de75bc..6929ef9b21d0e 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { // ----- spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { - // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index 7dc0bd99f54b3..fa4d9e253307d 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 29a4a46136156..24fe2f9458413 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- -func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} - %2 = spirv.CL.exp %arg0 : i32 +func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () { + // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> + %2 = spirv.CL.exp %arg0 : vector<5xf32> return } // ----- -func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.exp %arg0 : vector<5xf32> +func.func @exp(%arg0 : i32) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %2 = spirv.CL.exp %arg0 : i32 return } @@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () { return } +// ----- + +func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () { + // CHECK: spirv.CL.fabs {{%.*}} : vector<5xf32> + %2 = spirv.CL.fabs %arg0 : vector<5xf32> + return +} + func.func @fabsf64(%arg0 : f64) -> () { // CHECK: spirv.CL.fabs {{%.*}} : f64 %2 = spirv.CL.fabs %arg0 : f64 @@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () { // ----- -func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} - %2 = spirv.CL.fabs %arg0 : vector<5xf32> - return -} - -// ----- - func.func @fabs(%arg0 : f32, %arg1 : f32) -> () { // expected-error @+1 {{expected ':'}} %2 = spirv.CL.fabs %arg0, %arg1 : i32 @@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () { return } +// ----- + +func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () { + // CHECK: spirv.CL.s_abs {{%.*}} : vector<5xi32> + %2 = spirv.CL.s_abs %arg0 : vector<5xi32> + return +} + func.func @sabsi64(%arg0 : i64) -> () { // CHECK: spirv.CL.s_abs {{%.*}} : i64 %2 = spirv.CL.s_abs %arg0 : i64 @@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () { return } -// ----- -func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} - %2 = spirv.CL.s_abs %arg0 : vector<5xi32> - return -} // ----- diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir index b1ea13c6854fd..90144afc6f3af 100644 --- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir @@ -6,9 +6,9 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.FMul %arg0, %arg1 : f32 spirv.Return } - spirv.func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<4xf32> - %0 = spirv.FAdd %arg0, %arg1 : vector<4xf32> + spirv.func @fadd(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) "None" { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<5xf32> + %0 = spirv.FAdd %arg0, %arg1 : vector<5xf32> spirv.Return } spirv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir index 9a2e4cf62e370..31a7f616d648e 100644 --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -39,6 +39,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce) "None" { + // CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32> + %0 = spirv.CL.fabs %arg0 : vector<5000xf32> + spirv.Return + } + spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { // CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 %13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32