diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 4cbd1290401..0712062a37e 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -497,9 +497,7 @@ vTensor::vTensor( VK_CHECK_COND( dim_order_is_valid(dim_order_), "computed dim order is invalid"); - if (storage_type != utils::kBuffer) { - set_logical_limits(storage_.image_extents_); - } + set_logical_limits(storage_.image_extents_); } // NOLINTNEXTLINE diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl new file mode 100644 index 00000000000..c7fcdcc775a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_required_extensions("uint8")} +${define_required_extensions("int8")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +uint8_t get_first(const uint8_t packed) { + return uint8_t((packed & 0xF0) >> 4); +} + +uint8_t get_second(const uint8_t packed) { + return uint8_t(packed & 0x0F); +} + +uint8_t combine(const uint8_t first, const uint8_t second) { + return uint8_t(first << 4 | second); +} + +/* + * This shader packs the weight tensor into a texture. + * + * The original tensor has a (W, H) shape of (K / 2, N) and each scalar element + * is a uint8_t, which contains 2 packed 4 bit uint values. + * + * The transform performed by this shader is to first transpose the tensor, so + * the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers + * are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits + * of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of + * each value contain the 4, 5, 6, 7 4-bit values. + * + * As a concrete example, consider the following weight tensor. The | demarks + * the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the + * leftmost 4 bits and 2 in the rightmost 4 bits. + * + * 1| 2, 3| 4, 5| 6, 7| 8, + * 9|10, 11|12, 13|14, 15|16, + * 17|18, 19|20, 21|22, 23|24, + * 25|26, 27|28, 29|30, 31|32, + * 33|34, 35|36, 37|38, 39|40, + * 41|42, 43|44, 45|46, 47|48, + * 49|50, 51|52, 53|54, 55|56, + * 57|58, 59|60, 61|62, 63|64, + * + * After packing, the packed tensor would contain + * + * 1|33, 9|41, 17|49, 25|57, + * 2|34, 10|42, 18|50, 26|58, + * 3|35, 11|43, 19|51, 27|59, + * 4|36, 12|44, 20|52, 28|60, + * 5|37, 13|45, 21|53, 29|61, + * 6|38, 14|46, 22|54, 30|62, + * 7|39, 15|47, 23|55, 31|63, + * 8|40, 16|48, 24|56, 32|64, + * + * The purpose of interleaving is to make it easier to extract the unpacked + * values in order using the u8vec4 vectorized type. With the packing in place, + * The 4-bit values can be extracted via + * + * u8vec4 packed; + * u8vec4 vals_0123 = (packed & 0xF0) >> 4; + * u8vec4 vals_4567 = (packed | 0x0F); + */ +void main() { + // Each thread writes 2 output texels along the height axis + ivec2 packed_pos = ivec2( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y << 1); + + // The packed tensor is width packed + if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) { + return; + } + + int out_col = packed_pos.x << 3; + int out_row = packed_pos.y; + + int in_col = out_row; + int in_int8_col = in_col >> 1; + int in_row = out_col; + + int in_numrows = qmat2_sizes.x << 1; + int in_numcols = qmat2_sizes.y; + int in_num_int8_cols = qmat2_sizes.y >> 1; + + uint8_t in_vals[8][2]; + for (int r = 0; r < 8; ++r) { + if (in_row + r < in_numrows) { + uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col]; + in_vals[r][0] = get_first(in_val_packed); + in_vals[r][1] = get_second(in_val_packed); + } else { + in_vals[r][0] = uint8_t(254); + in_vals[r][1] = uint8_t(254); + } + } + + u8vec4 out_tex_1 = u8vec4( + combine(in_vals[0][0], in_vals[4][0]), + combine(in_vals[1][0], in_vals[5][0]), + combine(in_vals[2][0], in_vals[6][0]), + combine(in_vals[3][0], in_vals[7][0])); + + u8vec4 out_tex_2 = u8vec4( + combine(in_vals[0][1], in_vals[4][1]), + combine(in_vals[1][1], in_vals[5][1]), + combine(in_vals[2][1], in_vals[6][1]), + combine(in_vals[3][1], in_vals[7][1])); + + $if STORAGE == "buffer": + int stride = qmat2_sizes.x >> 2; + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; + $else: + imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1); + imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml new file mode 100644 index 00000000000..168b18fffe4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_int4_linear_weight_transposed_interleaved: + parameter_names_with_default_values: + STORAGE: texture3d + shader_variants: + - NAME: pack_int4_linear_weight_transposed_interleaved_texture3d + - NAME: pack_int4_linear_weight_transposed_interleaved_buffer + STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index b702a110a65..29f2934f957 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -8,34 +8,30 @@ #version 450 core -#include "indexing_utils.h" - #define PRECISION ${PRECISION} -#define FOUR 4 - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define FLOAT_T ${buffer_scalar_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} -${define_active_storage_type(STORAGE)} - -${define_required_extensions([DTYPE, "uint8", "uint16"])} -#extension GL_EXT_control_flow_attributes : require +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} layout(std430) buffer; -${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")} -${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "ret_limits")} -${layout_declare_ubo(B, "ivec4", "x_sizes")} -${layout_declare_ubo(B, "ivec4", "weights_strides")} -${layout_declare_ubo(B, "ivec4", "qparams_strides")} +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(constant_id = 3) const int group_size = 1; +layout(constant_id = 3) const int group_size = 64; /* * This shader computes a linear operator between a floating point input matrix @@ -43,9 +39,11 @@ layout(constant_id = 3) const int group_size = 1; * * The (W, H, C) shape of each tensor is: * - x: (K, M) - * - weights: (K / 2, N) + * - weights: (N / 2, K) * - The weights tensor has a data type of `uint8`. Each element in the tensor * contains 2 4-bit values packed into a uint8. + * - See the pack_int4_linear_weight_transposed_interleave shader to see more + * details on how the weight tensor is stored. * - qparams: (2, N, number_of_groups) * - This tensor contains the scales and zeros quantization parameters for the * weights tensor. The weight tensor is quantized group-wise, which means @@ -57,56 +55,68 @@ layout(constant_id = 3) const int group_size = 1; * Note that this shader assumes that all tensors are width packed. */ void main() { - // output positions being calculated are (n, m), (n + 1, m), ... - // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows - // of the weights tensor. - const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(ret_pos, ret_limits))) { + const uint out_row = gl_GlobalInvocationID.y; + // Each thread writes out 2 texels along the width axis, equivalent to 8 + // scalar elements. Therefore multiply the thread_idx.x by 8. + const uint out_col = gl_GlobalInvocationID.x << 3; + // Similar reasoning to the above, each thread works on 2 texels along the + // width axis so multiply thread_idx.x by 2. + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { return; } - // Since ret is width packed, need to multiply by 4 - const uint16_t n = uint16_t(ret_pos.x * 4); + const int num_blocks = mat1_sizes.x / group_size; - // K is guaranteed to be a multiple of group size - const uint16_t num_blocks = uint16_t(x_sizes.x / group_size); + VEC4_T sums[2]; - uint16_t k_texel_i = uint16_t(0); - vec4 sums = vec4(0.0); - for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) { - vec4 scales; - vec4 zeros; + sums[0] = VEC4_T(0); + sums[1] = VEC4_T(0); - [[unroll]] for (int comp = 0; comp < 4; comp++) { - const vec4 scale_and_zero = load_texel( - qparams, u16vec3(0, n + comp, block_idx)); - scales[comp] = scale_and_zero.x; - zeros[comp] = scale_and_zero.y; - } + VEC4_T scales[2]; + VEC4_T zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); + zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) { - const VEC4_T x_texel = load_texel( - x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z)); - - [[unroll]] for (int comp = 0; comp < 4; comp++) { - const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2); - // Need to read 4 unpacked values, which corresponds to 2 packed values - const uint8_t weights_val_1 = weights[weights_bufi]; - const uint8_t weights_val_2 = weights[weights_bufi + 1]; - - const u8vec4 weights_texel = u8vec4( - (weights_val_1 & 0xF0) >> 4, - weights_val_1 & 0x0F, - (weights_val_2 & 0xF0) >> 4, - weights_val_2 & 0x0F); - - // Note that the unpacked 4-bit values are unsigned, therefore they must - // first be "centered" around 0 by subtracting 8 before applying the - // scale and zero point. - sums[comp] += dot( - x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]); + scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); + zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); + + for (int g_idx = 0; g_idx < group_size; g_idx += 4) { + const int k = block_idx * group_size + g_idx; + + $if IN_STORAGE == "buffer": + const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2]; + $else: + const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0); + + for (int comp = 0; comp < 4; ++comp) { + $if WEIGHT_STORAGE == "buffer": + const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec3(gl_GlobalInvocationID.x, k + comp, 0), + 0); + + const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4; + const uvec4 weight_tex_2 = packed_weight_tex & 0x0F; + + sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]); + sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]); } } } - write_texel(ret, ret_pos, sums); + + $if OUT_STORAGE == "buffer": + t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0]; + t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1]; + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml index 40d95d4a05f..fac9c25c220 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml @@ -7,10 +7,17 @@ q_4w_linear: parameter_names_with_default_values: DTYPE: float - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: float - - VALUE: half + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture3d shader_variants: - - NAME: q_4w_linear_texture3d + - NAME: q_4w_linear_texture3d_texture3d_texture3d_float + - NAME: q_4w_linear_texture3d_buffer_texture3d_float + IN_STORAGE: buffer + - NAME: q_4w_linear_buffer_buffer_texture3d_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: q_4w_linear_buffer_buffer_buffer_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp new file mode 100644 index 00000000000..b795e574291 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +namespace vkcompute { + +void check_q_4w_linear_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.int16_shader_types_enabled()); + VK_CHECK_COND(graph.int8_buffers_enabled()); + + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + // Due to the way weight packing works, group size needs to be a multiple of 8 + VK_CHECK_COND(group_size_val % 8 == 0); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); +} + +void resize_q_4w_linear_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int out_cols = utils::val_at(-2, mat1->sizes()); + const int out_rows = utils::val_at(-1, mat2->sizes()) * 2; + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + const int64_t N_div2 = N / int64_t(2); + + utils::StorageType storage_type = utils::kTexture3D; + utils::uvec3 max_extents = + graph.context()->adapter_ptr()->max_texture_extents(); + if (N_div2 > max_extents[0] * 4 || K > max_extents[1]) { + storage_type = utils::kBuffer; + } + + std::vector qmat2_sizes{K, N_div2}; + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); + + utils::uvec3 global_wg_size; + global_wg_size = graph.logical_limits_of(qmat2); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2)); + + std::string kernel_name = "pack_int4_linear_weight_transposed_interleaved"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2)})); + + return qmat2; +} + +void add_q_4w_linear_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef out) { + check_q_4w_linear_args( + graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + + ValueRef mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + + ValueRef scales_and_zeros = prepack_standard_hw_transposed( + graph, scales_and_zeros_data, utils::kTexture3D, utils::kWidthPacked); + + std::string kernel_name = "q_4w_linear"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const uint32_t group_size_val = graph.extract_scalar(group_size); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2, scales_and_zeros}, vkapi::kRead}}, + // Shader params buffers + {}, + // Specialization Constants + {SV(group_size_val)}, + // Resizing Logic + resize_q_4w_linear_node, + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)})); +} + +void linear_weight_int4( + ComputeGraph& graph, + const std::vector& args) { + return add_q_4w_linear_node( + graph, + args[0], // mat1 + args[1], // mat2 + args[2], // group_size + args[3], // scales_and_zeros + // There is an unused variable inner_k_tiles which is used to call + // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th + // argument is skipped. + args[5] // out + ); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp similarity index 64% rename from backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp rename to backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index f4f5c853ddd..49085ff4e06 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -268,157 +268,8 @@ void weight_int8pack_mm( return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } -void check_q_4w_linear_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros, - const ValueRef out) { - VK_CHECK_COND(graph.int16_shader_types_enabled()); - VK_CHECK_COND(graph.int8_buffers_enabled()); - - VK_CHECK_COND(graph.val_is_tensor(mat1)); - VK_CHECK_COND(graph.val_is_tref(mat2_data)); - VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); - - VK_CHECK_COND(graph.dim_of(mat1) <= 3); - VK_CHECK_COND(graph.dim_of(mat2_data) == 2); - VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); - - VK_CHECK_COND(graph.size_at(-3, mat1) == 1); - const int K = graph.size_at(-1, mat1); - VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); - - const int group_size_val = graph.extract_scalar(group_size); - VK_CHECK_COND(K % group_size_val == 0); - - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); - VK_CHECK_COND(graph.has_standard_axis_map(out)); -} - -void resize_q_4w_linear_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, mat2->sizes()); - - std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - out->virtual_resize(new_out_sizes); -} - -void add_q_4w_linear_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - check_q_4w_linear_args( - graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - - utils::StorageType storage_type = graph.storage_type_of(out); - - ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data); - - ValueRef scales_and_zeros = prepack_standard( - graph, - scales_and_zeros_data, - graph.storage_type_of(out), - utils::kWidthPacked); - - std::string kernel_name = "q_4w_linear"; - add_storage_type_suffix(kernel_name, storage_type); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - const uint32_t group_size_val = graph.extract_scalar(group_size); - - ValueRef mat1_W_packed = mat1; - ValueRef out_W_packed = out; - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - // Create temporary tensors to store the width packed versions of mat1 and out - TmpTensor mat1_tmp( - &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); - TmpTensor out_tmp( - &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); - if (storage_type == utils::kTexture3D) { - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - // Ensure mat1 is width packed - mat1_W_packed = mat1_tmp; - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // Ensure out is packed correctly - out_W_packed = out_tmp; - } - } - - vkapi::ParamsBindList ubos({}); - ubos.append(graph.logical_limits_ubo(out_W_packed)); - ubos.append(graph.sizes_ubo(mat1_W_packed)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed); - utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, - // Inputs and Outputs - {{out_W_packed, vkapi::MemoryAccessType::WRITE}, - {{mat1_W_packed, mat2, scales_and_zeros}, - vkapi::MemoryAccessType::READ}}, - // Shader params buffers - ubos, - // Specialization Constants - {SV(group_size_val)}, - // Resizing Logic - resize_q_4w_linear_node, - {})); - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(out) != WHCN::kWidthDim) { - viewFn(graph, {out_W_packed, graph.add_none(), out}); - } -} - -void linear_weight_int4( - ComputeGraph& graph, - const std::vector& args) { - return add_q_4w_linear_node( - graph, - args[0], // mat1 - args[1], // mat2 - args[2], // group_size - args[3], // scales_and_zeros - // There is an unused variable inner_k_tiles which is used to call - // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th - // argument is skipped. - args[5] // out - ); -} - REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); - VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index be0554161d3..68371d3eebf 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -211,6 +211,13 @@ class Adapter final { return physical_device_.min_ubo_alignment; } + inline utils::uvec3 max_texture_extents() const { + return { + physical_device_.properties.limits.maxImageDimension1D, + physical_device_.properties.limits.maxImageDimension2D, + physical_device_.properties.limits.maxImageDimension3D}; + } + // Command Buffer Submission void diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 66a585844cf..884b068bc24 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -152,13 +152,17 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { } } -void test_vulkan_linear_int4( +void test_vulkan_linear_int4_impl( const int B, const int M, const int K, const int N, const int group_size = 32, - const int inner_k_tiles = 8) { + const int inner_k_tiles = 8, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { assert(K % group_size == 0); at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); @@ -169,8 +173,13 @@ void test_vulkan_linear_int4( at::Tensor scales_and_zeros = at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor out_ref = dequantize_and_linear( - x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); + at::Tensor out_ref = linear_weight_int4_reference_impl( + x, + at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size), + group_size, + scales_and_zeros, + inner_k_tiles); // Build Vulkan graph using namespace vkcompute; @@ -188,14 +197,13 @@ void test_vulkan_linear_int4( MAKE_TENSORREF_FOR(weights_4x2); MAKE_TENSORREF_FOR(scales_and_zeros); -#define MAKE_INPUT_FOR(x) \ - IOValueRef r_##x = graph.add_input_tensor( \ - x.sizes().vec(), from_at_scalartype(x.scalar_type())); - - MAKE_INPUT_FOR(x); + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); const ValueRef r_out = graph.add_tensor( - out_ref.sizes().vec(), from_at_scalartype(out_ref.scalar_type())); + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); VK_GET_OP_FN("et_vk.linear_weight_int4.default") (graph, @@ -229,6 +237,34 @@ void test_vulkan_linear_int4( ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); } +void test_vulkan_linear_int4( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32, + const int inner_k_tiles = 8) { + test_vulkan_linear_int4_impl( + B, + M, + K, + N, + group_size, + inner_k_tiles, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + test_vulkan_linear_int4_impl( + B, + M, + K, + N, + group_size, + inner_k_tiles, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + TEST(VulkanInt4LinearTest, test_reference_impl) { test_reference_linear_int4( /*B = */ 1,