diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c75d217663a9e..25e6a080642e6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -43,9 +43,13 @@ using namespace mlir; template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + const spirv::SPIRVType::ExtensionArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) { + return llvm::is_contained(ors, elidedExt); + })) continue; LLVM_DEBUG({ @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + const spirv::SPIRVType::CapabilityArrayRefVector &candidates, + const ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) { + return llvm::is_contained(ors, elidedCap); + })) continue; LLVM_DEBUG({ @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } -/// Returns true if the given `storageClass` needs explicit layout when used in -/// Shader environments. +/// Check capabilities and extensions requirements +/// Checks that `capCandidates`, `extCandidates`, and capability +/// (`capCandidates`) infered extension requirements are possible to be +/// satisfied with the given `targetEnv`. +/// It also provides a way to relax requirements for certain capabilities and +/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to +/// allow passes to relax certain requirements based on an option (e.g., +/// relaxing bitwidth requirement, see `convertScalarType()`, +/// `ConvertVectorType()`). +template +static LogicalResult checkCapabilityAndExtensionRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates, + const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates, + const ArrayRef elidedCapCandidates = {}, + const ArrayRef elidedExtCandidates = {}) { + SmallVector, 8> updatedExtCandidates; + llvm::append_range(updatedExtCandidates, extCandidates); + + if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates, + elidedCapCandidates))) + return failure(); + // Add capablity infered extensions to the list of extension requirement list, + // only considers the capabilities that already available in the `targetEnv`. + + // WARNING: Some capabilities are part of both the core SPIR-V + // specification and an extension (e.g., 'Groups' capability is part of both + // core specification and SPV_AMD_shader_ballot extension, hence we should + // relax the capability inferred extension for these cases). + static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups}; + ArrayRef multiModalCapsArrayRef(multiModalCaps, + std::size(multiModalCaps)); + + for (auto cap : targetEnv.getAttr().getCapabilities()) { + if (llvm::any_of(multiModalCapsArrayRef, + [&cap](spirv::Capability mMCap) { return cap == mMCap; })) + continue; + std::optional> ext = getExtensions(cap); + if (ext) + updatedExtCandidates.push_back(*ext); + } + if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates, + elidedExtCandidates))) + return failure(); + return success(); +} + +/// Returns true if the given `storageClass` needs explicit layout when used +/// in Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { switch (storageClass) { case spirv::StorageClass::PhysicalStorageBuffer: @@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, type.getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions))) return type; // Otherwise we need to adjust the type, which really means adjusting the @@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv, cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); + // If the bit-width related capabilities and extensions are not met + // for lower bit-width (<32-bit), convert it to 32-bit + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); + if (!elementType) + return nullptr; + type = VectorType::get(type.getShape(), elementType); + + SmallVector elidedCaps; + SmallVector elidedExts; + + // Relax the bitwidth requirements for capabilities and extensions + if (options.emulateLT32BitScalarTypes) { + elidedCaps.push_back(spirv::Capability::Int8); + elidedCaps.push_back(spirv::Capability::Int16); + elidedCaps.push_back(spirv::Capability::Float16); + } + // For capabilities whose requirements were relaxed, relax requirements for + // the extensions that were infered by those capabilities (e.g., elidedCaps) + for (spirv::Capability cap : elidedCaps) { + std::optional> ext = spirv::getExtensions(cap); + if (ext) + llvm::append_range(elidedExts, *ext); + } // If all requirements are met, then we can accept this type as-is. - if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && - succeeded(checkExtensionRequirements(type, targetEnv, extensions))) + if (succeeded(checkCapabilityAndExtensionRequirements( + type, targetEnv, capabilities, extensions, elidedCaps, elidedExts))) return type; - auto elementType = - convertScalarType(targetEnv, options, scalarType, storageClass); - if (elementType) - return VectorType::get(type.getShape(), elementType); return nullptr; } @@ -656,8 +731,9 @@ std::optional castToSourceType(const spirv::TargetEnv &targetEnv, SmallVector, 2> caps; scalarType.getExtensions(exts); scalarType.getCapabilities(caps); - if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || - failed(checkExtensionRequirements(type, targetEnv, exts))) { + + if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps, + exts))) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } @@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { - typeExtensions.clear(); - cast(valueType).getExtensions(typeExtensions); - if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, - typeExtensions))) - return false; - typeCapabilities.clear(); cast(valueType).getCapabilities(typeCapabilities); - if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, - typeCapabilities))) + typeExtensions.clear(); + cast(valueType).getExtensions(typeExtensions); + // Checking for capability and extension requirements along with capability + // infered extensions. + // If a capability is present, the extension that + // supports it should also be present, this reduces the burden of adding + // extension requirement that may or maynot be added in + // CompositeType::getExtensions(). + if (failed(checkCapabilityAndExtensionRequirements( + op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) return false; } diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index 1e5d64387650c..5fc50ae99cfe9 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -2,7 +2,7 @@ module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -22,7 +22,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -42,7 +42,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -62,7 +62,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -82,7 +82,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -102,7 +102,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -122,7 +122,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -142,7 +142,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -162,7 +162,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -182,7 +182,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -202,7 +202,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -222,7 +222,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -242,7 +242,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -262,7 +262,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -282,7 +282,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -302,7 +302,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -322,7 +322,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -342,7 +342,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -362,7 +362,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -382,7 +382,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -402,7 +402,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -422,7 +422,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -442,7 +442,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -462,7 +462,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -482,7 +482,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -502,7 +502,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -522,7 +522,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -542,7 +542,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -562,7 +562,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -582,7 +582,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -602,7 +602,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -622,7 +622,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { gpu.module @kernels { @@ -637,3 +637,22 @@ gpu.module @kernels { } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + // CHECK-NOT: spirv.func @test_unsupported + gpu.func @test_unsupported(%arg : i32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} + %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32) + gpu.return + } +} + +} \ No newline at end of file