From de2aab69de54dc48dca6f577d215c8b077aeff87 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 31 Mar 2025 12:18:12 -0700 Subject: [PATCH] [ET-VK] Store weights transposed for int8 linear Pull Request resolved: https://github.com/pytorch/executorch/pull/9765 ## Context The weight tensor of a linear layer is usually stored in a transposed manner, such that when computing the matrix multiplication, the reduction traverses along the rows of the weight tensor as opposed to the columns. This results in a better memory access pattern for CPUs. However, for GPUs, I have found that "un-transposing" the weight tensors result in better performance. This is likely due to the fact since GPUs can compute multiple output elements in parallel, reading along the columns allows for coalescing memory loads among threads in a work group. ## Changes * Introduce the ability to transpose height and weight dims when transferring tensor data to the GPU. * Prepackthe weight tensor "un-transposed" for the int8 quantized linear operator ghstack-source-id: 275180033 @exported-using-ghexport Differential Revision: [D72066588](https://our.internmc.facebook.com/intern/diff/D72066588/) --- .../nchw_to_bitw8_image_nobitw8buffer.glsl | 21 ++++++++++-- .../graph/ops/glsl/nchw_to_buffer.glsl | 9 ++++- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 21 ++++++++++-- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 31 ++++++++--------- .../graph/ops/impl/QuantizedLinear.cpp | 4 +-- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 34 +++++++++++++++++-- .../vulkan/runtime/graph/ops/impl/Staging.h | 12 +++++++ backends/vulkan/test/op_tests/cases.py | 4 ++- 8 files changed, 110 insertions(+), 26 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl index 25113887dca..327c3868847 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl @@ -27,6 +27,8 @@ ${layout_declare_ubo(B, "ivec4", "sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} + const lowp ivec4 axis_map = unhash_axis_map(t_layout); const lowp int packed_dim = unhash_packed_dim(t_layout); @@ -41,8 +43,23 @@ int extend_sign(int x) { } ivec4 read_texel(ivec4 tidx) { + ivec4 tidx_to_use = tidx; + ivec4 sizes_to_use = sizes; + int packed_dim_to_use = packed_dim; + if (transpose_hw == 1) { + sizes_to_use.xy = sizes_to_use.yx; + tidx_to_use.xy = tidx.yx; + + if (packed_dim == 1) { + packed_dim_to_use = 0; + } + if (packed_dim == 0) { + packed_dim_to_use = 1; + } + } + const ivec4 buf_indices = tidx_to_nchwi( - tidx, sizes, packed_dim); + tidx_to_use, sizes_to_use, packed_dim_to_use); int shift = (1 << 8) - 1; ivec4 masks; @@ -70,7 +87,7 @@ ivec4 read_texel(ivec4 tidx) { void main() { const ivec3 lpos = ivec3(gl_GlobalInvocationID); - const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim); + ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim); if (any(greaterThanEqual(tidx, sizes))) { return; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index bf498f34d5b..32235a9ad65 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -21,6 +21,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; // This constant is unused in this shader but is kept so that the signature is // consistent with nchw_to_image. ${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} void main() { int out_bufi = int(gl_GlobalInvocationID.x); @@ -29,7 +30,13 @@ void main() { } ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides); - const int in_nchwi = tidx_to_nchwi(out_tidx, out_sizes); + + ivec4 sizes = out_sizes; + if (transpose_hw == 1) { + sizes.xy = sizes.yx; + out_tidx.xy = out_tidx.yx; + } + const int in_nchwi = tidx_to_nchwi(out_tidx, sizes); t_out[out_bufi] = nchw_in[in_nchwi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 3d2a102dac7..2f55535c82c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -30,14 +30,31 @@ $if not FROM_STAGING: layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} + const lowp ivec4 axis_map = unhash_axis_map(t_layout); const lowp int packed_dim = unhash_packed_dim(t_layout); VEC4_T read_texel(ivec4 tidx) { + ivec4 tidx_to_use = tidx; + ivec4 sizes_to_use = sizes; + int packed_dim_to_use = packed_dim; + if (transpose_hw == 1) { + sizes_to_use.xy = sizes_to_use.yx; + tidx_to_use.xy = tidx.yx; + + if (packed_dim == 1) { + packed_dim_to_use = 0; + } + if (packed_dim == 0) { + packed_dim_to_use = 1; + } + } + $if FROM_STAGING: - const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim); + const ivec4 buf_indices = tidx_to_nchwi(tidx_to_use, sizes_to_use, packed_dim_to_use); $else: - const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim); + const ivec4 buf_indices = tidx_to_4bufi(tidx_to_use, buf_strides, packed_dim_to_use); VEC4_T texel = VEC4_T(0); if (tidx[packed_dim] < sizes[packed_dim]) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 56bffaee675..228e2e8f870 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -64,24 +64,21 @@ void main() { FLOAT_T outval = FLOAT_T(0.0); - // Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0) int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_tidx.x * qmat2_strides.y; + int qmat2_offset = out_tidx.x; // TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop for (int i = 0; i < mat1_sizes.x; i++) { const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]); outval += mat1_val * mat2_val; mat1_offset++; - qmat2_offset++; + qmat2_offset += qmat2_strides.y; } - t_out[out_bufi] = outval; + t_out[out_bufi] = outval * scale; } #else // USING_TEXTURE @@ -97,25 +94,27 @@ void main() { return; } - const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4); + const uint16_t qmat2_pos_x = out_pos.x; VEC4_T outtex = VEC4_T(0); const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0)); + VEC4_T mat1_tex; + VEC4_T mat2_tex[4]; for ( uint16_t i = uint16_t(0), x = uint16_t(0); i < uint16_t(mat1_sizes.x); i += uint16_t(4), x++) { - const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); - const VEC4_T sums = VEC4_T( - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0)))); - - outtex += sums; + mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); + + mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0)); + mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0)); + mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0)); + mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0)); + + outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3]; } outtex *= scales; diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 59684d73bd2..2011331ec38 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -48,7 +48,7 @@ void resize_q_8w_linear_node( vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]); const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, qmat2->sizes()); + const int out_rows = utils::val_at(-1, qmat2->sizes()); std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { @@ -86,7 +86,7 @@ void add_q_8w_linear_node( // Ensure out is packed correctly out_W_packed = out_tmp; } - ValueRef q_mat2 = prepack_standard( + ValueRef q_mat2 = prepack_standard_hw_transposed( graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked); ValueRef scales = prepack_standard( graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked); diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 959d3974b73..f59d1cd65d9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -113,7 +113,8 @@ void add_tensor_to_staging_node( void add_prepack_standard_node( ComputeGraph& graph, const ValueRef tensor_data, - const ValueRef tensor) { + const ValueRef tensor, + const bool transpose_hw = false) { vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( *graph.get_tensor(tensor), graph.int8_buffers_enabled()); @@ -127,6 +128,8 @@ void add_prepack_standard_node( ubos.append({graph.sizes_ubo(tensor)}); } + int transpose_hw_spec = transpose_hw ? 1 : 0; + graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, @@ -138,7 +141,7 @@ void add_prepack_standard_node( // Parameter Buffers ubos, // Specialization Constants - {graph.hashed_layout_of(tensor)})); + {graph.hashed_layout_of(tensor), transpose_hw_spec})); } ValueRef prepack_standard( @@ -158,6 +161,33 @@ ValueRef prepack_standard( return tensor; } +ValueRef prepack_standard_hw_transposed( + ComputeGraph& graph, + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough, + const utils::AxisMapLayout axis_map_layout) { + (void)passthrough; + + VK_CHECK_COND(graph.val_is_tref(tensor_data)); + std::vector new_out_sizes = graph.sizes_of(tensor_data); + const int w_dim = new_out_sizes.size() - 1; + const int h_dim = new_out_sizes.size() - 2; + const int64_t tmp = new_out_sizes.at(w_dim); + new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim); + new_out_sizes.at(h_dim) = tmp; + ValueRef tensor = graph.add_tensor( + new_out_sizes, + graph.dtype_of(tensor_data), + storage_type, + layout, + -1, + axis_map_layout); + add_prepack_standard_node(graph, tensor_data, tensor, true); + return tensor; +} + ValueRef prepack_standard_like( ComputeGraph& graph, const ValueRef tensor_data, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index bc501d5d053..1b6f245bd34 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -51,6 +51,18 @@ ValueRef prepack_standard( const bool passthrough = false, const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap); +/* + * Same as prepack_standard, but transpose the height and width dimensions of + * the tensor while packing. + */ +ValueRef prepack_standard_hw_transposed( + ComputeGraph& graph, + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough = false, + const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap); + /* * Equivalent to `prepack_standard()` function, except the `storage_type` and * `memory_layout` are set to match `to_copy`, which must be a `Tensor`. diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 41d8edf1f25..329d62c2285 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -157,12 +157,14 @@ def get_weight_int8pack_mm_inputs(): [6, 1024, 256], [6, 256, 256], [6, 256, 512], + [4, 768, 4096], + [1024, 1024, 1024], ] inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list] test_suite = VkTestSuite(inputs_list) - test_suite.dtypes = ["at::kFloat", "at::kHalf"] + test_suite.dtypes = ["at::kFloat"] test_suite.layouts = ["utils::kWidthPacked"] test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] test_suite.prepacked_args = ["mat2", "scales"]