From ceb8151772fceec4cb4a65763f53e889b0a23df3 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Wed, 2 Apr 2025 22:39:17 -0700 Subject: [PATCH 1/3] [ET-VK] Adding all tensor packing support for native layer norm. Pull Request resolved: https://github.com/pytorch/executorch/pull/9532 This diff updates Executorch Vulkan backend's `native layer norm` operation to support width, height and channel packed tensors. . and adds new test cases to the cases.py file to test the operation. ghstack-source-id: 275813549 @exported-using-ghexport Differential Revision: [D71663678](https://our.internmc.facebook.com/intern/diff/D71663678/) --- backends/vulkan/op_registry.py | 15 ++- .../graph/ops/glsl/native_layer_norm.glsl | 126 +++++++++++++----- .../graph/ops/impl/NativeLayerNorm.cpp | 7 +- backends/vulkan/test/op_tests/cases.py | 5 + 4 files changed, 114 insertions(+), 39 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 26f461c062f..54b7b8651bc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -576,7 +576,6 @@ def register_ported_op_all_packed_dims(features: OpFeatures): [ exir_ops.edge.aten.embedding.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, ] ) def register_ported_ops_with_prepacking(features: OpFeatures): @@ -587,6 +586,20 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +# Ported ops that support their own prepacking. +@update_features( + [ + exir_ops.edge.aten.native_layer_norm.default, + ] +) +def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims=all_packed_dims, + ) + features.handles_own_prepacking = True + return features + + ####################### ## Utility functions ## ####################### diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index f984821600b..f518e838750 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -15,6 +15,8 @@ #define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} @@ -48,37 +50,97 @@ void main() { const int width = int(sizes.x); - VEC4_T mean = VEC4_T(0); - VEC4_T delta = VEC4_T(0); - VEC4_T delta2 = VEC4_T(0); - VEC4_T M2 = VEC4_T(0); - - // Use Welford's online algorithm to compute mean and variance in one pass - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - delta = v - mean; - mean += delta / (w + 1); - delta2 = v - mean; - M2 += delta * delta2; - } - - VEC4_T var = M2 / width; - VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); - VEC4_T offset = -rstd * mean; - - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - // broadcasting - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; - VEC4_T outtex = (v * rstd + offset) * weight + bias; - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + if (in_packed_dim != W_DIM) { + VEC4_T mean = VEC4_T(0); + VEC4_T delta = VEC4_T(0); + VEC4_T delta2 = VEC4_T(0); + VEC4_T M2 = VEC4_T(0); + + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + delta = v - mean; + mean += delta / (w + 1); + delta2 = v - mean; + M2 += delta * delta2; + } + + VEC4_T var = M2 / width; + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); + VEC4_T offset = -rstd * mean; + + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + // broadcasting + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; + VEC4_T outtex = (v * rstd + offset) * weight + bias; + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + } + + write_texel(t_mean, lpos, mean); + write_texel(t_rstd, lpos, rstd); + } else { + const int packed_width = divup4(width); + + T mean = T(0); + T delta = T(0); + T delta2 = T(0); + T M2 = T(0); + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + T width_counter = T(1); + + const bool has_unaligned_width = (width & 0x3) != 0; + const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); + + // iterate through texels that are fully packed ie. has 4 components + for (int w = 0; w < fully_packed_4_comp_count; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; i<4; i++) { + delta = v[i] - mean; + mean += delta / width_counter; + delta2 = v[i] - mean; + M2 += delta * delta2; + width_counter++; + } + } + + // handle last texel if its not 4 aligned + if (has_unaligned_width) { + in_pos[in_axis_map.x] = fully_packed_4_comp_count; + const int remaining_width = width & 0x3; + + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; ivirtual_resize(mean_size); } -void check_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); -} - void add_native_layer_norm_node( ComputeGraph& graph, const ValueRef in, @@ -84,7 +79,7 @@ void add_native_layer_norm_node( vTensorPtr t_input = graph.get_tensor(in); float epsilon = graph.extract_scalar(eps); - check_args(*t_input, *t_out); + VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out)); std::vector in_sizes = t_input->sizes(); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index d2e09404ca0..88f5bea5c3e 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -385,6 +385,11 @@ def get_native_layer_norm_inputs(): ((S, XL, M1, M2), [M2], (M2), (M2), 0.001), ] ) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kHeightPacked", + "utils::kChannelsPacked", + ] return test_suite From c0bb116a1dc937c5dd3e2e9f151349d6eae8c223 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Wed, 2 Apr 2025 22:39:19 -0700 Subject: [PATCH 2/3] [ET-VK] Adding round op support. Pull Request resolved: https://github.com/pytorch/executorch/pull/9792 This diff adds support for the round op in the Vulkan backend for Executorch. ghstack-source-id: 275813551 Differential Revision: [D72218482](https://our.internmc.facebook.com/intern/diff/D72218482/) --- backends/vulkan/op_registry.py | 1 + backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 2 ++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 6 insertions(+) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 54b7b8651bc..b33430a6bca 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -276,6 +276,7 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.round.default, ] ) def register_unary_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 6757d2a6d45..f13393ce6c7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -44,3 +44,5 @@ unary_op: OPERATOR: hardsigmoid(X) - NAME: leaky_relu OPERATOR: leaky_relu(X, A) + - NAME: round + OPERATOR: round(X) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 4bf73fad5a1..9a3ab002403 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -149,6 +149,7 @@ DEFINE_HARDSHRINK_FN(hardshrink); DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); +DEFINE_ACTIVATION_FN(round); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -168,6 +169,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardswish.default, hardswish); VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); + VK_REGISTER_OP(aten.round.default, round); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 88f5bea5c3e..85008a52ff0 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1092,6 +1092,7 @@ def get_reduce_op_inputs(): "aten.hardswish.default", "aten.hardsigmoid.default", "aten.leaky_relu.default", + "aten.round.default", ] ) def get_unary_ops_inputs(): From 476a8384d0b38fcb1b67cca9e908df82559f709f Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Wed, 2 Apr 2025 22:39:20 -0700 Subject: [PATCH 3/3] [ET-VK] Replace Uniform buffers with push constants for native layer norm op Pull Request resolved: https://github.com/pytorch/executorch/pull/9831 This diff replaces Uniform buffers with push constants for the native layer norm op in the Vulkan backend of Executorch. The changes include updating the shader code to use push constants instead of Uniform buffers, and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D70943355](https://our.internmc.facebook.com/intern/diff/D70943355/) ghstack-source-id: 275813550 --- .../runtime/graph/ops/glsl/native_layer_norm.glsl | 8 +++++--- .../runtime/graph/ops/impl/NativeLayerNorm.cpp | 13 +++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index f518e838750..d6c94661ace 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -27,9 +27,11 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "out_limits")} -${layout_declare_ubo(B, "ivec4", "sizes")} -${layout_declare_ubo(B, "float", "epsilon")} +layout(push_constant) uniform PRECISION restrict Block { + ivec3 out_limits; + ivec4 sizes; + float epsilon; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 102c1518667..7aa98e52654 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -101,11 +101,7 @@ void add_native_layer_norm_node( vkapi::MemoryAccessType::WRITE}, {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers - { - t_out->logical_limits_ubo(), - t_out->sizes_ubo(), - graph.create_params_buffer(epsilon), - }, + {}, // Specialization Constants { t_input->hashed_layout(), @@ -113,7 +109,12 @@ void add_native_layer_norm_node( }, // Resizing Logic resize_native_layer_norm_node, - {normalized_shape})); + {normalized_shape}, + { + graph.logical_limits_pc_of(out_val->at(0)), + graph.sizes_pc_of(out_val->at(0)), + PushConstantDataInfo(&epsilon, sizeof(epsilon)), + })); } void native_layer_norm(ComputeGraph& graph, const std::vector& args) {