From 90cf470434d08b64ae73a514de2d19300d1ba31b Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Sat, 7 Jun 2025 15:52:55 +0400 Subject: [PATCH 1/5] fix: array stride validation errors --- .../src/linker/array_stride_fixer.rs | 612 ++++++++++++++++++ crates/rustc_codegen_spirv/src/linker/mod.rs | 7 + 2 files changed, 619 insertions(+) create mode 100644 crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs new file mode 100644 index 0000000000..129494b771 --- /dev/null +++ b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs @@ -0,0 +1,612 @@ +//! Fix ArrayStride decorations for newer SPIR-V versions. +//! +//! Newer SPIR-V versions forbid explicit layouts (ArrayStride decorations) in certain +//! storage classes (Function, Private, Workgroup), but allow them in others +//! (StorageBuffer, Uniform). This module removes ArrayStride decorations from +//! array types that are used in contexts where they're forbidden. + +use rspirv::dr::{Module, Operand}; +use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; +use rustc_data_structures::fx::FxHashSet; + +/// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. +/// This matches the logic from SPIRV-Tools validate_decorations.cpp AllowsLayout function. +fn allows_layout( + storage_class: StorageClass, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) -> bool { + match storage_class { + // Always explicitly laid out + StorageClass::StorageBuffer + | StorageClass::Uniform + | StorageClass::PhysicalStorageBuffer + | StorageClass::PushConstant => true, + + // Never allows layout + StorageClass::UniformConstant => false, + + // Requires explicit capability + StorageClass::Workgroup => has_workgroup_layout_capability, + + // Only forbidden in SPIR-V 1.4+ + StorageClass::Function | StorageClass::Private => spirv_version < (1, 4), + + // Block is used generally and mesh shaders use Offset + StorageClass::Input | StorageClass::Output => true, + + // TODO: Some storage classes in ray tracing use explicit layout + // decorations, but it is not well documented which. For now treat other + // storage classes as allowed to be laid out. + _ => true, + } +} + +/// Remove ArrayStride decorations from array types used in storage classes where +/// newer SPIR-V versions forbid explicit layouts. +pub fn fix_array_stride_decorations(module: &mut Module) { + // Get SPIR-V version from module header + let spirv_version = module + .header + .as_ref() + .map(|h| h.version()) + .unwrap_or((1, 0)); // Default to 1.0 if no header + + // Check for WorkgroupMemoryExplicitLayoutKHR capability + let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { + inst.class.opcode == Op::Capability + && inst.operands.first() + == Some(&Operand::Capability( + Capability::WorkgroupMemoryExplicitLayoutKHR, + )) + }); + + // Find all array types that have ArrayStride decorations + let mut array_types_with_stride = FxHashSet::default(); + for inst in &module.annotations { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + array_types_with_stride.insert(target_id); + } + } + + // Check each array type with ArrayStride to see if it's used in forbidden contexts + let mut array_types_to_fix = FxHashSet::default(); + for &array_type_id in &array_types_with_stride { + if is_array_type_used_in_forbidden_storage_class( + array_type_id, + module, + spirv_version, + has_workgroup_layout_capability, + ) { + array_types_to_fix.insert(array_type_id); + } + } + + // Remove ArrayStride decorations for the problematic types + module.annotations.retain(|inst| { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + !array_types_to_fix.contains(&target_id) + } else { + true + } + }); +} + +/// Check if an array type is used in any variable with a forbidden storage class +fn is_array_type_used_in_forbidden_storage_class( + array_type_id: Word, + module: &Module, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) -> bool { + // Check global variables + for inst in &module.types_global_values { + if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + let storage_class = inst.operands[0].unwrap_storage_class(); + + // Check if this storage class forbids explicit layouts + if !allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + // Check if this variable's type hierarchy contains the array type + if let Some(var_type_id) = inst.result_type { + if type_hierarchy_contains_array_type(var_type_id, array_type_id, module) { + return true; + } + } + } + } + } + + // Check function-local variables + for function in &module.functions { + for block in &function.blocks { + for inst in &block.instructions { + if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + let storage_class = inst.operands[0].unwrap_storage_class(); + + // Check if this storage class forbids explicit layouts + if !allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + // Check if this variable's type hierarchy contains the array type + if let Some(var_type_id) = inst.result_type { + if type_hierarchy_contains_array_type( + var_type_id, + array_type_id, + module, + ) { + return true; + } + } + } + } + } + } + } + + false +} + +/// Check if a type hierarchy contains a specific array type +fn type_hierarchy_contains_array_type( + type_id: Word, + target_array_type_id: Word, + module: &Module, +) -> bool { + if type_id == target_array_type_id { + return true; + } + + // Find the type definition + if let Some(type_inst) = module + .types_global_values + .iter() + .find(|inst| inst.result_id == Some(type_id)) + { + match type_inst.class.opcode { + Op::TypeArray | Op::TypeRuntimeArray => { + // Check element type recursively + if !type_inst.operands.is_empty() { + let element_type = type_inst.operands[0].unwrap_id_ref(); + return type_hierarchy_contains_array_type( + element_type, + target_array_type_id, + module, + ); + } + } + Op::TypeStruct => { + // Check all field types + for operand in &type_inst.operands { + if let Ok(field_type) = operand.id_ref_any().ok_or(()) { + if type_hierarchy_contains_array_type( + field_type, + target_array_type_id, + module, + ) { + return true; + } + } + } + } + Op::TypePointer => { + // Follow pointer to pointee type + if type_inst.operands.len() >= 2 { + let pointee_type = type_inst.operands[1].unwrap_id_ref(); + return type_hierarchy_contains_array_type( + pointee_type, + target_array_type_id, + module, + ); + } + } + _ => {} + } + } + false +} + +#[cfg(test)] +mod tests { + use super::*; + use rspirv::dr::Module; + + // Helper function to assemble SPIR-V from text + fn assemble_spirv(spirv: &str) -> Vec { + use spirv_tools::assembler::{self, Assembler}; + + let assembler = assembler::create(None); + let spv_binary = assembler + .assemble(spirv, assembler::AssemblerOptions::default()) + .expect("Failed to assemble test SPIR-V"); + let contents: &[u8] = spv_binary.as_ref(); + contents.to_vec() + } + + // Helper function to load SPIR-V binary into Module + fn load_spirv(bytes: &[u8]) -> Module { + use rspirv::dr::Loader; + + let mut loader = Loader::new(); + rspirv::binary::parse_bytes(bytes, &mut loader).unwrap(); + loader.module() + } + + // Helper function to count ArrayStride decorations + fn count_array_stride_decorations(module: &Module) -> usize { + module + .annotations + .iter() + .filter(|inst| { + inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + }) + .count() + } + + #[test] + fn test_removes_array_stride_from_workgroup_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for workgroup storage (forbidden in newer SPIR-V) + %ptr_workgroup = OpTypePointer Workgroup %array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be removed + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed from arrays used in Workgroup storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_keeps_array_stride_for_workgroup_with_capability() { + let spirv = r#" + OpCapability Shader + OpCapability WorkgroupMemoryExplicitLayoutKHR + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for workgroup storage (allowed with capability) + %ptr_workgroup = OpTypePointer Workgroup %array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be kept with capability + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept when WorkgroupMemoryExplicitLayoutKHR capability is present + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_keeps_array_stride_for_storage_buffer_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for storage buffer (always allowed) + %ptr_storage_buffer = OpTypePointer StorageBuffer %array_ty + + ; Variables in storage buffer storage class + %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer + + ; ArrayStride decoration that should be kept + OpDecorate %array_ty ArrayStride 4 + OpDecorate %storage_buffer_var DescriptorSet 0 + OpDecorate %storage_buffer_var Binding 0 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept for StorageBuffer storage class + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_handles_runtime_arrays_in_workgroup() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %runtime_array_ty = OpTypeRuntimeArray %u32 + + ; Pointer types for workgroup storage (forbidden) + %ptr_workgroup = OpTypePointer Workgroup %runtime_array_ty + + ; Variables in workgroup storage class + %workgroup_var = OpVariable %ptr_workgroup Workgroup + + ; ArrayStride decoration that should be removed + OpDecorate %runtime_array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed from runtime arrays in Workgroup storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_mixed_storage_classes_removes_problematic_arrays() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %forbidden_array_ty = OpTypeArray %u32 %u32_256 + %allowed_array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for different storage classes + %ptr_workgroup = OpTypePointer Workgroup %forbidden_array_ty + %ptr_storage_buffer = OpTypePointer StorageBuffer %allowed_array_ty + + ; Variables in different storage classes + %workgroup_var = OpVariable %ptr_workgroup Workgroup + %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer + + ; ArrayStride decorations + OpDecorate %forbidden_array_ty ArrayStride 4 + OpDecorate %allowed_array_ty ArrayStride 4 + OpDecorate %storage_buffer_var DescriptorSet 0 + OpDecorate %storage_buffer_var Binding 0 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + assert_eq!(count_array_stride_decorations(&module), 2); + + fix_array_stride_decorations(&mut module); + + // Only the Workgroup array should have its ArrayStride removed + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_nested_structs_and_arrays_in_function_storage() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; ArrayStride decoration that should be removed in SPIR-V 1.4+ + OpDecorate %inner_array_ty ArrayStride 16 + + ; Type declarations + %void = OpTypeVoid + %float = OpTypeFloat 32 + %u32 = OpTypeInt 32 0 + %u32_4 = OpConstant %u32 4 + %inner_array_ty = OpTypeArray %float %u32_4 + %inner_struct_ty = OpTypeStruct %inner_array_ty + %outer_struct_ty = OpTypeStruct %inner_struct_ty + + ; Pointer types for function storage (forbidden in SPIR-V 1.4+) + %ptr_function = OpTypePointer Function %outer_struct_ty + %func_ty = OpTypeFunction %void + + ; Function variable inside function + %main = OpFunction %void None %func_ty + %entry = OpLabel + %function_var = OpVariable %ptr_function Function + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.4 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 4); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed in SPIR-V 1.4+ for Function storage + assert_eq!(count_array_stride_decorations(&module), 0); + } + + #[test] + fn test_function_storage_spirv_13_keeps_decorations() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for function storage + %ptr_function = OpTypePointer Function %array_ty + + ; Function variable + %main = OpFunction %void None %func_ty + %entry = OpLabel + %function_var = OpVariable %ptr_function Function + OpReturn + OpFunctionEnd + + ; ArrayStride decoration that should be kept in SPIR-V 1.3 + OpDecorate %array_ty ArrayStride 4 + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.3 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 3); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be kept in SPIR-V 1.3 for Function storage + assert_eq!(count_array_stride_decorations(&module), 1); + } + + #[test] + fn test_private_storage_spirv_14_removes_decorations() { + let spirv = r#" + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + ; Type declarations + %void = OpTypeVoid + %func_ty = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %u32_256 = OpConstant %u32 256 + %array_ty = OpTypeArray %u32 %u32_256 + + ; Pointer types for private storage + %ptr_private = OpTypePointer Private %array_ty + + ; Variables in private storage class + %private_var = OpVariable %ptr_private Private + + ; ArrayStride decoration that should be removed in SPIR-V 1.4+ + OpDecorate %array_ty ArrayStride 4 + + %main = OpFunction %void None %func_ty + %entry = OpLabel + OpReturn + OpFunctionEnd + "#; + + let bytes = assemble_spirv(spirv); + let mut module = load_spirv(&bytes); + + // Force SPIR-V 1.4 for this test + if let Some(ref mut header) = module.header { + header.set_version(1, 4); + } + + assert_eq!(count_array_stride_decorations(&module), 1); + + fix_array_stride_decorations(&mut module); + + // ArrayStride should be removed in SPIR-V 1.4+ for Private storage + assert_eq!(count_array_stride_decorations(&module), 0); + } +} diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index fa69dc8e7f..8e4933da36 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod test; +mod array_stride_fixer; mod dce; mod destructure_composites; mod duplicates; @@ -355,6 +356,12 @@ pub fn link( }); } + // Fix ArrayStride decorations for arrays in storage classes where newer SPIR-V versions forbid explicit layouts + { + let _timer = sess.timer("fix_array_stride_decorations"); + array_stride_fixer::fix_array_stride_decorations(&mut output); + } + // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! { if opts.dce { From 44f5287eb53d875f3f9b25a2a68d9a3f1c71e269 Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:25:19 +0400 Subject: [PATCH 2/5] fixes --- .../src/linker/array_stride_fixer.rs | 942 +++++++++--------- .../src/linker/duplicates.rs | 54 +- crates/rustc_codegen_spirv/src/linker/mod.rs | 4 +- .../function_storage_spirv13_kept.rs | 21 + .../function_storage_spirv13_kept.stderr | 30 + .../mixed_storage_classes.rs | 21 + .../mixed_storage_classes.stderr | 36 + .../nested_structs_function_storage.rs | 34 + .../nested_structs_function_storage.stderr | 45 + .../private_storage_spirv14_removed.rs | 22 + .../private_storage_spirv14_removed.stderr | 22 + .../runtime_arrays_in_workgroup.rs | 22 + .../runtime_arrays_in_workgroup.stderr | 37 + .../storage_buffer_arrays_kept.rs | 17 + .../storage_buffer_arrays_kept.stderr | 31 + .../workgroup_arrays_removed.rs | 21 + .../workgroup_arrays_removed.stderr | 36 + .../workgroup_arrays_with_capability.rs | 22 + .../workgroup_arrays_with_capability.stderr | 39 + 19 files changed, 987 insertions(+), 469 deletions(-) create mode 100644 tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs create mode 100644 tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr create mode 100644 tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs create mode 100644 tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr create mode 100644 tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs create mode 100644 tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr create mode 100644 tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs create mode 100644 tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr create mode 100644 tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs create mode 100644 tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr create mode 100644 tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs create mode 100644 tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs index 129494b771..ca1736af79 100644 --- a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs +++ b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs @@ -1,28 +1,44 @@ -//! Fix ArrayStride decorations for newer SPIR-V versions. +//! Fix `ArrayStride` decorations for newer SPIR-V versions. //! -//! Newer SPIR-V versions forbid explicit layouts (ArrayStride decorations) in certain +//! Newer SPIR-V versions forbid explicit layouts (`ArrayStride` decorations) in certain //! storage classes (Function, Private, Workgroup), but allow them in others -//! (StorageBuffer, Uniform). This module removes ArrayStride decorations from +//! (`StorageBuffer`, Uniform). This module removes `ArrayStride` decorations from //! array types that are used in contexts where they're forbidden. use rspirv::dr::{Module, Operand}; use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; -use rustc_data_structures::fx::FxHashSet; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; + +/// Describes how an array type is used across different storage classes +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ArrayUsagePattern { + /// Array is only used in storage classes that require explicit layout + LayoutRequired, + /// Array is only used in storage classes that forbid explicit layout + LayoutForbidden, + /// Array is used in both types of storage classes (needs specialization) + MixedUsage, + /// Array is not used in any variables (orphaned) + Unused, +} + +/// Context information about array type usage +#[derive(Debug, Clone)] +pub struct ArrayStorageContext { + /// Which storage classes this array type is used in + pub storage_classes: FxHashSet, + /// Whether this array allows or forbids layout in its contexts + pub usage_pattern: ArrayUsagePattern, +} /// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. -/// This matches the logic from SPIRV-Tools validate_decorations.cpp AllowsLayout function. +/// This matches the logic from SPIRV-Tools `validate_decorations.cpp` `AllowsLayout` function. fn allows_layout( storage_class: StorageClass, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, ) -> bool { match storage_class { - // Always explicitly laid out - StorageClass::StorageBuffer - | StorageClass::Uniform - | StorageClass::PhysicalStorageBuffer - | StorageClass::PushConstant => true, - // Never allows layout StorageClass::UniformConstant => false, @@ -32,25 +48,18 @@ fn allows_layout( // Only forbidden in SPIR-V 1.4+ StorageClass::Function | StorageClass::Private => spirv_version < (1, 4), - // Block is used generally and mesh shaders use Offset - StorageClass::Input | StorageClass::Output => true, - - // TODO: Some storage classes in ray tracing use explicit layout - // decorations, but it is not well documented which. For now treat other - // storage classes as allowed to be laid out. + // All other storage classes allow layout by default _ => true, } } -/// Remove ArrayStride decorations from array types used in storage classes where -/// newer SPIR-V versions forbid explicit layouts. -pub fn fix_array_stride_decorations(module: &mut Module) { +/// Comprehensive fix for `ArrayStride` decorations with optional type deduplication +pub fn fix_array_stride_decorations_with_deduplication( + module: &mut Module, + use_context_aware_deduplication: bool, +) { // Get SPIR-V version from module header - let spirv_version = module - .header - .as_ref() - .map(|h| h.version()) - .unwrap_or((1, 0)); // Default to 1.0 if no header + let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); // Default to 1.0 if no header // Check for WorkgroupMemoryExplicitLayoutKHR capability let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { @@ -61,94 +70,139 @@ pub fn fix_array_stride_decorations(module: &mut Module) { )) }); - // Find all array types that have ArrayStride decorations - let mut array_types_with_stride = FxHashSet::default(); - for inst in &module.annotations { - if inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - { - let target_id = inst.operands[0].unwrap_id_ref(); - array_types_with_stride.insert(target_id); - } - } + // Analyze storage class contexts for all array types + let array_contexts = + analyze_array_storage_contexts(module, spirv_version, has_workgroup_layout_capability); - // Check each array type with ArrayStride to see if it's used in forbidden contexts - let mut array_types_to_fix = FxHashSet::default(); - for &array_type_id in &array_types_with_stride { - if is_array_type_used_in_forbidden_storage_class( - array_type_id, + // Handle mixed usage arrays by creating specialized versions + let specializations = create_specialized_array_types(module, &array_contexts); + + // Update references to use appropriate specialized types + if !specializations.is_empty() { + update_references_for_specialized_arrays( module, + &specializations, spirv_version, has_workgroup_layout_capability, - ) { - array_types_to_fix.insert(array_type_id); - } + ); } - // Remove ArrayStride decorations for the problematic types + // Apply context-aware type deduplication if requested + if use_context_aware_deduplication { + crate::linker::duplicates::remove_duplicate_types_with_array_context( + module, + Some(&array_contexts), + ); + } + + // Remove ArrayStride decorations from arrays used in forbidden contexts + remove_array_stride_decorations_for_forbidden_contexts(module, &array_contexts); +} + +/// Remove `ArrayStride` decorations from arrays used in layout-forbidden storage classes +fn remove_array_stride_decorations_for_forbidden_contexts( + module: &mut Module, + array_contexts: &FxHashMap, +) { + // Find array types that should have their ArrayStride decorations removed + // Remove from arrays used in forbidden contexts OR mixed usage that includes forbidden contexts + let arrays_to_remove_stride: FxHashSet = array_contexts + .iter() + .filter_map(|(&id, context)| { + match context.usage_pattern { + // Always remove from arrays used only in forbidden contexts + ArrayUsagePattern::LayoutForbidden => Some(id), + // For mixed usage, remove if it includes forbidden contexts that would cause validation errors + ArrayUsagePattern::MixedUsage => { + // If the array is used in any context that forbids layout, remove the decoration + // This is a conservative approach that prevents validation errors + let has_forbidden_context = context.storage_classes.iter().any(|&sc| { + !allows_layout(sc, (1, 4), false) // Use SPIR-V 1.4 rules for conservative check + }); + + if has_forbidden_context { + Some(id) + } else { + None + } + } + ArrayUsagePattern::LayoutRequired | ArrayUsagePattern::Unused => None, + } + }) + .collect(); + + // Remove ArrayStride decorations for layout-forbidden arrays module.annotations.retain(|inst| { if inst.class.opcode == Op::Decorate && inst.operands.len() >= 2 && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) { let target_id = inst.operands[0].unwrap_id_ref(); - !array_types_to_fix.contains(&target_id) + !arrays_to_remove_stride.contains(&target_id) } else { true } }); } -/// Check if an array type is used in any variable with a forbidden storage class -fn is_array_type_used_in_forbidden_storage_class( - array_type_id: Word, +/// Analyze storage class contexts for all array types in the module +pub fn analyze_array_storage_contexts( module: &Module, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, -) -> bool { - // Check global variables +) -> FxHashMap { + let mut array_contexts: FxHashMap = FxHashMap::default(); + + // Find all array and runtime array types + let mut array_types = FxHashSet::default(); for inst in &module.types_global_values { - if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(result_id) = inst.result_id { + array_types.insert(result_id); + array_contexts.insert(result_id, ArrayStorageContext { + storage_classes: FxHashSet::default(), + usage_pattern: ArrayUsagePattern::Unused, + }); + } + } + } + + // Analyze global variables + for inst in &module.types_global_values { + if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { let storage_class = inst.operands[0].unwrap_storage_class(); - // Check if this storage class forbids explicit layouts - if !allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - // Check if this variable's type hierarchy contains the array type - if let Some(var_type_id) = inst.result_type { + if let Some(var_type_id) = inst.result_type { + // Check if this variable's type hierarchy contains any array types + for &array_type_id in &array_types { if type_hierarchy_contains_array_type(var_type_id, array_type_id, module) { - return true; + if let Some(context) = array_contexts.get_mut(&array_type_id) { + context.storage_classes.insert(storage_class); + } } } } } } - // Check function-local variables + // Analyze function-local variables for function in &module.functions { for block in &function.blocks { for inst in &block.instructions { - if inst.class.opcode == Op::Variable && inst.operands.len() >= 1 { + if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { let storage_class = inst.operands[0].unwrap_storage_class(); - // Check if this storage class forbids explicit layouts - if !allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - // Check if this variable's type hierarchy contains the array type - if let Some(var_type_id) = inst.result_type { + if let Some(var_type_id) = inst.result_type { + // Check if this variable's type hierarchy contains any array types + for &array_type_id in &array_types { if type_hierarchy_contains_array_type( var_type_id, array_type_id, module, ) { - return true; + if let Some(context) = array_contexts.get_mut(&array_type_id) { + context.storage_classes.insert(storage_class); + } } } } @@ -157,7 +211,357 @@ fn is_array_type_used_in_forbidden_storage_class( } } - false + // Determine usage patterns + for context in array_contexts.values_mut() { + if context.storage_classes.is_empty() { + context.usage_pattern = ArrayUsagePattern::Unused; + } else { + let mut requires_layout = false; + let mut forbids_layout = false; + + for &storage_class in &context.storage_classes { + if allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + requires_layout = true; + } else { + forbids_layout = true; + } + } + + context.usage_pattern = match (requires_layout, forbids_layout) { + (true, true) => ArrayUsagePattern::MixedUsage, + (true, false) => ArrayUsagePattern::LayoutRequired, + (false, true) => ArrayUsagePattern::LayoutForbidden, + (false, false) => ArrayUsagePattern::Unused, // Should not happen + }; + } + } + + array_contexts +} + +/// Create specialized array types for mixed usage scenarios +fn create_specialized_array_types( + module: &mut Module, + array_contexts: &FxHashMap, +) -> FxHashMap { + let mut specializations = FxHashMap::default(); // original_id -> (layout_required_id, layout_forbidden_id) + + // Find arrays that need specialization (mixed usage) + let arrays_to_specialize: Vec = array_contexts + .iter() + .filter_map(|(&id, context)| { + if context.usage_pattern == ArrayUsagePattern::MixedUsage { + Some(id) + } else { + None + } + }) + .collect(); + + if arrays_to_specialize.is_empty() { + return specializations; + } + + // Generate new IDs for specialized types + let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); + + for &original_id in &arrays_to_specialize { + let layout_required_id = next_id; + next_id += 1; + let layout_forbidden_id = next_id; + next_id += 1; + + specializations.insert(original_id, (layout_required_id, layout_forbidden_id)); + } + + // Update the module header bound + if let Some(ref mut header) = module.header { + header.bound = next_id; + } + + // Create specialized array type definitions + let mut new_type_instructions = Vec::new(); + + for &original_id in &arrays_to_specialize { + if let Some((layout_required_id, layout_forbidden_id)) = specializations.get(&original_id) { + // Find the original array type instruction + if let Some(original_inst) = module + .types_global_values + .iter() + .find(|inst| inst.result_id == Some(original_id)) + .cloned() + { + // Create layout-required variant (keeps ArrayStride decorations) + let mut layout_required_inst = original_inst.clone(); + layout_required_inst.result_id = Some(*layout_required_id); + new_type_instructions.push(layout_required_inst); + + // Create layout-forbidden variant (will have ArrayStride decorations removed later) + let mut layout_forbidden_inst = original_inst.clone(); + layout_forbidden_inst.result_id = Some(*layout_forbidden_id); + new_type_instructions.push(layout_forbidden_inst); + } + } + } + + // IMPORTANT: Do not add the specialized arrays to the end - this would create forward references + // Instead, we need to insert them in the correct position to maintain SPIR-V type ordering + + // Find the insertion point: after the last original array type that needs specialization + // This ensures all specialized arrays are defined before any types that might reference them + let mut insertion_point = 0; + for (i, inst) in module.types_global_values.iter().enumerate() { + if let Some(result_id) = inst.result_id { + if arrays_to_specialize.contains(&result_id) { + insertion_point = i + 1; + } + } + } + + // Insert the specialized array types at the calculated position + // This maintains the invariant that referenced types appear before referencing types + for (i, new_inst) in new_type_instructions.into_iter().enumerate() { + module + .types_global_values + .insert(insertion_point + i, new_inst); + } + + specializations +} + +/// Update all references to specialized array types based on storage class context +fn update_references_for_specialized_arrays( + module: &mut Module, + specializations: &FxHashMap, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) { + // Update struct types that contain specialized arrays + // This is safe now because all specialized arrays have been properly positioned in the types section + for inst in &mut module.types_global_values { + if inst.class.opcode == Op::TypeStruct { + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&(layout_required_id, _layout_forbidden_id)) = + specializations.get(&referenced_id) + { + // For struct types, we use the layout-required variant since structs + // can be used in both layout-required and layout-forbidden contexts + *operand = Operand::IdRef(layout_required_id); + } + } + } + } + } + + // Collect all existing pointer types that reference specialized arrays FIRST + let mut existing_pointers_to_specialize = Vec::new(); + for inst in &module.types_global_values { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let pointee_type = inst.operands[1].unwrap_id_ref(); + if specializations.contains_key(&pointee_type) { + existing_pointers_to_specialize.push(inst.clone()); + } + } + } + + // Create ALL specialized pointer types from the collected existing ones + let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); + let mut new_pointer_instructions = Vec::new(); + let mut pointer_type_mappings = FxHashMap::default(); // old_pointer_id -> new_pointer_id + + // Create new pointer types for each storage class context + for inst in &existing_pointers_to_specialize { + let storage_class = inst.operands[0].unwrap_storage_class(); + let pointee_type = inst.operands[1].unwrap_id_ref(); + + if let Some(&(layout_required_id, layout_forbidden_id)) = specializations.get(&pointee_type) + { + let allows_layout_for_sc = allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ); + + // Create new pointer type pointing to appropriate specialized array + let target_array_id = if allows_layout_for_sc { + layout_required_id + } else { + layout_forbidden_id + }; + + let mut new_pointer_inst = inst.clone(); + new_pointer_inst.result_id = Some(next_id); + new_pointer_inst.operands[1] = Operand::IdRef(target_array_id); + new_pointer_instructions.push(new_pointer_inst); + + // Map old pointer to new pointer + if let Some(old_pointer_id) = inst.result_id { + pointer_type_mappings.insert(old_pointer_id, next_id); + } + next_id += 1; + } + } + + // Update module header bound to account for the new pointer types + if let Some(ref mut header) = module.header { + header.bound = next_id; + } + + // Insert new pointer type instructions in the correct position + // They must come after the specialized arrays they reference, but before any variables that use them + + // Find the last specialized array position to ensure pointers come after their pointee types + let mut pointer_insertion_point = 0; + for (i, inst) in module.types_global_values.iter().enumerate() { + if let Some(result_id) = inst.result_id { + // Check if this is one of our specialized arrays + if specializations + .values() + .any(|&(req_id, forb_id)| result_id == req_id || result_id == forb_id) + { + pointer_insertion_point = i + 1; + } + } + } + + // Insert the new pointer types at the calculated position + // This ensures they appear after specialized arrays but before variables + for (i, new_pointer_inst) in new_pointer_instructions.into_iter().enumerate() { + module + .types_global_values + .insert(pointer_insertion_point + i, new_pointer_inst); + } + + // Update ALL references to old pointer types throughout the entire module + // This includes variables, function parameters, and all instructions + + // Update global variables and function types + for inst in &mut module.types_global_values { + match inst.class.opcode { + Op::Variable => { + if let Some(var_type_id) = inst.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&var_type_id) { + inst.result_type = Some(new_pointer_id); + } + } + } + Op::TypeFunction => { + // Update function type operands (return type and parameter types) + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { + *operand = Operand::IdRef(new_pointer_id); + } + } + } + } + _ => {} + } + } + + // Update function signatures and local variables + for function in &mut module.functions { + // Update function parameters + for param in &mut function.parameters { + if let Some(param_type_id) = param.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(¶m_type_id) { + param.result_type = Some(new_pointer_id); + } + } + } + + // Update all instructions in function bodies + for block in &mut function.blocks { + for inst in &mut block.instructions { + // Update result type + if let Some(result_type_id) = inst.result_type { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&result_type_id) { + inst.result_type = Some(new_pointer_id); + } + } + + // Update operand references + for operand in &mut inst.operands { + if let Some(referenced_id) = operand.id_ref_any() { + if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { + *operand = Operand::IdRef(new_pointer_id); + } + } + } + } + } + } + + // Remove old pointer type instructions that reference specialized arrays + module.types_global_values.retain(|inst| { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let pointee_type = inst.operands[1].unwrap_id_ref(); + !specializations.contains_key(&pointee_type) + } else { + true + } + }); + + // Remove original array type instructions that were specialized + let arrays_to_remove: FxHashSet = specializations.keys().cloned().collect(); + module.types_global_values.retain(|inst| { + if let Some(result_id) = inst.result_id { + !arrays_to_remove.contains(&result_id) + } else { + true + } + }); + + // STEP 8: Copy ArrayStride decorations from original arrays to layout-required variants + // and remove them from layout-forbidden variants + let mut decorations_to_add = Vec::new(); + let layout_forbidden_arrays: FxHashSet = specializations + .values() + .map(|&(_, layout_forbidden_id)| layout_forbidden_id) + .collect(); + + // Find existing ArrayStride decorations on original arrays and copy them to layout-required variants + for inst in &module.annotations { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + if let Some(&(layout_required_id, _)) = specializations.get(&target_id) { + // Copy the decoration to the layout-required variant + let mut new_decoration = inst.clone(); + new_decoration.operands[0] = Operand::IdRef(layout_required_id); + decorations_to_add.push(new_decoration); + } + } + } + + // Add the copied decorations + module.annotations.extend(decorations_to_add); + + // Remove ArrayStride decorations from layout-forbidden arrays and original arrays + let arrays_to_remove_decorations: FxHashSet = layout_forbidden_arrays + .iter() + .cloned() + .chain(specializations.keys().cloned()) // Also remove from original arrays + .collect(); + + module.annotations.retain(|inst| { + if inst.class.opcode == Op::Decorate + && inst.operands.len() >= 2 + && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) + { + let target_id = inst.operands[0].unwrap_id_ref(); + !arrays_to_remove_decorations.contains(&target_id) + } else { + true + } + }); } /// Check if a type hierarchy contains a specific array type @@ -218,395 +622,3 @@ fn type_hierarchy_contains_array_type( } false } - -#[cfg(test)] -mod tests { - use super::*; - use rspirv::dr::Module; - - // Helper function to assemble SPIR-V from text - fn assemble_spirv(spirv: &str) -> Vec { - use spirv_tools::assembler::{self, Assembler}; - - let assembler = assembler::create(None); - let spv_binary = assembler - .assemble(spirv, assembler::AssemblerOptions::default()) - .expect("Failed to assemble test SPIR-V"); - let contents: &[u8] = spv_binary.as_ref(); - contents.to_vec() - } - - // Helper function to load SPIR-V binary into Module - fn load_spirv(bytes: &[u8]) -> Module { - use rspirv::dr::Loader; - - let mut loader = Loader::new(); - rspirv::binary::parse_bytes(bytes, &mut loader).unwrap(); - loader.module() - } - - // Helper function to count ArrayStride decorations - fn count_array_stride_decorations(module: &Module) -> usize { - module - .annotations - .iter() - .filter(|inst| { - inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - }) - .count() - } - - #[test] - fn test_removes_array_stride_from_workgroup_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for workgroup storage (forbidden in newer SPIR-V) - %ptr_workgroup = OpTypePointer Workgroup %array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be removed - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed from arrays used in Workgroup storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_keeps_array_stride_for_workgroup_with_capability() { - let spirv = r#" - OpCapability Shader - OpCapability WorkgroupMemoryExplicitLayoutKHR - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for workgroup storage (allowed with capability) - %ptr_workgroup = OpTypePointer Workgroup %array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be kept with capability - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept when WorkgroupMemoryExplicitLayoutKHR capability is present - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_keeps_array_stride_for_storage_buffer_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for storage buffer (always allowed) - %ptr_storage_buffer = OpTypePointer StorageBuffer %array_ty - - ; Variables in storage buffer storage class - %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer - - ; ArrayStride decoration that should be kept - OpDecorate %array_ty ArrayStride 4 - OpDecorate %storage_buffer_var DescriptorSet 0 - OpDecorate %storage_buffer_var Binding 0 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept for StorageBuffer storage class - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_handles_runtime_arrays_in_workgroup() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %runtime_array_ty = OpTypeRuntimeArray %u32 - - ; Pointer types for workgroup storage (forbidden) - %ptr_workgroup = OpTypePointer Workgroup %runtime_array_ty - - ; Variables in workgroup storage class - %workgroup_var = OpVariable %ptr_workgroup Workgroup - - ; ArrayStride decoration that should be removed - OpDecorate %runtime_array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed from runtime arrays in Workgroup storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_mixed_storage_classes_removes_problematic_arrays() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %forbidden_array_ty = OpTypeArray %u32 %u32_256 - %allowed_array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for different storage classes - %ptr_workgroup = OpTypePointer Workgroup %forbidden_array_ty - %ptr_storage_buffer = OpTypePointer StorageBuffer %allowed_array_ty - - ; Variables in different storage classes - %workgroup_var = OpVariable %ptr_workgroup Workgroup - %storage_buffer_var = OpVariable %ptr_storage_buffer StorageBuffer - - ; ArrayStride decorations - OpDecorate %forbidden_array_ty ArrayStride 4 - OpDecorate %allowed_array_ty ArrayStride 4 - OpDecorate %storage_buffer_var DescriptorSet 0 - OpDecorate %storage_buffer_var Binding 0 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - assert_eq!(count_array_stride_decorations(&module), 2); - - fix_array_stride_decorations(&mut module); - - // Only the Workgroup array should have its ArrayStride removed - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_nested_structs_and_arrays_in_function_storage() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; ArrayStride decoration that should be removed in SPIR-V 1.4+ - OpDecorate %inner_array_ty ArrayStride 16 - - ; Type declarations - %void = OpTypeVoid - %float = OpTypeFloat 32 - %u32 = OpTypeInt 32 0 - %u32_4 = OpConstant %u32 4 - %inner_array_ty = OpTypeArray %float %u32_4 - %inner_struct_ty = OpTypeStruct %inner_array_ty - %outer_struct_ty = OpTypeStruct %inner_struct_ty - - ; Pointer types for function storage (forbidden in SPIR-V 1.4+) - %ptr_function = OpTypePointer Function %outer_struct_ty - %func_ty = OpTypeFunction %void - - ; Function variable inside function - %main = OpFunction %void None %func_ty - %entry = OpLabel - %function_var = OpVariable %ptr_function Function - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.4 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 4); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed in SPIR-V 1.4+ for Function storage - assert_eq!(count_array_stride_decorations(&module), 0); - } - - #[test] - fn test_function_storage_spirv_13_keeps_decorations() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for function storage - %ptr_function = OpTypePointer Function %array_ty - - ; Function variable - %main = OpFunction %void None %func_ty - %entry = OpLabel - %function_var = OpVariable %ptr_function Function - OpReturn - OpFunctionEnd - - ; ArrayStride decoration that should be kept in SPIR-V 1.3 - OpDecorate %array_ty ArrayStride 4 - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.3 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 3); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be kept in SPIR-V 1.3 for Function storage - assert_eq!(count_array_stride_decorations(&module), 1); - } - - #[test] - fn test_private_storage_spirv_14_removes_decorations() { - let spirv = r#" - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint GLCompute %main "main" - OpExecutionMode %main LocalSize 1 1 1 - - ; Type declarations - %void = OpTypeVoid - %func_ty = OpTypeFunction %void - %u32 = OpTypeInt 32 0 - %u32_256 = OpConstant %u32 256 - %array_ty = OpTypeArray %u32 %u32_256 - - ; Pointer types for private storage - %ptr_private = OpTypePointer Private %array_ty - - ; Variables in private storage class - %private_var = OpVariable %ptr_private Private - - ; ArrayStride decoration that should be removed in SPIR-V 1.4+ - OpDecorate %array_ty ArrayStride 4 - - %main = OpFunction %void None %func_ty - %entry = OpLabel - OpReturn - OpFunctionEnd - "#; - - let bytes = assemble_spirv(spirv); - let mut module = load_spirv(&bytes); - - // Force SPIR-V 1.4 for this test - if let Some(ref mut header) = module.header { - header.set_version(1, 4); - } - - assert_eq!(count_array_stride_decorations(&module), 1); - - fix_array_stride_decorations(&mut module); - - // ArrayStride should be removed in SPIR-V 1.4+ for Private storage - assert_eq!(count_array_stride_decorations(&module), 0); - } -} diff --git a/crates/rustc_codegen_spirv/src/linker/duplicates.rs b/crates/rustc_codegen_spirv/src/linker/duplicates.rs index 6b1b45d8cd..7631972001 100644 --- a/crates/rustc_codegen_spirv/src/linker/duplicates.rs +++ b/crates/rustc_codegen_spirv/src/linker/duplicates.rs @@ -117,11 +117,14 @@ fn gather_names(debug_names: &[Instruction]) -> FxHashMap { .collect() } -fn make_dedupe_key( +fn make_dedupe_key_with_array_context( inst: &Instruction, unresolved_forward_pointers: &FxHashSet, annotations: &FxHashMap>, names: &FxHashMap, + array_contexts: Option< + &FxHashMap, + >, ) -> Vec { let mut data = vec![inst.class.opcode as u32]; @@ -169,6 +172,38 @@ fn make_dedupe_key( } } + // For array types, include storage class context in the key to prevent + // inappropriate deduplication between different storage class contexts + if let Some(result_id) = inst.result_id { + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(contexts) = array_contexts { + if let Some(context) = contexts.get(&result_id) { + // Include usage pattern in the key so arrays with different contexts won't deduplicate + let usage_pattern_discriminant = match context.usage_pattern { + crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutRequired => { + 1u32 + } + crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutForbidden => { + 2u32 + } + crate::linker::array_stride_fixer::ArrayUsagePattern::MixedUsage => 3u32, + crate::linker::array_stride_fixer::ArrayUsagePattern::Unused => 4u32, + }; + data.push(usage_pattern_discriminant); + + // Also include the specific storage classes for fine-grained differentiation + let mut storage_classes: Vec = context + .storage_classes + .iter() + .map(|sc| *sc as u32) + .collect(); + storage_classes.sort(); // Ensure deterministic ordering + data.extend(storage_classes); + } + } + } + } + data } @@ -185,6 +220,15 @@ fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &FxHashMap) } pub fn remove_duplicate_types(module: &mut Module) { + remove_duplicate_types_with_array_context(module, None); +} + +pub fn remove_duplicate_types_with_array_context( + module: &mut Module, + array_contexts: Option< + &FxHashMap, + >, +) { // Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module. // When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID. @@ -222,7 +266,13 @@ pub fn remove_duplicate_types(module: &mut Module) { // all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess. rewrite_inst_with_rules(inst, &rewrite_rules); - let key = make_dedupe_key(inst, &unresolved_forward_pointers, &annotations, &names); + let key = make_dedupe_key_with_array_context( + inst, + &unresolved_forward_pointers, + &annotations, + &names, + array_contexts, + ); match key_to_result_id.entry(key) { hash_map::Entry::Vacant(entry) => { diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 8e4933da36..d78525fcff 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -356,10 +356,10 @@ pub fn link( }); } - // Fix ArrayStride decorations for arrays in storage classes where newer SPIR-V versions forbid explicit layouts + // Fix ArrayStride decorations (after storage classes are resolved to avoid conflicts) { let _timer = sess.timer("fix_array_stride_decorations"); - array_stride_fixer::fix_array_stride_decorations(&mut output); + array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output, false); } // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! diff --git a/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs new file mode 100644 index 0000000000..c7219048b2 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.rs @@ -0,0 +1,21 @@ +// Test that ArrayStride decorations are kept for function storage in SPIR-V 1.3 + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-spv1.3 +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], +) { + // Function storage in SPIR-V 1.3 should keep ArrayStride decorations + let mut function_var: [u32; 256] = [0; 256]; + function_var[0] = 42; + function_var[1] = function_var[0] + 1; + // Force the array to be used by writing to output + output[0] = function_var[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr new file mode 100644 index 0000000000..5c2290c73a --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/function_storage_spirv13_kept.stderr @@ -0,0 +1,30 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/function_storage_spirv13_kept.rs" +OpDecorate %4 ArrayStride 4 +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 1 +%4 = OpTypeArray %8 %9 +%10 = OpTypePointer StorageBuffer %4 +%5 = OpTypeStruct %4 +%11 = OpTypePointer StorageBuffer %5 +%3 = OpVariable %11 StorageBuffer +%12 = OpConstant %8 0 +%13 = OpTypeBool +%14 = OpConstant %8 256 +%15 = OpConstant %8 42 +%16 = OpTypePointer StorageBuffer %8 diff --git a/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs new file mode 100644 index 0000000000..9947ce3426 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs @@ -0,0 +1,21 @@ +// Test that mixed storage class usage results in proper ArrayStride handling + +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage_data: &mut [u32; 256], + #[spirv(workgroup)] workgroup_data: &mut [u32; 256], +) { + // Both variables use the same array type [u32; 256] but in different storage classes: + // - storage_data is in StorageBuffer (requires ArrayStride) + // - workgroup_data is in Workgroup (forbids ArrayStride in SPIR-V 1.4+) + + storage_data[0] = 42; + workgroup_data[0] = storage_data[0]; +} diff --git a/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr new file mode 100644 index 0000000000..a6d066e4ae --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr @@ -0,0 +1,36 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 64 1 1 +%2 = OpString "$OPSTRING_FILENAME/mixed_storage_classes.rs" +OpName %4 "workgroup_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 256 +%5 = OpTypeArray %9 %10 +%11 = OpTypePointer StorageBuffer %5 +%6 = OpTypeStruct %5 +%12 = OpTypePointer StorageBuffer %6 +%3 = OpVariable %12 StorageBuffer +%13 = OpConstant %9 0 +%14 = OpTypeBool +%15 = OpTypePointer StorageBuffer %9 +%16 = OpConstant %9 42 +%17 = OpTypePointer Workgroup %9 +%18 = OpTypeArray %9 %10 +%19 = OpTypePointer Workgroup %18 +%4 = OpVariable %19 Workgroup diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs new file mode 100644 index 0000000000..3407f66b90 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.rs @@ -0,0 +1,34 @@ +// Test that ArrayStride decorations are removed from nested structs with arrays in Function storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.2 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[derive(Copy, Clone)] +struct InnerStruct { + data: [f32; 4], +} + +#[derive(Copy, Clone)] +struct OuterStruct { + inner: InnerStruct, +} + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [f32; 1], +) { + // Function-local variables with nested structs containing arrays + // Should have ArrayStride removed in SPIR-V 1.4+ + let mut function_var = OuterStruct { + inner: InnerStruct { data: [0.0; 4] }, + }; + function_var.inner.data[0] = 42.0; + function_var.inner.data[1] = function_var.inner.data[0] + 1.0; + // Force usage to prevent optimization + output[0] = function_var.inner.data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr new file mode 100644 index 0000000000..89a68f3a87 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr @@ -0,0 +1,45 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 +OpExecutionMode %1 LocalSize 1 1 1 +%3 = OpString "$OPSTRING_FILENAME/nested_structs_function_storage.rs" +OpName %4 "InnerStruct" +OpMemberName %4 0 "data" +OpName %5 "OuterStruct" +OpMemberName %5 0 "inner" +OpDecorate %6 ArrayStride 4 +OpDecorate %7 Block +OpMemberDecorate %7 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpMemberDecorate %4 0 Offset 0 +OpMemberDecorate %5 0 Offset 0 +%8 = OpTypeFloat 32 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 1 +%6 = OpTypeArray %8 %10 +%7 = OpTypeStruct %6 +%11 = OpTypePointer StorageBuffer %7 +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypePointer StorageBuffer %6 +%2 = OpVariable %11 StorageBuffer +%15 = OpConstant %9 0 +%16 = OpConstant %9 4 +%17 = OpTypeArray %8 %16 +%4 = OpTypeStruct %17 +%5 = OpTypeStruct %4 +%18 = OpConstant %8 0 +%19 = OpConstantComposite %17 %18 %18 %18 %18 +%20 = OpUndef %5 +%21 = OpTypeBool +%22 = OpConstant %8 1109917696 +%23 = OpConstant %8 1065353216 +%24 = OpTypePointer StorageBuffer %8 diff --git a/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs new file mode 100644 index 0000000000..19de8523be --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are removed from private storage in SPIR-V 1.4 + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-spv1.4 +use spirv_std::spirv; + +// Helper function to create an array in private storage +fn create_private_array() -> [u32; 4] { + [0, 1, 2, 3] +} + +#[spirv(compute(threads(1)))] +pub fn main() { + // This creates a private storage array in SPIR-V 1.4+ + // ArrayStride decorations should be removed + let mut private_array = create_private_array(); + private_array[0] = 42; +} diff --git a/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr new file mode 100644 index 0000000000..2eff9c5bee --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/private_storage_spirv14_removed.stderr @@ -0,0 +1,22 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/private_storage_spirv14_removed.rs" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpConstant %6 4 +%8 = OpTypeArray %6 %7 +%9 = OpTypeFunction %8 +%10 = OpConstant %6 0 +%11 = OpConstant %6 1 +%12 = OpConstant %6 2 +%13 = OpConstant %6 3 +%14 = OpTypeBool diff --git a/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs new file mode 100644 index 0000000000..9c486544ab --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are removed from runtime arrays in Workgroup storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::RuntimeArray; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_array: &mut [u32; 256], +) { + // Workgroup arrays should have ArrayStride removed + shared_array[0] = 42; + shared_array[1] = shared_array[0] + 1; + // Force usage to prevent optimization + output[0] = shared_array[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr new file mode 100644 index 0000000000..42ca0186aa --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/runtime_arrays_in_workgroup.stderr @@ -0,0 +1,37 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 64 1 1 +%2 = OpString "$OPSTRING_FILENAME/runtime_arrays_in_workgroup.rs" +OpName %4 "shared_array" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 1 +%5 = OpTypeArray %9 %10 +%11 = OpTypePointer StorageBuffer %5 +%6 = OpTypeStruct %5 +%12 = OpTypePointer StorageBuffer %6 +%3 = OpVariable %12 StorageBuffer +%13 = OpConstant %9 0 +%14 = OpTypeBool +%15 = OpConstant %9 256 +%16 = OpTypePointer Workgroup %9 +%17 = OpTypeArray %9 %15 +%18 = OpTypePointer Workgroup %17 +%4 = OpVariable %18 Workgroup +%19 = OpConstant %9 42 +%20 = OpTypePointer StorageBuffer %9 diff --git a/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs new file mode 100644 index 0000000000..5b9261a8b0 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.rs @@ -0,0 +1,17 @@ +// Test that ArrayStride decorations are kept for arrays in StorageBuffer storage class + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage_buffer_var: &mut [u32; 256], +) { + // StorageBuffer storage class should keep ArrayStride decorations + storage_buffer_var[0] = 42; +} diff --git a/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr new file mode 100644 index 0000000000..5f0007f120 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/storage_buffer_arrays_kept.stderr @@ -0,0 +1,31 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 1 1 1 +%2 = OpString "$OPSTRING_FILENAME/storage_buffer_arrays_kept.rs" +OpDecorate %4 ArrayStride 4 +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 +OpDecorate %3 Binding 0 +OpDecorate %3 DescriptorSet 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 256 +%4 = OpTypeArray %8 %9 +%10 = OpTypePointer StorageBuffer %4 +%5 = OpTypeStruct %4 +%11 = OpTypePointer StorageBuffer %5 +%3 = OpVariable %11 StorageBuffer +%12 = OpConstant %8 0 +%13 = OpTypeBool +%14 = OpTypePointer StorageBuffer %8 +%15 = OpConstant %8 42 diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs new file mode 100644 index 0000000000..967113c70f --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.rs @@ -0,0 +1,21 @@ +// Test that ArrayStride decorations are removed from arrays in Function storage class (SPIR-V 1.4+) + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.2 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_data: &mut [u32; 256], +) { + // Workgroup storage arrays should have ArrayStride removed + shared_data[0] = 42; + shared_data[1] = shared_data[0] + 1; + // Force usage to prevent optimization + output[0] = shared_data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr new file mode 100644 index 0000000000..3f769cba39 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr @@ -0,0 +1,36 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +%4 = OpString "$OPSTRING_FILENAME/workgroup_arrays_removed.rs" +OpName %3 "shared_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 1 +%5 = OpTypeArray %7 %8 +%6 = OpTypeStruct %5 +%9 = OpTypePointer StorageBuffer %6 +%10 = OpConstant %7 256 +%11 = OpTypeArray %7 %10 +%12 = OpTypePointer Workgroup %11 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %9 StorageBuffer +%16 = OpConstant %7 0 +%17 = OpTypeBool +%18 = OpTypePointer Workgroup %7 +%3 = OpVariable %12 Workgroup +%19 = OpConstant %7 42 +%20 = OpTypePointer StorageBuffer %7 diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs new file mode 100644 index 0000000000..505809003e --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.rs @@ -0,0 +1,22 @@ +// Test that ArrayStride decorations are kept for arrays in Workgroup storage class with WorkgroupMemoryExplicitLayoutKHR capability + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals -Ctarget-feature=+WorkgroupMemoryExplicitLayoutKHR,+ext:SPV_KHR_workgroup_memory_explicit_layout +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" +// only-vulkan1.2 + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1], + #[spirv(workgroup)] shared_data: &mut [u32; 256], +) { + // With WorkgroupMemoryExplicitLayoutKHR capability, ArrayStride should be kept + shared_data[0] = 42; + shared_data[1] = shared_data[0] + 1; + // Force usage to prevent optimization + output[0] = shared_data[1]; +} diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr new file mode 100644 index 0000000000..a645d0bcbd --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_with_capability.stderr @@ -0,0 +1,39 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability WorkgroupMemoryExplicitLayoutKHR +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_workgroup_memory_explicit_layout" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +%4 = OpString "$OPSTRING_FILENAME/workgroup_arrays_with_capability.rs" +OpName %3 "shared_data" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %7 ArrayStride 4 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 1 +%5 = OpTypeArray %8 %9 +%6 = OpTypeStruct %5 +%10 = OpTypePointer StorageBuffer %6 +%11 = OpConstant %8 256 +%7 = OpTypeArray %8 %11 +%12 = OpTypePointer Workgroup %7 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %10 StorageBuffer +%16 = OpConstant %8 0 +%17 = OpTypeBool +%18 = OpTypePointer Workgroup %8 +%3 = OpVariable %12 Workgroup +%19 = OpConstant %8 42 +%20 = OpTypePointer StorageBuffer %8 From 8be822fb115c4a516b02061ada22f9f07831899d Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:44:47 +0400 Subject: [PATCH 3/5] fixer: pt2 --- .../src/linker/array_stride_fixer.rs | 586 ++++++++++++------ crates/rustc_codegen_spirv/src/linker/mod.rs | 2 +- .../workgroup_2d_arrays_issue.rs | 30 + .../workgroup_2d_arrays_issue.stderr | 34 + 4 files changed, 453 insertions(+), 199 deletions(-) create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs create mode 100644 tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.stderr diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs index ca1736af79..f5c53316c3 100644 --- a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs +++ b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs @@ -10,7 +10,7 @@ use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; /// Describes how an array type is used across different storage classes -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ArrayUsagePattern { /// Array is only used in storage classes that require explicit layout LayoutRequired, @@ -29,6 +29,10 @@ pub struct ArrayStorageContext { pub storage_classes: FxHashSet, /// Whether this array allows or forbids layout in its contexts pub usage_pattern: ArrayUsagePattern, + /// Array types that this array contains as elements (for nested arrays) + pub element_arrays: FxHashSet, + /// Array types that contain this array as an element + pub parent_arrays: FxHashSet, } /// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. @@ -53,15 +57,9 @@ fn allows_layout( } } -/// Comprehensive fix for `ArrayStride` decorations with optional type deduplication -pub fn fix_array_stride_decorations_with_deduplication( - module: &mut Module, - use_context_aware_deduplication: bool, -) { - // Get SPIR-V version from module header - let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); // Default to 1.0 if no header - - // Check for WorkgroupMemoryExplicitLayoutKHR capability +/// Comprehensive fix for `ArrayStride` decorations with staged processing architecture +pub fn fix_array_stride_decorations_with_deduplication(module: &mut Module) { + let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { inst.class.opcode == Op::Capability && inst.operands.first() @@ -70,39 +68,52 @@ pub fn fix_array_stride_decorations_with_deduplication( )) }); - // Analyze storage class contexts for all array types + // Analyze all array usage patterns and dependencies let array_contexts = analyze_array_storage_contexts(module, spirv_version, has_workgroup_layout_capability); - // Handle mixed usage arrays by creating specialized versions - let specializations = create_specialized_array_types(module, &array_contexts); + // Create specialized array types when necessary (mixed usage scenarios) + let specializations = create_specialized_array_types( + module, + &array_contexts, + spirv_version, + has_workgroup_layout_capability, + ); - // Update references to use appropriate specialized types + // Update references with full context awareness if !specializations.is_empty() { update_references_for_specialized_arrays( module, &specializations, + &array_contexts, spirv_version, has_workgroup_layout_capability, ); } - // Apply context-aware type deduplication if requested - if use_context_aware_deduplication { - crate::linker::duplicates::remove_duplicate_types_with_array_context( - module, - Some(&array_contexts), - ); - } - - // Remove ArrayStride decorations from arrays used in forbidden contexts - remove_array_stride_decorations_for_forbidden_contexts(module, &array_contexts); + // Remove decorations from layout-forbidden contexts + remove_array_stride_decorations_for_forbidden_contexts( + module, + &array_contexts, + spirv_version, + has_workgroup_layout_capability, + ); + + // Final cleanup and deduplication. + // Always run the context-aware variant so that arrays used in differing + // storage-class contexts are not incorrectly merged. + crate::linker::duplicates::remove_duplicate_types_with_array_context( + module, + Some(&array_contexts), + ); } /// Remove `ArrayStride` decorations from arrays used in layout-forbidden storage classes fn remove_array_stride_decorations_for_forbidden_contexts( module: &mut Module, array_contexts: &FxHashMap, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, ) { // Find array types that should have their ArrayStride decorations removed // Remove from arrays used in forbidden contexts OR mixed usage that includes forbidden contexts @@ -117,7 +128,7 @@ fn remove_array_stride_decorations_for_forbidden_contexts( // If the array is used in any context that forbids layout, remove the decoration // This is a conservative approach that prevents validation errors let has_forbidden_context = context.storage_classes.iter().any(|&sc| { - !allows_layout(sc, (1, 4), false) // Use SPIR-V 1.4 rules for conservative check + !allows_layout(sc, spirv_version, has_workgroup_layout_capability) }); if has_forbidden_context { @@ -146,7 +157,7 @@ fn remove_array_stride_decorations_for_forbidden_contexts( } /// Analyze storage class contexts for all array types in the module -pub fn analyze_array_storage_contexts( +fn analyze_array_storage_contexts( module: &Module, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, @@ -162,11 +173,33 @@ pub fn analyze_array_storage_contexts( array_contexts.insert(result_id, ArrayStorageContext { storage_classes: FxHashSet::default(), usage_pattern: ArrayUsagePattern::Unused, + element_arrays: FxHashSet::default(), + parent_arrays: FxHashSet::default(), }); } } } + // Build parent-child relationships between array types + for inst in &module.types_global_values { + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(parent_id) = inst.result_id { + if !inst.operands.is_empty() { + let element_type = inst.operands[0].unwrap_id_ref(); + // If the element type is also an array, record the parent-child relationship + if array_types.contains(&element_type) { + if let Some(parent_context) = array_contexts.get_mut(&parent_id) { + parent_context.element_arrays.insert(element_type); + } + if let Some(element_context) = array_contexts.get_mut(&element_type) { + element_context.parent_arrays.insert(parent_id); + } + } + } + } + } + } + // Analyze global variables for inst in &module.types_global_values { if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { @@ -211,7 +244,7 @@ pub fn analyze_array_storage_contexts( } } - // Determine usage patterns + // Determine usage patterns with enhanced logic for nested arrays for context in array_contexts.values_mut() { if context.storage_classes.is_empty() { context.usage_pattern = ArrayUsagePattern::Unused; @@ -240,6 +273,50 @@ pub fn analyze_array_storage_contexts( } } + // Propagate context from parent arrays to child arrays for better consistency + // If a parent array is pure workgroup, its child arrays should inherit this context + let mut changed = true; + while changed { + changed = false; + let array_ids: Vec = array_contexts.keys().copied().collect(); + + for &array_id in &array_ids { + let (parent_arrays, current_pattern) = { + let context = &array_contexts[&array_id]; + (context.parent_arrays.clone(), context.usage_pattern) + }; + + // If this array has parent arrays that are pure workgroup, and this array + // doesn't have mixed usage, then it should also be pure workgroup + if current_pattern != ArrayUsagePattern::LayoutForbidden + && current_pattern != ArrayUsagePattern::MixedUsage + { + let has_pure_workgroup_parent = parent_arrays.iter().any(|&parent_id| { + matches!( + array_contexts.get(&parent_id).map(|ctx| ctx.usage_pattern), + Some(ArrayUsagePattern::LayoutForbidden) + ) + }); + + if has_pure_workgroup_parent { + if let Some(context) = array_contexts.get_mut(&array_id) { + // Only inherit if this array doesn't have its own conflicting storage classes + let has_layout_required_usage = context.storage_classes.iter().any(|&sc| { + allows_layout(sc, spirv_version, has_workgroup_layout_capability) + }); + + if !has_layout_required_usage { + context.usage_pattern = ArrayUsagePattern::LayoutForbidden; + // Also inherit the workgroup storage class if not already present + context.storage_classes.insert(StorageClass::Workgroup); + changed = true; + } + } + } + } + } + } + array_contexts } @@ -247,15 +324,41 @@ pub fn analyze_array_storage_contexts( fn create_specialized_array_types( module: &mut Module, array_contexts: &FxHashMap, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, ) -> FxHashMap { let mut specializations = FxHashMap::default(); // original_id -> (layout_required_id, layout_forbidden_id) - // Find arrays that need specialization (mixed usage) + // Find arrays that need specialization (true mixed usage only) + // Be more conservative - only specialize arrays that truly have conflicting requirements let arrays_to_specialize: Vec = array_contexts .iter() .filter_map(|(&id, context)| { + // Only specialize if the array has BOTH layout-required AND layout-forbidden usage + // AND it's not just inheriting context from parents (check own storage classes) if context.usage_pattern == ArrayUsagePattern::MixedUsage { - Some(id) + let mut has_layout_required = false; + let mut has_layout_forbidden = false; + + // Check actual storage classes, not inherited patterns + for &storage_class in &context.storage_classes { + if allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + has_layout_required = true; + } else { + has_layout_forbidden = true; + } + } + + // Only specialize if there's a true conflict in this array's own usage + if has_layout_required && has_layout_forbidden { + Some(id) + } else { + None + } } else { None } @@ -308,224 +411,250 @@ fn create_specialized_array_types( } } - // IMPORTANT: Do not add the specialized arrays to the end - this would create forward references - // Instead, we need to insert them in the correct position to maintain SPIR-V type ordering + // Insert specialized arrays right after their corresponding original arrays + // This maintains proper SPIR-V type ordering + let mut new_types_global_values = Vec::new(); + + for inst in &module.types_global_values { + new_types_global_values.push(inst.clone()); - // Find the insertion point: after the last original array type that needs specialization - // This ensures all specialized arrays are defined before any types that might reference them - let mut insertion_point = 0; - for (i, inst) in module.types_global_values.iter().enumerate() { + // If this is an array that was specialized, add the specialized versions right after if let Some(result_id) = inst.result_id { - if arrays_to_specialize.contains(&result_id) { - insertion_point = i + 1; + if let Some(&(layout_required_id, layout_forbidden_id)) = + specializations.get(&result_id) + { + // Find and add the specialized versions + for new_inst in &new_type_instructions { + if new_inst.result_id == Some(layout_required_id) { + new_types_global_values.push(new_inst.clone()); + break; + } + } + for new_inst in &new_type_instructions { + if new_inst.result_id == Some(layout_forbidden_id) { + new_types_global_values.push(new_inst.clone()); + break; + } + } } } } - // Insert the specialized array types at the calculated position - // This maintains the invariant that referenced types appear before referencing types - for (i, new_inst) in new_type_instructions.into_iter().enumerate() { - module - .types_global_values - .insert(insertion_point + i, new_inst); - } + module.types_global_values = new_types_global_values; specializations } -/// Update all references to specialized array types based on storage class context -fn update_references_for_specialized_arrays( - module: &mut Module, +/// Helper function to select the appropriate specialized variant based on context +fn select_array_variant_for_context( + original_id: Word, + context_storage_classes: &FxHashSet, specializations: &FxHashMap, spirv_version: (u8, u8), has_workgroup_layout_capability: bool, -) { - // Update struct types that contain specialized arrays - // This is safe now because all specialized arrays have been properly positioned in the types section - for inst in &mut module.types_global_values { - if inst.class.opcode == Op::TypeStruct { - for operand in &mut inst.operands { - if let Some(referenced_id) = operand.id_ref_any() { - if let Some(&(layout_required_id, _layout_forbidden_id)) = - specializations.get(&referenced_id) - { - // For struct types, we use the layout-required variant since structs - // can be used in both layout-required and layout-forbidden contexts - *operand = Operand::IdRef(layout_required_id); - } - } - } +) -> Option { + if let Some(&(layout_required_id, layout_forbidden_id)) = specializations.get(&original_id) { + // If context is pure workgroup, use layout_forbidden variant + if context_storage_classes.len() == 1 + && context_storage_classes.contains(&StorageClass::Workgroup) + { + return Some(layout_forbidden_id); } - } - // Collect all existing pointer types that reference specialized arrays FIRST - let mut existing_pointers_to_specialize = Vec::new(); - for inst in &module.types_global_values { - if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { - let pointee_type = inst.operands[1].unwrap_id_ref(); - if specializations.contains_key(&pointee_type) { - existing_pointers_to_specialize.push(inst.clone()); - } + // If context has any layout-forbidden storage classes, use layout_forbidden variant + let has_forbidden_context = context_storage_classes + .iter() + .any(|&sc| !allows_layout(sc, spirv_version, has_workgroup_layout_capability)); + + if has_forbidden_context { + Some(layout_forbidden_id) + } else { + Some(layout_required_id) } + } else { + None + } +} + +/// Update all references to specialized array types and create specialized pointer types +fn update_references_for_specialized_arrays( + module: &mut Module, + specializations: &FxHashMap, + array_contexts: &FxHashMap, + spirv_version: (u8, u8), + has_workgroup_layout_capability: bool, +) { + if specializations.is_empty() { + return; } - // Create ALL specialized pointer types from the collected existing ones let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); + + // Step 1: Create new pointer types for specialized arrays + let mut pointer_rewrite_rules = FxHashMap::default(); // old_pointer_id -> new_pointer_id let mut new_pointer_instructions = Vec::new(); - let mut pointer_type_mappings = FxHashMap::default(); // old_pointer_id -> new_pointer_id - // Create new pointer types for each storage class context - for inst in &existing_pointers_to_specialize { - let storage_class = inst.operands[0].unwrap_storage_class(); - let pointee_type = inst.operands[1].unwrap_id_ref(); + // Collect all pointer types that need updating + for inst in &module.types_global_values { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let storage_class = inst.operands[0].unwrap_storage_class(); + let pointee_type = inst.operands[1].unwrap_id_ref(); - if let Some(&(layout_required_id, layout_forbidden_id)) = specializations.get(&pointee_type) - { - let allows_layout_for_sc = allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ); - - // Create new pointer type pointing to appropriate specialized array - let target_array_id = if allows_layout_for_sc { - layout_required_id - } else { - layout_forbidden_id - }; + if let Some(&(layout_required_id, layout_forbidden_id)) = + specializations.get(&pointee_type) + { + // Choose variant based on storage class + let target_array_id = if allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + layout_required_id + } else { + layout_forbidden_id + }; - let mut new_pointer_inst = inst.clone(); - new_pointer_inst.result_id = Some(next_id); - new_pointer_inst.operands[1] = Operand::IdRef(target_array_id); - new_pointer_instructions.push(new_pointer_inst); + // Create new pointer type + let mut new_pointer_inst = inst.clone(); + new_pointer_inst.result_id = Some(next_id); + new_pointer_inst.operands[1] = Operand::IdRef(target_array_id); + new_pointer_instructions.push(new_pointer_inst); - // Map old pointer to new pointer - if let Some(old_pointer_id) = inst.result_id { - pointer_type_mappings.insert(old_pointer_id, next_id); + pointer_rewrite_rules.insert(inst.result_id.unwrap(), next_id); + next_id += 1; } - next_id += 1; } } - // Update module header bound to account for the new pointer types - if let Some(ref mut header) = module.header { - header.bound = next_id; - } - - // Insert new pointer type instructions in the correct position - // They must come after the specialized arrays they reference, but before any variables that use them + // Step 2: Update struct field and array element references that point to + // original (now specialized) arrays so they reference the appropriate + // specialized variant. - // Find the last specialized array position to ensure pointers come after their pointee types - let mut pointer_insertion_point = 0; - for (i, inst) in module.types_global_values.iter().enumerate() { - if let Some(result_id) = inst.result_id { - // Check if this is one of our specialized arrays - if specializations - .values() - .any(|&(req_id, forb_id)| result_id == req_id || result_id == forb_id) - { - pointer_insertion_point = i + 1; + // 2a) Struct field types + for inst in &mut module.types_global_values { + if inst.class.opcode == Op::TypeStruct { + for op in &mut inst.operands { + if let Some(field_type_id) = op.id_ref_any_mut() { + if let Some(new_id) = select_array_variant_for_context( + *field_type_id, + // We don't have per struct storage class context, but + // any arrays appearing inside a Block decorated struct + // are expected to be in layout-required contexts. + &[StorageClass::StorageBuffer, StorageClass::Uniform] + .iter() + .cloned() + .collect::>(), + specializations, + spirv_version, + has_workgroup_layout_capability, + ) { + *field_type_id = new_id; + } + } } } } - // Insert the new pointer types at the calculated position - // This ensures they appear after specialized arrays but before variables - for (i, new_pointer_inst) in new_pointer_instructions.into_iter().enumerate() { - module - .types_global_values - .insert(pointer_insertion_point + i, new_pointer_inst); - } - - // Update ALL references to old pointer types throughout the entire module - // This includes variables, function parameters, and all instructions - - // Update global variables and function types + // 2b) Array element references for non-specialized arrays for inst in &mut module.types_global_values { - match inst.class.opcode { - Op::Variable => { - if let Some(var_type_id) = inst.result_type { - if let Some(&new_pointer_id) = pointer_type_mappings.get(&var_type_id) { - inst.result_type = Some(new_pointer_id); - } + if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { + if let Some(parent_id) = inst.result_id { + // Skip arrays that were themselves specialized (they should already be correctly set up) + if specializations.contains_key(&parent_id) { + continue; } - } - Op::TypeFunction => { - // Update function type operands (return type and parameter types) - for operand in &mut inst.operands { - if let Some(referenced_id) = operand.id_ref_any() { - if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { - *operand = Operand::IdRef(new_pointer_id); + + // Update element reference based on parent's context + if let Some(element_operand) = inst.operands.get_mut(0) { + if let Some(elem_id) = element_operand.id_ref_any_mut() { + if let Some(parent_context) = array_contexts.get(&parent_id) { + if let Some(new_elem_id) = select_array_variant_for_context( + *elem_id, + &parent_context.storage_classes, + specializations, + spirv_version, + has_workgroup_layout_capability, + ) { + *elem_id = new_elem_id; + } } } } } - _ => {} } } - // Update function signatures and local variables - for function in &mut module.functions { - // Update function parameters - for param in &mut function.parameters { - if let Some(param_type_id) = param.result_type { - if let Some(&new_pointer_id) = pointer_type_mappings.get(¶m_type_id) { - param.result_type = Some(new_pointer_id); - } + // Step 3: Update pointer types in types section to reference specialized arrays + let mut updated_pointer_types = Vec::new(); + for inst in &module.types_global_values { + if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { + let storage_class = inst.operands[0].unwrap_storage_class(); + let pointee_type = inst.operands[1].unwrap_id_ref(); + + if let Some(&(layout_required_id, layout_forbidden_id)) = + specializations.get(&pointee_type) + { + // Choose variant based on storage class + let target_array_id = if allows_layout( + storage_class, + spirv_version, + has_workgroup_layout_capability, + ) { + layout_required_id + } else { + layout_forbidden_id + }; + + // Update the pointer to reference the appropriate specialized array + let mut updated_inst = inst.clone(); + updated_inst.operands[1] = Operand::IdRef(target_array_id); + updated_pointer_types.push((inst.result_id.unwrap(), updated_inst)); } } + } - // Update all instructions in function bodies - for block in &mut function.blocks { - for inst in &mut block.instructions { - // Update result type - if let Some(result_type_id) = inst.result_type { - if let Some(&new_pointer_id) = pointer_type_mappings.get(&result_type_id) { - inst.result_type = Some(new_pointer_id); - } - } - - // Update operand references - for operand in &mut inst.operands { - if let Some(referenced_id) = operand.id_ref_any() { - if let Some(&new_pointer_id) = pointer_type_mappings.get(&referenced_id) { - *operand = Operand::IdRef(new_pointer_id); - } - } + // Apply pointer type updates + for inst in &mut module.types_global_values { + if let Some(result_id) = inst.result_id { + for (old_id, new_inst) in &updated_pointer_types { + if result_id == *old_id { + *inst = new_inst.clone(); + break; } } } } - // Remove old pointer type instructions that reference specialized arrays - module.types_global_values.retain(|inst| { - if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { - let pointee_type = inst.operands[1].unwrap_id_ref(); - !specializations.contains_key(&pointee_type) - } else { - true - } - }); + // Step 4: Keep original arrays (even after specialization) to avoid + // potential forward-reference ordering issues. These original types are + // now unused, but retaining them is harmless and greatly simplifies the + // type-ordering constraints enforced by the SPIR-V -> SPIR-T lowering + // step. + // NOTE: If size becomes a concern, we can revisit this and implement a + // safer removal strategy that preserves correct ordering. - // Remove original array type instructions that were specialized - let arrays_to_remove: FxHashSet = specializations.keys().cloned().collect(); - module.types_global_values.retain(|inst| { - if let Some(result_id) = inst.result_id { - !arrays_to_remove.contains(&result_id) - } else { - true + // Step 5: Add new pointer types to the module + module.types_global_values.extend(new_pointer_instructions); + + // Update module header bound + if let Some(ref mut header) = module.header { + header.bound = next_id; + } + + // Step 6: Apply pointer rewrite rules throughout the module + for inst in module.all_inst_iter_mut() { + if let Some(ref mut id) = inst.result_type { + *id = pointer_rewrite_rules.get(id).copied().unwrap_or(*id); } - }); + for op in &mut inst.operands { + if let Some(id) = op.id_ref_any_mut() { + *id = pointer_rewrite_rules.get(id).copied().unwrap_or(*id); + } + } + } - // STEP 8: Copy ArrayStride decorations from original arrays to layout-required variants - // and remove them from layout-forbidden variants + // Step 7: Handle ArrayStride decorations for specialized arrays let mut decorations_to_add = Vec::new(); - let layout_forbidden_arrays: FxHashSet = specializations - .values() - .map(|&(_, layout_forbidden_id)| layout_forbidden_id) - .collect(); - - // Find existing ArrayStride decorations on original arrays and copy them to layout-required variants for inst in &module.annotations { if inst.class.opcode == Op::Decorate && inst.operands.len() >= 2 @@ -533,22 +662,23 @@ fn update_references_for_specialized_arrays( { let target_id = inst.operands[0].unwrap_id_ref(); if let Some(&(layout_required_id, _)) = specializations.get(&target_id) { - // Copy the decoration to the layout-required variant let mut new_decoration = inst.clone(); new_decoration.operands[0] = Operand::IdRef(layout_required_id); decorations_to_add.push(new_decoration); } } } - - // Add the copied decorations module.annotations.extend(decorations_to_add); - // Remove ArrayStride decorations from layout-forbidden arrays and original arrays + // Remove ArrayStride decorations from original and layout-forbidden arrays + let layout_forbidden_arrays: FxHashSet = specializations + .values() + .map(|&(_, layout_forbidden_id)| layout_forbidden_id) + .collect(); let arrays_to_remove_decorations: FxHashSet = layout_forbidden_arrays .iter() .cloned() - .chain(specializations.keys().cloned()) // Also remove from original arrays + .chain(specializations.keys().cloned()) .collect(); module.annotations.retain(|inst| { @@ -562,6 +692,66 @@ fn update_references_for_specialized_arrays( true } }); + + // Step 8: Rewrite any remaining uses/results of original array IDs to the + // chosen specialized variant (defaulting to the layout-required variant). + if !specializations.is_empty() { + let mut array_default_rewrite: FxHashMap = FxHashMap::default(); + for (&orig, &(layout_required_id, _)) in specializations { + array_default_rewrite.insert(orig, layout_required_id); + } + + for inst in module.all_inst_iter_mut() { + // Skip type declarations themselves – we only want to fix *uses* of the IDs. + if matches!( + inst.class.opcode, + Op::TypeVoid + | Op::TypeBool + | Op::TypeInt + | Op::TypeFloat + | Op::TypeVector + | Op::TypeMatrix + | Op::TypeImage + | Op::TypeSampler + | Op::TypeSampledImage + | Op::TypeArray + | Op::TypeRuntimeArray + | Op::TypeStruct + | Op::TypeOpaque + | Op::TypePointer + | Op::TypeFunction + | Op::TypeEvent + | Op::TypeDeviceEvent + | Op::TypeReserveId + | Op::TypeQueue + | Op::TypePipe + | Op::TypeForwardPointer + ) { + continue; + } + + // Avoid changing the declared type of function parameters and + // composite ops, as they must stay in sync with their value + // operands. + if !matches!( + inst.class.opcode, + Op::FunctionParameter | Op::CompositeInsert | Op::CompositeExtract + ) { + if let Some(ref mut ty) = inst.result_type { + if let Some(&new) = array_default_rewrite.get(ty) { + *ty = new; + } + } + } + for op in &mut inst.operands { + if let Some(id) = op.id_ref_any_mut() { + if let Some(&new) = array_default_rewrite.get(id) { + *id = new; + } + } + } + } + } } /// Check if a type hierarchy contains a specific array type diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index d78525fcff..d8125091e6 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -359,7 +359,7 @@ pub fn link( // Fix ArrayStride decorations (after storage classes are resolved to avoid conflicts) { let _timer = sess.timer("fix_array_stride_decorations"); - array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output, false); + array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output); } // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! diff --git a/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs new file mode 100644 index 0000000000..771b7c9db6 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs @@ -0,0 +1,30 @@ +// Test that reproduces the OpInBoundsAccessChain type mismatch issue +// with workgroup 2D arrays after array_stride_fixer changes + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// only-vulkan1.1 +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/" + +use spirv_std::spirv; + +const TILE_SIZE: u32 = 32; + +#[spirv(compute(threads(32, 32)))] +pub fn transpose_2d_workgroup( + #[spirv(local_invocation_id)] lid: spirv_std::glam::UVec3, + #[spirv(workgroup)] shared_real: &mut [[f32; TILE_SIZE as usize]; TILE_SIZE as usize], + #[spirv(workgroup)] shared_imag: &mut [[f32; TILE_SIZE as usize]; TILE_SIZE as usize], +) { + let lx = lid.x as usize; + let ly = lid.y as usize; + + // This should trigger the OpInBoundsAccessChain issue + shared_real[ly][lx] = 1.0; + shared_imag[ly][lx] = 2.0; + + // Read back to ensure usage + let _val = shared_real[lx][ly] + shared_imag[lx][ly]; +} \ No newline at end of file diff --git a/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.stderr b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.stderr new file mode 100644 index 0000000000..95bac54330 --- /dev/null +++ b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.stderr @@ -0,0 +1,34 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_shader_clock" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "transpose_2d_workgroup" %2 +OpExecutionMode %1 LocalSize 32 32 1 +%3 = OpString "$OPSTRING_FILENAME/workgroup_2d_arrays_issue.rs" +OpName %4 "shared_real" +OpName %5 "shared_imag" +OpDecorate %2 BuiltIn LocalInvocationId +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %6 3 +%8 = OpTypePointer Input %7 +%9 = OpTypeVoid +%10 = OpTypeFunction %9 +%2 = OpVariable %8 Input +%11 = OpTypeBool +%12 = OpConstant %6 32 +%13 = OpTypeFloat 32 +%14 = OpTypeArray %13 %12 +%15 = OpTypePointer Workgroup %14 +%16 = OpTypeArray %14 %12 +%17 = OpTypePointer Workgroup %16 +%4 = OpVariable %17 Workgroup +%18 = OpTypePointer Workgroup %13 +%19 = OpConstant %13 1065353216 +%5 = OpVariable %17 Workgroup +%20 = OpConstant %13 1073741824 From 5d09cd07356f05e33a277e1e3e8487daa5013afc Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:45:05 +0400 Subject: [PATCH 4/5] fix: testing new approach --- crates/rustc_codegen_spirv/src/linker/mod.rs | 6 - .../src/linker/specializer.rs | 258 +++++++++++++++++- .../nested_structs_function_storage.stderr | 37 +-- .../workgroup_2d_arrays_issue.rs | 4 +- .../workgroup_arrays_removed.stderr | 25 +- 5 files changed, 278 insertions(+), 52 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index d8125091e6..ead00e28c5 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -356,12 +356,6 @@ pub fn link( }); } - // Fix ArrayStride decorations (after storage classes are resolved to avoid conflicts) - { - let _timer = sess.timer("fix_array_stride_decorations"); - array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output); - } - // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too! { if opts.dce { diff --git a/crates/rustc_codegen_spirv/src/linker/specializer.rs b/crates/rustc_codegen_spirv/src/linker/specializer.rs index 01c4c9f9b1..1353c68c98 100644 --- a/crates/rustc_codegen_spirv/src/linker/specializer.rs +++ b/crates/rustc_codegen_spirv/src/linker/specializer.rs @@ -57,6 +57,7 @@ use rspirv::spirv::{Op, StorageClass, Word}; use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use smallvec::SmallVec; +use std::cell::RefCell; use std::collections::{BTreeMap, VecDeque}; use std::ops::{Range, RangeTo}; use std::{fmt, io, iter, mem, slice}; @@ -128,11 +129,23 @@ pub fn specialize( .collect(); } + let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); + let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { + inst.class.opcode == Op::Capability + && inst.operands.first() + == Some(&Operand::Capability( + rspirv::spirv::Capability::WorkgroupMemoryExplicitLayoutKHR, + )) + }); + let mut specializer = Specializer { specialization, debug_names, generics: IndexMap::new(), int_consts: FxHashMap::default(), + array_layout_variants: RefCell::new(FxHashMap::default()), + spirv_version, + has_workgroup_layout_capability, }; specializer.collect_generics(&module); @@ -524,6 +537,31 @@ struct Generic { replacements: Replacements, } +/// For every `OpTypeArray` (identified by its defining instruction result ID), we may +/// need two physical variants that differ only by the presence of an `ArrayStride` +/// decoration. +/// +/// * `layout_required_id` – An array type *with* `ArrayStride` attached. This variant +/// must be used whenever the SPIR-V storage class context *requires* explicit +/// layout information (e.g. `StorageBuffer`). In most existing modules this will be +/// the original array type that rust-gpu produced. +/// * `layout_forbidden_id` – An array type *without* `ArrayStride`. This variant must +/// be used in storage classes where explicit layout is disallowed (e.g. +/// `Workgroup`, `Private`, `Function` in SPIR-V ≥ 1.4 without special +/// capabilities). +/// +/// Not every module needs both variants – e.g. an array only ever used in +/// `StorageBuffer` memory does not require the `layout_forbidden_id`. Therefore the +/// two fields are optional and are filled on demand by the storage-class inference +/// pass when it discovers the need for a particular variant. +#[derive(Clone, Debug, Default)] +struct ArrayLayoutVariants { + /// Variant that *keeps* the `ArrayStride` decoration, if it exists. + layout_required_id: Option, + /// Variant that *omits* the `ArrayStride` decoration. + layout_forbidden_id: Option, +} + struct Specializer { specialization: S, @@ -536,6 +574,16 @@ struct Specializer { /// Integer `OpConstant`s (i.e. containing a `LiteralBit32`), to be used /// for interpreting `TyPat::IndexComposite` (such as for `OpAccessChain`). int_consts: FxHashMap, + + /// Lazily-populated map holding, for every original array type ID, the IDs of + /// its specialised layout-`required` / layout-`forbidden` clones (if any). + array_layout_variants: RefCell>, + + /// SPIR-V version of the current module (major, minor) + spirv_version: (u8, u8), + + /// Whether the module enables `WorkgroupMemoryExplicitLayoutKHR` + has_workgroup_layout_capability: bool, } impl Specializer { @@ -680,6 +728,17 @@ impl Specializer { Some(replacements) } } + + /// Return `true` if the given storage class allows explicit layout decorations + /// (i.e. `ArrayStride`). + fn allows_layout(&self, storage_class: StorageClass) -> bool { + match storage_class { + StorageClass::UniformConstant => false, + StorageClass::Workgroup => self.has_workgroup_layout_capability, + StorageClass::Function | StorageClass::Private => self.spirv_version < (1, 4), + _ => true, + } + } } /// Newtype'd inference variable index. @@ -2202,10 +2261,25 @@ struct Expander<'a, S: Specialization> { /// own `replacements` analyzed in order to fully collect all instances. // FIXME(eddyb) fine-tune the length of `SmallVec<[_; 4]>` here. propagate_instances_queue: VecDeque>>, + + /// Snapshot of all original type/global instructions keyed by their result ID. + /// This is necessary because during expansion we temporarily move the original + /// `types_global_values` out of the `Module`, making them inaccessible to helper + /// routines (such as `get_array_variant`). + original_types: FxHashMap, } impl<'a, S: Specialization> Expander<'a, S> { fn new(specializer: &'a Specializer, module: Module) -> Self { + // Build a lookup table of all original type/global instructions before moving + // `module` into the internal `Builder`. + let mut original_types = FxHashMap::default(); + for inst in &module.types_global_values { + if let Some(id) = inst.result_id { + original_types.insert(id, inst.clone()); + } + } + Expander { specializer, @@ -2213,6 +2287,8 @@ impl<'a, S: Specialization> Expander<'a, S> { instances: BTreeMap::new(), propagate_instances_queue: VecDeque::new(), + + original_types, } } @@ -2354,28 +2430,73 @@ impl<'a, S: Specialization> Expander<'a, S> { // Expand `Op(Member)Decorate* %target ...`, when `target` is "generic". let mut expanded_annotations = expand_debug_or_annotation(annotations); - // Expand "generic" globals (types, constants and module-scoped variables). + // Will store any additional type declarations (array variants, etc.) that we might + // need to generate while expanding pointer types. let mut expanded_types_global_values = Vec::with_capacity(types_global_values.len().next_power_of_two()); for inst in types_global_values { if let Some(result_id) = inst.result_id { if let Some(generic) = self.specializer.generics.get(&result_id) { - expanded_types_global_values.extend(self.all_instances_of(result_id).map( - |(instance, &instance_id)| { - let mut expanded_inst = inst.clone(); - expanded_inst.result_id = Some(instance_id); - for (loc, operand) in generic - .replacements - .to_concrete(&instance.generic_args, |i| self.instances[&i]) - { - expanded_inst.index_set(loc, operand.into()); + // Collect instances first to avoid borrowing `self` immutably and mutably at the same time. + let instance_info: Vec<_> = self + .all_instances_of(result_id) + .map(|(inst, &inst_id)| (inst.clone(), inst_id)) + .collect(); + + for (instance, instance_id) in instance_info { + let mut expanded_inst = inst.clone(); + expanded_inst.result_id = Some(instance_id); + for (loc, operand) in generic + .replacements + .to_concrete(&instance.generic_args, |i| self.instances[&i]) + { + expanded_inst.index_set(loc, operand.into()); + } + + // If this is a pointer type now specialized to a concrete storage + // class, ensure its pointee array type respects layout rules. + if expanded_inst.class.opcode == Op::TypePointer { + if let ( + Some(Operand::StorageClass(sc)), + Some(Operand::IdRef(pointee_id)), + ) = ( + expanded_inst.operands.first().cloned(), + expanded_inst.operands.get(1).cloned(), + ) { + let need_no_stride = !self.specializer.allows_layout(sc); + let variant_id = self.get_array_variant( + pointee_id, + need_no_stride, + &mut expanded_types_global_values, + ); + if variant_id != pointee_id { + expanded_inst.operands[1] = Operand::IdRef(variant_id); + } } - expanded_inst - }, - )); + } + + expanded_types_global_values.push(expanded_inst); + } continue; } } + let mut inst = inst; + // Ensure even non-generic `OpTypePointer`s use the proper array variant. + if inst.class.opcode == Op::TypePointer { + if let (Some(Operand::StorageClass(sc)), Some(Operand::IdRef(pointee_id))) = + (inst.operands.first().cloned(), inst.operands.get(1).cloned()) + { + let need_no_stride = !self.specializer.allows_layout(sc); + let variant_id = self.get_array_variant( + pointee_id, + need_no_stride, + &mut expanded_types_global_values, + ); + if variant_id != pointee_id { + inst.operands[1] = Operand::IdRef(variant_id); + } + } + } expanded_types_global_values.push(inst); } @@ -2463,9 +2584,52 @@ impl<'a, S: Specialization> Expander<'a, S> { module.entry_points = entry_points; module.debug_names = expanded_debug_names; module.annotations = expanded_annotations; - module.types_global_values = expanded_types_global_values; + module.types_global_values = expanded_types_global_values.clone(); module.functions = expanded_functions; + // `expanded_types_global_values` already contains any new type declarations (array variants). + module.types_global_values = expanded_types_global_values; + + // Pass 2: rewrite result types of value/composite instructions that still reference the + // *original* (layout-bearing) array types to their stride-less specialised variants. + if !self.specializer.array_layout_variants.borrow().is_empty() { + let mut variant_map = FxHashMap::::default(); + for (orig, variants) in self.specializer.array_layout_variants.borrow().iter() { + if let Some(forbidden) = variants.layout_forbidden_id { + variant_map.insert(*orig, forbidden); + } + } + + // Helper closure: given a mutable instruction, patch its result type if needed. + let patch_inst = |inst: &mut Instruction| { + if let Some(rt) = inst.result_type { + if let Some(&new_id) = variant_map.get(&rt) { + inst.result_type = Some(new_id); + } + } + }; + + // Patch globals (constants etc.). + for inst in &mut module.types_global_values { + patch_inst(inst); + } + + // Patch every instruction inside every function. + for func in &mut module.functions { + if let Some(def) = &mut func.def { + patch_inst(def); + } + for blk in &mut func.blocks { + if let Some(lbl) = &mut blk.label { + patch_inst(lbl); + } + for inst in &mut blk.instructions { + patch_inst(inst); + } + } + } + } + self.builder.module() } @@ -2551,4 +2715,70 @@ impl<'a, S: Specialization> Expander<'a, S> { } Ok(()) } + + /// Obtain (or lazily create) an array type variant that either keeps or removes + /// the `ArrayStride` decoration, according to `need_no_stride`. + /// + /// * `need_no_stride == false` => layout-required/allowed, return original ID. + /// * `need_no_stride == true` => layout-forbidden, create a stride-less clone if + /// one does not yet exist. + fn get_array_variant( + &mut self, + original_id: Word, + need_no_stride: bool, + new_types: &mut Vec, + ) -> Word { + // Only care about true arrays. + // Only care about true arrays. Look up the defining instruction from the cached map, + // as the original `types_global_values` list has been temporarily moved out of the + // `Module` during expansion. + let original_inst_opt = self.original_types.get(&original_id).cloned(); + + let original_inst = match original_inst_opt { + Some(i) if matches!(i.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) => i, + _ => return original_id, + }; + + // Consult shared cache first with a short-lived borrow to avoid conflicts during + // recursive calls further below. + if need_no_stride { + if let Some(id) = self + .specializer + .array_layout_variants + .borrow() + .get(&original_id) + .and_then(|v| v.layout_forbidden_id) + { + return id; + } + } else { + return original_id; + } + + // Create stride-less clone. + let new_id = self.builder.id(); + let mut new_inst = original_inst.clone(); + new_inst.result_id = Some(new_id); + + // Recurse into element type (array of arrays). + if let Some(&Operand::IdRef(elem_ty)) = original_inst.operands.first() { + let nested_variant = self.get_array_variant(elem_ty, need_no_stride, new_types); + if nested_variant != elem_ty { + new_inst.operands[0] = Operand::IdRef(nested_variant); + } + } + + new_types.push(new_inst.clone()); + // Record in the lookup so that future queries can find this variant. + self.original_types.insert(new_id, new_inst.clone()); + + // Update cache entry now (new borrow). + self.specializer + .array_layout_variants + .borrow_mut() + .entry(original_id) + .or_default() + .layout_forbidden_id = Some(new_id); + new_id + } } diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr index 89a68f3a87..c7bfc7b87a 100644 --- a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr @@ -19,27 +19,28 @@ OpDecorate %7 Block OpMemberDecorate %7 0 Offset 0 OpDecorate %2 Binding 0 OpDecorate %2 DescriptorSet 0 +OpDecorate %8 ArrayStride 4 OpMemberDecorate %4 0 Offset 0 OpMemberDecorate %5 0 Offset 0 -%8 = OpTypeFloat 32 -%9 = OpTypeInt 32 0 -%10 = OpConstant %9 1 -%6 = OpTypeArray %8 %10 +%9 = OpTypeFloat 32 +%10 = OpTypeInt 32 0 +%11 = OpConstant %10 1 +%6 = OpTypeArray %9 %11 %7 = OpTypeStruct %6 -%11 = OpTypePointer StorageBuffer %7 -%12 = OpTypeVoid -%13 = OpTypeFunction %12 -%14 = OpTypePointer StorageBuffer %6 -%2 = OpVariable %11 StorageBuffer -%15 = OpConstant %9 0 -%16 = OpConstant %9 4 -%17 = OpTypeArray %8 %16 -%4 = OpTypeStruct %17 +%12 = OpTypePointer StorageBuffer %7 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %6 +%2 = OpVariable %12 StorageBuffer +%16 = OpConstant %10 0 +%17 = OpConstant %10 4 +%8 = OpTypeArray %9 %17 +%4 = OpTypeStruct %8 %5 = OpTypeStruct %4 -%18 = OpConstant %8 0 -%19 = OpConstantComposite %17 %18 %18 %18 %18 +%18 = OpConstant %9 0 +%19 = OpConstantComposite %8 %18 %18 %18 %18 %20 = OpUndef %5 %21 = OpTypeBool -%22 = OpConstant %8 1109917696 -%23 = OpConstant %8 1065353216 -%24 = OpTypePointer StorageBuffer %8 +%22 = OpConstant %9 1109917696 +%23 = OpConstant %9 1065353216 +%24 = OpTypePointer StorageBuffer %9 diff --git a/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs index 771b7c9db6..c01c4108c2 100644 --- a/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs +++ b/tests/ui/linker/array_stride_fixer/workgroup_2d_arrays_issue.rs @@ -24,7 +24,7 @@ pub fn transpose_2d_workgroup( // This should trigger the OpInBoundsAccessChain issue shared_real[ly][lx] = 1.0; shared_imag[ly][lx] = 2.0; - + // Read back to ensure usage let _val = shared_real[lx][ly] + shared_imag[lx][ly]; -} \ No newline at end of file +} diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr index 3f769cba39..e26ac19930 100644 --- a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr @@ -14,23 +14,24 @@ OpName %3 "shared_data" OpDecorate %5 ArrayStride 4 OpDecorate %6 Block OpMemberDecorate %6 0 Offset 0 +OpDecorate %7 ArrayStride 4 OpDecorate %2 Binding 0 OpDecorate %2 DescriptorSet 0 -%7 = OpTypeInt 32 0 -%8 = OpConstant %7 1 -%5 = OpTypeArray %7 %8 +%8 = OpTypeInt 32 0 +%9 = OpConstant %8 1 +%5 = OpTypeArray %8 %9 %6 = OpTypeStruct %5 -%9 = OpTypePointer StorageBuffer %6 -%10 = OpConstant %7 256 -%11 = OpTypeArray %7 %10 -%12 = OpTypePointer Workgroup %11 +%10 = OpTypePointer StorageBuffer %6 +%11 = OpConstant %8 256 +%7 = OpTypeArray %8 %11 +%12 = OpTypePointer Workgroup %7 %13 = OpTypeVoid %14 = OpTypeFunction %13 %15 = OpTypePointer StorageBuffer %5 -%2 = OpVariable %9 StorageBuffer -%16 = OpConstant %7 0 +%2 = OpVariable %10 StorageBuffer +%16 = OpConstant %8 0 %17 = OpTypeBool -%18 = OpTypePointer Workgroup %7 +%18 = OpTypePointer Workgroup %8 %3 = OpVariable %12 Workgroup -%19 = OpConstant %7 42 -%20 = OpTypePointer StorageBuffer %7 +%19 = OpConstant %8 42 +%20 = OpTypePointer StorageBuffer %8 From 7f66f5ab642df567ac278b2c50c7ec4fd39a6760 Mon Sep 17 00:00:00 2001 From: redain <52823087+redain@users.noreply.github.com> Date: Sun, 6 Jul 2025 14:26:33 +0400 Subject: [PATCH 5/5] fixes --- .../src/linker/array_stride_fixer.rs | 814 ------------------ .../src/linker/duplicates.rs | 45 +- crates/rustc_codegen_spirv/src/linker/mod.rs | 1 - .../src/linker/specializer.rs | 151 ++-- .../nested_structs_function_storage.stderr | 54 +- .../workgroup_arrays_removed.stderr | 25 +- 6 files changed, 140 insertions(+), 950 deletions(-) delete mode 100644 crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs diff --git a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs b/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs deleted file mode 100644 index f5c53316c3..0000000000 --- a/crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs +++ /dev/null @@ -1,814 +0,0 @@ -//! Fix `ArrayStride` decorations for newer SPIR-V versions. -//! -//! Newer SPIR-V versions forbid explicit layouts (`ArrayStride` decorations) in certain -//! storage classes (Function, Private, Workgroup), but allow them in others -//! (`StorageBuffer`, Uniform). This module removes `ArrayStride` decorations from -//! array types that are used in contexts where they're forbidden. - -use rspirv::dr::{Module, Operand}; -use rspirv::spirv::{Capability, Decoration, Op, StorageClass, Word}; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; - -/// Describes how an array type is used across different storage classes -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ArrayUsagePattern { - /// Array is only used in storage classes that require explicit layout - LayoutRequired, - /// Array is only used in storage classes that forbid explicit layout - LayoutForbidden, - /// Array is used in both types of storage classes (needs specialization) - MixedUsage, - /// Array is not used in any variables (orphaned) - Unused, -} - -/// Context information about array type usage -#[derive(Debug, Clone)] -pub struct ArrayStorageContext { - /// Which storage classes this array type is used in - pub storage_classes: FxHashSet, - /// Whether this array allows or forbids layout in its contexts - pub usage_pattern: ArrayUsagePattern, - /// Array types that this array contains as elements (for nested arrays) - pub element_arrays: FxHashSet, - /// Array types that contain this array as an element - pub parent_arrays: FxHashSet, -} - -/// Check if a storage class allows explicit layout decorations based on SPIR-V version and capabilities. -/// This matches the logic from SPIRV-Tools `validate_decorations.cpp` `AllowsLayout` function. -fn allows_layout( - storage_class: StorageClass, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) -> bool { - match storage_class { - // Never allows layout - StorageClass::UniformConstant => false, - - // Requires explicit capability - StorageClass::Workgroup => has_workgroup_layout_capability, - - // Only forbidden in SPIR-V 1.4+ - StorageClass::Function | StorageClass::Private => spirv_version < (1, 4), - - // All other storage classes allow layout by default - _ => true, - } -} - -/// Comprehensive fix for `ArrayStride` decorations with staged processing architecture -pub fn fix_array_stride_decorations_with_deduplication(module: &mut Module) { - let spirv_version = module.header.as_ref().map_or((1, 0), |h| h.version()); - let has_workgroup_layout_capability = module.capabilities.iter().any(|inst| { - inst.class.opcode == Op::Capability - && inst.operands.first() - == Some(&Operand::Capability( - Capability::WorkgroupMemoryExplicitLayoutKHR, - )) - }); - - // Analyze all array usage patterns and dependencies - let array_contexts = - analyze_array_storage_contexts(module, spirv_version, has_workgroup_layout_capability); - - // Create specialized array types when necessary (mixed usage scenarios) - let specializations = create_specialized_array_types( - module, - &array_contexts, - spirv_version, - has_workgroup_layout_capability, - ); - - // Update references with full context awareness - if !specializations.is_empty() { - update_references_for_specialized_arrays( - module, - &specializations, - &array_contexts, - spirv_version, - has_workgroup_layout_capability, - ); - } - - // Remove decorations from layout-forbidden contexts - remove_array_stride_decorations_for_forbidden_contexts( - module, - &array_contexts, - spirv_version, - has_workgroup_layout_capability, - ); - - // Final cleanup and deduplication. - // Always run the context-aware variant so that arrays used in differing - // storage-class contexts are not incorrectly merged. - crate::linker::duplicates::remove_duplicate_types_with_array_context( - module, - Some(&array_contexts), - ); -} - -/// Remove `ArrayStride` decorations from arrays used in layout-forbidden storage classes -fn remove_array_stride_decorations_for_forbidden_contexts( - module: &mut Module, - array_contexts: &FxHashMap, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) { - // Find array types that should have their ArrayStride decorations removed - // Remove from arrays used in forbidden contexts OR mixed usage that includes forbidden contexts - let arrays_to_remove_stride: FxHashSet = array_contexts - .iter() - .filter_map(|(&id, context)| { - match context.usage_pattern { - // Always remove from arrays used only in forbidden contexts - ArrayUsagePattern::LayoutForbidden => Some(id), - // For mixed usage, remove if it includes forbidden contexts that would cause validation errors - ArrayUsagePattern::MixedUsage => { - // If the array is used in any context that forbids layout, remove the decoration - // This is a conservative approach that prevents validation errors - let has_forbidden_context = context.storage_classes.iter().any(|&sc| { - !allows_layout(sc, spirv_version, has_workgroup_layout_capability) - }); - - if has_forbidden_context { - Some(id) - } else { - None - } - } - ArrayUsagePattern::LayoutRequired | ArrayUsagePattern::Unused => None, - } - }) - .collect(); - - // Remove ArrayStride decorations for layout-forbidden arrays - module.annotations.retain(|inst| { - if inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - { - let target_id = inst.operands[0].unwrap_id_ref(); - !arrays_to_remove_stride.contains(&target_id) - } else { - true - } - }); -} - -/// Analyze storage class contexts for all array types in the module -fn analyze_array_storage_contexts( - module: &Module, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) -> FxHashMap { - let mut array_contexts: FxHashMap = FxHashMap::default(); - - // Find all array and runtime array types - let mut array_types = FxHashSet::default(); - for inst in &module.types_global_values { - if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { - if let Some(result_id) = inst.result_id { - array_types.insert(result_id); - array_contexts.insert(result_id, ArrayStorageContext { - storage_classes: FxHashSet::default(), - usage_pattern: ArrayUsagePattern::Unused, - element_arrays: FxHashSet::default(), - parent_arrays: FxHashSet::default(), - }); - } - } - } - - // Build parent-child relationships between array types - for inst in &module.types_global_values { - if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { - if let Some(parent_id) = inst.result_id { - if !inst.operands.is_empty() { - let element_type = inst.operands[0].unwrap_id_ref(); - // If the element type is also an array, record the parent-child relationship - if array_types.contains(&element_type) { - if let Some(parent_context) = array_contexts.get_mut(&parent_id) { - parent_context.element_arrays.insert(element_type); - } - if let Some(element_context) = array_contexts.get_mut(&element_type) { - element_context.parent_arrays.insert(parent_id); - } - } - } - } - } - } - - // Analyze global variables - for inst in &module.types_global_values { - if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { - let storage_class = inst.operands[0].unwrap_storage_class(); - - if let Some(var_type_id) = inst.result_type { - // Check if this variable's type hierarchy contains any array types - for &array_type_id in &array_types { - if type_hierarchy_contains_array_type(var_type_id, array_type_id, module) { - if let Some(context) = array_contexts.get_mut(&array_type_id) { - context.storage_classes.insert(storage_class); - } - } - } - } - } - } - - // Analyze function-local variables - for function in &module.functions { - for block in &function.blocks { - for inst in &block.instructions { - if inst.class.opcode == Op::Variable && !inst.operands.is_empty() { - let storage_class = inst.operands[0].unwrap_storage_class(); - - if let Some(var_type_id) = inst.result_type { - // Check if this variable's type hierarchy contains any array types - for &array_type_id in &array_types { - if type_hierarchy_contains_array_type( - var_type_id, - array_type_id, - module, - ) { - if let Some(context) = array_contexts.get_mut(&array_type_id) { - context.storage_classes.insert(storage_class); - } - } - } - } - } - } - } - } - - // Determine usage patterns with enhanced logic for nested arrays - for context in array_contexts.values_mut() { - if context.storage_classes.is_empty() { - context.usage_pattern = ArrayUsagePattern::Unused; - } else { - let mut requires_layout = false; - let mut forbids_layout = false; - - for &storage_class in &context.storage_classes { - if allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - requires_layout = true; - } else { - forbids_layout = true; - } - } - - context.usage_pattern = match (requires_layout, forbids_layout) { - (true, true) => ArrayUsagePattern::MixedUsage, - (true, false) => ArrayUsagePattern::LayoutRequired, - (false, true) => ArrayUsagePattern::LayoutForbidden, - (false, false) => ArrayUsagePattern::Unused, // Should not happen - }; - } - } - - // Propagate context from parent arrays to child arrays for better consistency - // If a parent array is pure workgroup, its child arrays should inherit this context - let mut changed = true; - while changed { - changed = false; - let array_ids: Vec = array_contexts.keys().copied().collect(); - - for &array_id in &array_ids { - let (parent_arrays, current_pattern) = { - let context = &array_contexts[&array_id]; - (context.parent_arrays.clone(), context.usage_pattern) - }; - - // If this array has parent arrays that are pure workgroup, and this array - // doesn't have mixed usage, then it should also be pure workgroup - if current_pattern != ArrayUsagePattern::LayoutForbidden - && current_pattern != ArrayUsagePattern::MixedUsage - { - let has_pure_workgroup_parent = parent_arrays.iter().any(|&parent_id| { - matches!( - array_contexts.get(&parent_id).map(|ctx| ctx.usage_pattern), - Some(ArrayUsagePattern::LayoutForbidden) - ) - }); - - if has_pure_workgroup_parent { - if let Some(context) = array_contexts.get_mut(&array_id) { - // Only inherit if this array doesn't have its own conflicting storage classes - let has_layout_required_usage = context.storage_classes.iter().any(|&sc| { - allows_layout(sc, spirv_version, has_workgroup_layout_capability) - }); - - if !has_layout_required_usage { - context.usage_pattern = ArrayUsagePattern::LayoutForbidden; - // Also inherit the workgroup storage class if not already present - context.storage_classes.insert(StorageClass::Workgroup); - changed = true; - } - } - } - } - } - } - - array_contexts -} - -/// Create specialized array types for mixed usage scenarios -fn create_specialized_array_types( - module: &mut Module, - array_contexts: &FxHashMap, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) -> FxHashMap { - let mut specializations = FxHashMap::default(); // original_id -> (layout_required_id, layout_forbidden_id) - - // Find arrays that need specialization (true mixed usage only) - // Be more conservative - only specialize arrays that truly have conflicting requirements - let arrays_to_specialize: Vec = array_contexts - .iter() - .filter_map(|(&id, context)| { - // Only specialize if the array has BOTH layout-required AND layout-forbidden usage - // AND it's not just inheriting context from parents (check own storage classes) - if context.usage_pattern == ArrayUsagePattern::MixedUsage { - let mut has_layout_required = false; - let mut has_layout_forbidden = false; - - // Check actual storage classes, not inherited patterns - for &storage_class in &context.storage_classes { - if allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - has_layout_required = true; - } else { - has_layout_forbidden = true; - } - } - - // Only specialize if there's a true conflict in this array's own usage - if has_layout_required && has_layout_forbidden { - Some(id) - } else { - None - } - } else { - None - } - }) - .collect(); - - if arrays_to_specialize.is_empty() { - return specializations; - } - - // Generate new IDs for specialized types - let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); - - for &original_id in &arrays_to_specialize { - let layout_required_id = next_id; - next_id += 1; - let layout_forbidden_id = next_id; - next_id += 1; - - specializations.insert(original_id, (layout_required_id, layout_forbidden_id)); - } - - // Update the module header bound - if let Some(ref mut header) = module.header { - header.bound = next_id; - } - - // Create specialized array type definitions - let mut new_type_instructions = Vec::new(); - - for &original_id in &arrays_to_specialize { - if let Some((layout_required_id, layout_forbidden_id)) = specializations.get(&original_id) { - // Find the original array type instruction - if let Some(original_inst) = module - .types_global_values - .iter() - .find(|inst| inst.result_id == Some(original_id)) - .cloned() - { - // Create layout-required variant (keeps ArrayStride decorations) - let mut layout_required_inst = original_inst.clone(); - layout_required_inst.result_id = Some(*layout_required_id); - new_type_instructions.push(layout_required_inst); - - // Create layout-forbidden variant (will have ArrayStride decorations removed later) - let mut layout_forbidden_inst = original_inst.clone(); - layout_forbidden_inst.result_id = Some(*layout_forbidden_id); - new_type_instructions.push(layout_forbidden_inst); - } - } - } - - // Insert specialized arrays right after their corresponding original arrays - // This maintains proper SPIR-V type ordering - let mut new_types_global_values = Vec::new(); - - for inst in &module.types_global_values { - new_types_global_values.push(inst.clone()); - - // If this is an array that was specialized, add the specialized versions right after - if let Some(result_id) = inst.result_id { - if let Some(&(layout_required_id, layout_forbidden_id)) = - specializations.get(&result_id) - { - // Find and add the specialized versions - for new_inst in &new_type_instructions { - if new_inst.result_id == Some(layout_required_id) { - new_types_global_values.push(new_inst.clone()); - break; - } - } - for new_inst in &new_type_instructions { - if new_inst.result_id == Some(layout_forbidden_id) { - new_types_global_values.push(new_inst.clone()); - break; - } - } - } - } - } - - module.types_global_values = new_types_global_values; - - specializations -} - -/// Helper function to select the appropriate specialized variant based on context -fn select_array_variant_for_context( - original_id: Word, - context_storage_classes: &FxHashSet, - specializations: &FxHashMap, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) -> Option { - if let Some(&(layout_required_id, layout_forbidden_id)) = specializations.get(&original_id) { - // If context is pure workgroup, use layout_forbidden variant - if context_storage_classes.len() == 1 - && context_storage_classes.contains(&StorageClass::Workgroup) - { - return Some(layout_forbidden_id); - } - - // If context has any layout-forbidden storage classes, use layout_forbidden variant - let has_forbidden_context = context_storage_classes - .iter() - .any(|&sc| !allows_layout(sc, spirv_version, has_workgroup_layout_capability)); - - if has_forbidden_context { - Some(layout_forbidden_id) - } else { - Some(layout_required_id) - } - } else { - None - } -} - -/// Update all references to specialized array types and create specialized pointer types -fn update_references_for_specialized_arrays( - module: &mut Module, - specializations: &FxHashMap, - array_contexts: &FxHashMap, - spirv_version: (u8, u8), - has_workgroup_layout_capability: bool, -) { - if specializations.is_empty() { - return; - } - - let mut next_id = module.header.as_ref().map_or(1, |h| h.bound); - - // Step 1: Create new pointer types for specialized arrays - let mut pointer_rewrite_rules = FxHashMap::default(); // old_pointer_id -> new_pointer_id - let mut new_pointer_instructions = Vec::new(); - - // Collect all pointer types that need updating - for inst in &module.types_global_values { - if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { - let storage_class = inst.operands[0].unwrap_storage_class(); - let pointee_type = inst.operands[1].unwrap_id_ref(); - - if let Some(&(layout_required_id, layout_forbidden_id)) = - specializations.get(&pointee_type) - { - // Choose variant based on storage class - let target_array_id = if allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - layout_required_id - } else { - layout_forbidden_id - }; - - // Create new pointer type - let mut new_pointer_inst = inst.clone(); - new_pointer_inst.result_id = Some(next_id); - new_pointer_inst.operands[1] = Operand::IdRef(target_array_id); - new_pointer_instructions.push(new_pointer_inst); - - pointer_rewrite_rules.insert(inst.result_id.unwrap(), next_id); - next_id += 1; - } - } - } - - // Step 2: Update struct field and array element references that point to - // original (now specialized) arrays so they reference the appropriate - // specialized variant. - - // 2a) Struct field types - for inst in &mut module.types_global_values { - if inst.class.opcode == Op::TypeStruct { - for op in &mut inst.operands { - if let Some(field_type_id) = op.id_ref_any_mut() { - if let Some(new_id) = select_array_variant_for_context( - *field_type_id, - // We don't have per struct storage class context, but - // any arrays appearing inside a Block decorated struct - // are expected to be in layout-required contexts. - &[StorageClass::StorageBuffer, StorageClass::Uniform] - .iter() - .cloned() - .collect::>(), - specializations, - spirv_version, - has_workgroup_layout_capability, - ) { - *field_type_id = new_id; - } - } - } - } - } - - // 2b) Array element references for non-specialized arrays - for inst in &mut module.types_global_values { - if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { - if let Some(parent_id) = inst.result_id { - // Skip arrays that were themselves specialized (they should already be correctly set up) - if specializations.contains_key(&parent_id) { - continue; - } - - // Update element reference based on parent's context - if let Some(element_operand) = inst.operands.get_mut(0) { - if let Some(elem_id) = element_operand.id_ref_any_mut() { - if let Some(parent_context) = array_contexts.get(&parent_id) { - if let Some(new_elem_id) = select_array_variant_for_context( - *elem_id, - &parent_context.storage_classes, - specializations, - spirv_version, - has_workgroup_layout_capability, - ) { - *elem_id = new_elem_id; - } - } - } - } - } - } - } - - // Step 3: Update pointer types in types section to reference specialized arrays - let mut updated_pointer_types = Vec::new(); - for inst in &module.types_global_values { - if inst.class.opcode == Op::TypePointer && inst.operands.len() >= 2 { - let storage_class = inst.operands[0].unwrap_storage_class(); - let pointee_type = inst.operands[1].unwrap_id_ref(); - - if let Some(&(layout_required_id, layout_forbidden_id)) = - specializations.get(&pointee_type) - { - // Choose variant based on storage class - let target_array_id = if allows_layout( - storage_class, - spirv_version, - has_workgroup_layout_capability, - ) { - layout_required_id - } else { - layout_forbidden_id - }; - - // Update the pointer to reference the appropriate specialized array - let mut updated_inst = inst.clone(); - updated_inst.operands[1] = Operand::IdRef(target_array_id); - updated_pointer_types.push((inst.result_id.unwrap(), updated_inst)); - } - } - } - - // Apply pointer type updates - for inst in &mut module.types_global_values { - if let Some(result_id) = inst.result_id { - for (old_id, new_inst) in &updated_pointer_types { - if result_id == *old_id { - *inst = new_inst.clone(); - break; - } - } - } - } - - // Step 4: Keep original arrays (even after specialization) to avoid - // potential forward-reference ordering issues. These original types are - // now unused, but retaining them is harmless and greatly simplifies the - // type-ordering constraints enforced by the SPIR-V -> SPIR-T lowering - // step. - // NOTE: If size becomes a concern, we can revisit this and implement a - // safer removal strategy that preserves correct ordering. - - // Step 5: Add new pointer types to the module - module.types_global_values.extend(new_pointer_instructions); - - // Update module header bound - if let Some(ref mut header) = module.header { - header.bound = next_id; - } - - // Step 6: Apply pointer rewrite rules throughout the module - for inst in module.all_inst_iter_mut() { - if let Some(ref mut id) = inst.result_type { - *id = pointer_rewrite_rules.get(id).copied().unwrap_or(*id); - } - for op in &mut inst.operands { - if let Some(id) = op.id_ref_any_mut() { - *id = pointer_rewrite_rules.get(id).copied().unwrap_or(*id); - } - } - } - - // Step 7: Handle ArrayStride decorations for specialized arrays - let mut decorations_to_add = Vec::new(); - for inst in &module.annotations { - if inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - { - let target_id = inst.operands[0].unwrap_id_ref(); - if let Some(&(layout_required_id, _)) = specializations.get(&target_id) { - let mut new_decoration = inst.clone(); - new_decoration.operands[0] = Operand::IdRef(layout_required_id); - decorations_to_add.push(new_decoration); - } - } - } - module.annotations.extend(decorations_to_add); - - // Remove ArrayStride decorations from original and layout-forbidden arrays - let layout_forbidden_arrays: FxHashSet = specializations - .values() - .map(|&(_, layout_forbidden_id)| layout_forbidden_id) - .collect(); - let arrays_to_remove_decorations: FxHashSet = layout_forbidden_arrays - .iter() - .cloned() - .chain(specializations.keys().cloned()) - .collect(); - - module.annotations.retain(|inst| { - if inst.class.opcode == Op::Decorate - && inst.operands.len() >= 2 - && inst.operands[1] == Operand::Decoration(Decoration::ArrayStride) - { - let target_id = inst.operands[0].unwrap_id_ref(); - !arrays_to_remove_decorations.contains(&target_id) - } else { - true - } - }); - - // Step 8: Rewrite any remaining uses/results of original array IDs to the - // chosen specialized variant (defaulting to the layout-required variant). - if !specializations.is_empty() { - let mut array_default_rewrite: FxHashMap = FxHashMap::default(); - for (&orig, &(layout_required_id, _)) in specializations { - array_default_rewrite.insert(orig, layout_required_id); - } - - for inst in module.all_inst_iter_mut() { - // Skip type declarations themselves – we only want to fix *uses* of the IDs. - if matches!( - inst.class.opcode, - Op::TypeVoid - | Op::TypeBool - | Op::TypeInt - | Op::TypeFloat - | Op::TypeVector - | Op::TypeMatrix - | Op::TypeImage - | Op::TypeSampler - | Op::TypeSampledImage - | Op::TypeArray - | Op::TypeRuntimeArray - | Op::TypeStruct - | Op::TypeOpaque - | Op::TypePointer - | Op::TypeFunction - | Op::TypeEvent - | Op::TypeDeviceEvent - | Op::TypeReserveId - | Op::TypeQueue - | Op::TypePipe - | Op::TypeForwardPointer - ) { - continue; - } - - // Avoid changing the declared type of function parameters and - // composite ops, as they must stay in sync with their value - // operands. - if !matches!( - inst.class.opcode, - Op::FunctionParameter | Op::CompositeInsert | Op::CompositeExtract - ) { - if let Some(ref mut ty) = inst.result_type { - if let Some(&new) = array_default_rewrite.get(ty) { - *ty = new; - } - } - } - for op in &mut inst.operands { - if let Some(id) = op.id_ref_any_mut() { - if let Some(&new) = array_default_rewrite.get(id) { - *id = new; - } - } - } - } - } -} - -/// Check if a type hierarchy contains a specific array type -fn type_hierarchy_contains_array_type( - type_id: Word, - target_array_type_id: Word, - module: &Module, -) -> bool { - if type_id == target_array_type_id { - return true; - } - - // Find the type definition - if let Some(type_inst) = module - .types_global_values - .iter() - .find(|inst| inst.result_id == Some(type_id)) - { - match type_inst.class.opcode { - Op::TypeArray | Op::TypeRuntimeArray => { - // Check element type recursively - if !type_inst.operands.is_empty() { - let element_type = type_inst.operands[0].unwrap_id_ref(); - return type_hierarchy_contains_array_type( - element_type, - target_array_type_id, - module, - ); - } - } - Op::TypeStruct => { - // Check all field types - for operand in &type_inst.operands { - if let Ok(field_type) = operand.id_ref_any().ok_or(()) { - if type_hierarchy_contains_array_type( - field_type, - target_array_type_id, - module, - ) { - return true; - } - } - } - } - Op::TypePointer => { - // Follow pointer to pointee type - if type_inst.operands.len() >= 2 { - let pointee_type = type_inst.operands[1].unwrap_id_ref(); - return type_hierarchy_contains_array_type( - pointee_type, - target_array_type_id, - module, - ); - } - } - _ => {} - } - } - false -} diff --git a/crates/rustc_codegen_spirv/src/linker/duplicates.rs b/crates/rustc_codegen_spirv/src/linker/duplicates.rs index 7631972001..fe637e6fe9 100644 --- a/crates/rustc_codegen_spirv/src/linker/duplicates.rs +++ b/crates/rustc_codegen_spirv/src/linker/duplicates.rs @@ -122,9 +122,6 @@ fn make_dedupe_key_with_array_context( unresolved_forward_pointers: &FxHashSet, annotations: &FxHashMap>, names: &FxHashMap, - array_contexts: Option< - &FxHashMap, - >, ) -> Vec { let mut data = vec![inst.class.opcode as u32]; @@ -172,37 +169,7 @@ fn make_dedupe_key_with_array_context( } } - // For array types, include storage class context in the key to prevent - // inappropriate deduplication between different storage class contexts - if let Some(result_id) = inst.result_id { - if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) { - if let Some(contexts) = array_contexts { - if let Some(context) = contexts.get(&result_id) { - // Include usage pattern in the key so arrays with different contexts won't deduplicate - let usage_pattern_discriminant = match context.usage_pattern { - crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutRequired => { - 1u32 - } - crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutForbidden => { - 2u32 - } - crate::linker::array_stride_fixer::ArrayUsagePattern::MixedUsage => 3u32, - crate::linker::array_stride_fixer::ArrayUsagePattern::Unused => 4u32, - }; - data.push(usage_pattern_discriminant); - - // Also include the specific storage classes for fine-grained differentiation - let mut storage_classes: Vec = context - .storage_classes - .iter() - .map(|sc| *sc as u32) - .collect(); - storage_classes.sort(); // Ensure deterministic ordering - data.extend(storage_classes); - } - } - } - } + // Array context feature removed - was never actually used data } @@ -220,15 +187,6 @@ fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &FxHashMap) } pub fn remove_duplicate_types(module: &mut Module) { - remove_duplicate_types_with_array_context(module, None); -} - -pub fn remove_duplicate_types_with_array_context( - module: &mut Module, - array_contexts: Option< - &FxHashMap, - >, -) { // Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module. // When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID. @@ -271,7 +229,6 @@ pub fn remove_duplicate_types_with_array_context( &unresolved_forward_pointers, &annotations, &names, - array_contexts, ); match key_to_result_id.entry(key) { diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index ead00e28c5..fa69dc8e7f 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -1,7 +1,6 @@ #[cfg(test)] mod test; -mod array_stride_fixer; mod dce; mod destructure_composites; mod duplicates; diff --git a/crates/rustc_codegen_spirv/src/linker/specializer.rs b/crates/rustc_codegen_spirv/src/linker/specializer.rs index 1353c68c98..0db9fa32bb 100644 --- a/crates/rustc_codegen_spirv/src/linker/specializer.rs +++ b/crates/rustc_codegen_spirv/src/linker/specializer.rs @@ -556,8 +556,6 @@ struct Generic { /// pass when it discovers the need for a particular variant. #[derive(Clone, Debug, Default)] struct ArrayLayoutVariants { - /// Variant that *keeps* the `ArrayStride` decoration, if it exists. - layout_required_id: Option, /// Variant that *omits* the `ArrayStride` decoration. layout_forbidden_id: Option, } @@ -2483,9 +2481,10 @@ impl<'a, S: Specialization> Expander<'a, S> { let mut inst = inst; // Ensure even non-generic `OpTypePointer`s use the proper array variant. if inst.class.opcode == Op::TypePointer { - if let (Some(Operand::StorageClass(sc)), Some(Operand::IdRef(pointee_id))) = - (inst.operands.first().cloned(), inst.operands.get(1).cloned()) - { + if let (Some(Operand::StorageClass(sc)), Some(Operand::IdRef(pointee_id))) = ( + inst.operands.first().cloned(), + inst.operands.get(1).cloned(), + ) { let need_no_stride = !self.specializer.allows_layout(sc); let variant_id = self.get_array_variant( pointee_id, @@ -2728,57 +2727,113 @@ impl<'a, S: Specialization> Expander<'a, S> { need_no_stride: bool, new_types: &mut Vec, ) -> Word { - // Only care about true arrays. - // Only care about true arrays. Look up the defining instruction from the cached map, - // as the original `types_global_values` list has been temporarily moved out of the - // `Module` during expansion. - let original_inst_opt = self.original_types.get(&original_id).cloned(); - - let original_inst = match original_inst_opt { - Some(i) if matches!(i.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) => i, - _ => return original_id, + // Look up the defining instruction from the cached map, as the original + // `types_global_values` list has been temporarily moved out of the `Module` + // during expansion. + let original_inst = match self.original_types.get(&original_id) { + Some(i) => i.clone(), + None => return original_id, }; - // Consult shared cache first with a short-lived borrow to avoid conflicts during - // recursive calls further below. - if need_no_stride { - if let Some(id) = self - .specializer - .array_layout_variants - .borrow() - .get(&original_id) - .and_then(|v| v.layout_forbidden_id) - { - return id; - } - } else { + // Fast-path if we don't need to strip layout information. + if !need_no_stride { return original_id; } - // Create stride-less clone. - let new_id = self.builder.id(); - let mut new_inst = original_inst.clone(); - new_inst.result_id = Some(new_id); + // If we already produced a stride-less variant for this type, just return it. + if let Some(id) = self + .specializer + .array_layout_variants + .borrow() + .get(&original_id) + .and_then(|v| v.layout_forbidden_id) + { + return id; + } + + match original_inst.class.opcode { + Op::TypeArray | Op::TypeRuntimeArray => { + // Handle arrays: remove the `ArrayStride` decoration (by simply cloning the + // instruction – the decoration lives in `OpDecorate`). - // Recurse into element type (array of arrays). - if let Some(&Operand::IdRef(elem_ty)) = original_inst.operands.first() { - let nested_variant = self.get_array_variant(elem_ty, need_no_stride, new_types); - if nested_variant != elem_ty { - new_inst.operands[0] = Operand::IdRef(nested_variant); + let new_id = self.builder.id(); + let mut new_inst = original_inst.clone(); + new_inst.result_id = Some(new_id); + + // Recurse into the element type (arrays of arrays). + if let Some(&Operand::IdRef(elem_ty)) = original_inst.operands.first() { + let nested_variant = self.get_array_variant(elem_ty, need_no_stride, new_types); + if nested_variant != elem_ty { + new_inst.operands[0] = Operand::IdRef(nested_variant); + } + } + + new_types.push(new_inst.clone()); + // Cache definition for later lookups. + self.original_types.insert(new_id, new_inst); + + self.specializer + .array_layout_variants + .borrow_mut() + .entry(original_id) + .or_default() + .layout_forbidden_id = Some(new_id); + + new_id } - } - new_types.push(new_inst.clone()); - // Record in the lookup so that future queries can find this variant. - self.original_types.insert(new_id, new_inst.clone()); + Op::TypeStruct => { + // For structs, we need to create a clone that references the stride-less + // variants of any member types that themselves changed. + + // Track whether any member type changed – if not, we can reuse the original. + let mut changed = false; + let mut new_operands = original_inst.operands.clone(); + + for (idx, op) in original_inst.operands.iter().enumerate() { + if let Operand::IdRef(member_ty) = op { + let new_member_ty = + self.get_array_variant(*member_ty, need_no_stride, new_types); + if new_member_ty != *member_ty { + new_operands[idx] = Operand::IdRef(new_member_ty); + changed = true; + } + } + } - // Update cache entry now (new borrow). - self.specializer - .array_layout_variants - .borrow_mut() - .entry(original_id) - .or_default() - .layout_forbidden_id = Some(new_id); - new_id + if !changed { + // Even if no member type changed, cache lookup to avoid redundant work. + self.specializer + .array_layout_variants + .borrow_mut() + .entry(original_id) + .or_default() + .layout_forbidden_id = Some(original_id); + return original_id; + } + + let new_id = self.builder.id(); + let mut new_inst = original_inst.clone(); + new_inst.result_id = Some(new_id); + new_inst.operands = new_operands; + + new_types.push(new_inst.clone()); + self.original_types.insert(new_id, new_inst); + + self.specializer + .array_layout_variants + .borrow_mut() + .entry(original_id) + .or_default() + .layout_forbidden_id = Some(new_id); + + new_id + } + + _ => { + // Not an array or struct – nothing to do. + original_id + } + } } } diff --git a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr index c7bfc7b87a..340badc371 100644 --- a/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr +++ b/tests/ui/linker/array_stride_fixer/nested_structs_function_storage.stderr @@ -10,37 +10,31 @@ OpMemoryModel Logical Vulkan OpEntryPoint GLCompute %1 "main" %2 OpExecutionMode %1 LocalSize 1 1 1 %3 = OpString "$OPSTRING_FILENAME/nested_structs_function_storage.rs" -OpName %4 "InnerStruct" -OpMemberName %4 0 "data" -OpName %5 "OuterStruct" -OpMemberName %5 0 "inner" -OpDecorate %6 ArrayStride 4 -OpDecorate %7 Block -OpMemberDecorate %7 0 Offset 0 +OpDecorate %4 ArrayStride 4 +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 OpDecorate %2 Binding 0 OpDecorate %2 DescriptorSet 0 -OpDecorate %8 ArrayStride 4 -OpMemberDecorate %4 0 Offset 0 -OpMemberDecorate %5 0 Offset 0 -%9 = OpTypeFloat 32 -%10 = OpTypeInt 32 0 -%11 = OpConstant %10 1 -%6 = OpTypeArray %9 %11 -%7 = OpTypeStruct %6 -%12 = OpTypePointer StorageBuffer %7 -%13 = OpTypeVoid -%14 = OpTypeFunction %13 -%15 = OpTypePointer StorageBuffer %6 -%2 = OpVariable %12 StorageBuffer -%16 = OpConstant %10 0 -%17 = OpConstant %10 4 -%8 = OpTypeArray %9 %17 -%4 = OpTypeStruct %8 +%6 = OpTypeFloat 32 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 1 +%4 = OpTypeArray %6 %8 %5 = OpTypeStruct %4 -%18 = OpConstant %9 0 -%19 = OpConstantComposite %8 %18 %18 %18 %18 -%20 = OpUndef %5 +%9 = OpTypePointer StorageBuffer %5 +%10 = OpTypeVoid +%11 = OpTypeFunction %10 +%12 = OpTypePointer StorageBuffer %4 +%2 = OpVariable %9 StorageBuffer +%13 = OpConstant %7 0 +%14 = OpConstant %7 4 +%15 = OpTypeArray %6 %14 +%16 = OpTypeStruct %15 +%17 = OpTypeStruct %16 +%18 = OpConstant %6 0 +%19 = OpConstantComposite %15 %18 %18 %18 %18 +%20 = OpUndef %17 %21 = OpTypeBool -%22 = OpConstant %9 1109917696 -%23 = OpConstant %9 1065353216 -%24 = OpTypePointer StorageBuffer %9 +%22 = OpConstant %6 1109917696 +%23 = OpConstant %6 1065353216 +%24 = OpTypePointer StorageBuffer %6 + diff --git a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr index e26ac19930..3f769cba39 100644 --- a/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr +++ b/tests/ui/linker/array_stride_fixer/workgroup_arrays_removed.stderr @@ -14,24 +14,23 @@ OpName %3 "shared_data" OpDecorate %5 ArrayStride 4 OpDecorate %6 Block OpMemberDecorate %6 0 Offset 0 -OpDecorate %7 ArrayStride 4 OpDecorate %2 Binding 0 OpDecorate %2 DescriptorSet 0 -%8 = OpTypeInt 32 0 -%9 = OpConstant %8 1 -%5 = OpTypeArray %8 %9 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 1 +%5 = OpTypeArray %7 %8 %6 = OpTypeStruct %5 -%10 = OpTypePointer StorageBuffer %6 -%11 = OpConstant %8 256 -%7 = OpTypeArray %8 %11 -%12 = OpTypePointer Workgroup %7 +%9 = OpTypePointer StorageBuffer %6 +%10 = OpConstant %7 256 +%11 = OpTypeArray %7 %10 +%12 = OpTypePointer Workgroup %11 %13 = OpTypeVoid %14 = OpTypeFunction %13 %15 = OpTypePointer StorageBuffer %5 -%2 = OpVariable %10 StorageBuffer -%16 = OpConstant %8 0 +%2 = OpVariable %9 StorageBuffer +%16 = OpConstant %7 0 %17 = OpTypeBool -%18 = OpTypePointer Workgroup %8 +%18 = OpTypePointer Workgroup %7 %3 = OpVariable %12 Workgroup -%19 = OpConstant %8 42 -%20 = OpTypePointer StorageBuffer %8 +%19 = OpConstant %7 42 +%20 = OpTypePointer StorageBuffer %7