diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c75d217663a9e..1cc2a49f027c6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -43,9 +43,13 @@ using namespace mlir; template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + const spirv::SPIRVType::ExtensionArrayRefVector &candidates, + ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) { + return llvm::is_contained(ors, elidedExt); + })) continue; LLVM_DEBUG({ @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, - const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + const spirv::SPIRVType::CapabilityArrayRefVector &candidates, + ArrayRef elidedCandidates = {}) { for (const auto &ors : candidates) { - if (targetEnv.allows(ors)) + if (targetEnv.allows(ors) || + llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) { + return llvm::is_contained(ors, elidedCap); + })) continue; LLVM_DEBUG({ @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements( return success(); } -/// Returns true if the given `storageClass` needs explicit layout when used in -/// Shader environments. +/// Check capabilities and extensions requirements +/// Checks that `capCandidates`, `extCandidates`, and capability +/// (`capCandidates`) infered extension requirements are possible to be +/// satisfied with the given `targetEnv`. +/// It also provides a way to relax requirements for certain capabilities and +/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to +/// allow passes to relax certain requirements based on an option (e.g., +/// relaxing bitwidth requirement, see `convertScalarType()`, +/// `ConvertVectorType()`). +template +static LogicalResult checkCapabilityAndExtensionRequirements( + LabelT label, const spirv::TargetEnv &targetEnv, + const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates, + const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates, + ArrayRef elidedCapCandidates = {}, + ArrayRef elidedExtCandidates = {}) { + SmallVector, 8> updatedExtCandidates; + llvm::append_range(updatedExtCandidates, extCandidates); + + if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates, + elidedCapCandidates))) + return failure(); + // Add capablity infered extensions to the list of extension requirement list, + // only considers the capabilities that already available in the `targetEnv`. + + // WARNING: Some capabilities are part of both the core SPIR-V + // specification and an extension (e.g., 'Groups' capability is part of both + // core specification and SPV_AMD_shader_ballot extension, hence we should + // relax the capability inferred extension for these cases). + static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups}; + ArrayRef multiModalCapsArrayRef(multiModalCaps, + std::size(multiModalCaps)); + + for (auto cap : targetEnv.getAttr().getCapabilities()) { + if (llvm::any_of(multiModalCapsArrayRef, + [&cap](spirv::Capability mMCap) { return cap == mMCap; })) + continue; + std::optional> ext = getExtensions(cap); + if (ext) + updatedExtCandidates.push_back(*ext); + } + if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates, + elidedExtCandidates))) + return failure(); + return success(); +} + +/// Returns true if the given `storageClass` needs explicit layout when used +/// in Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { switch (storageClass) { case spirv::StorageClass::PhysicalStorageBuffer: