From 1abe1a969610883ccf37d53c6cb2b0621ccf2c78 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 3 Apr 2025 13:48:46 -0700 Subject: [PATCH 1/2] [ET-VK] Improve packing format for int4 linear operator + misc improvements ## Context Improve performance of the quantized int4 linear shader by packing the scales and zeros tensor, as well as the weight tensor in a more optimal way. See the comments in the `pack_int4_linear_weight_transposed_interleave` shader for more details about how the new packing works. ## Changes * Split int8 quantized linear and int4 quantized linear into separate C++ files for better code organization * Introduce packing shader for int4 weights * Update int4 linear shader to account for packed weights ## Impact This change massively improves the performance of the weight int4 quantized linear operator. With this change, running LLaMa 3.2 1B can now achieve 10 tok/s, from 0.9 tok/s on an Adreno 740. This is a 10x improvement! With this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 332.3 MB/s (74692800 bytes in 0.214s) I 00:00:00.003353 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003533 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003563 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003685 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003747 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003799 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003852 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003902 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003976 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004289 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:04.841690 executorch:runner.cpp:101] Reading metadata from model I 00:00:04.841808 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:04.841830 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:04.841851 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:04.841874 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:04.841893 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:04.841909 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:04.841927 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:04.841945 executorch:runner.cpp:133] eos_id = 128009 I 00:00:04.841951 executorch:runner.cpp:133] eos_id = 128001 I 00:00:04.841963 executorch:runner.cpp:188] RSS after loading model: 2229.828125 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:06.239633 executorch:runner.cpp:258] RSS after prompt prefill: 2229.828125 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:00:17.699086 executorch:runner.cpp:272] RSS after finishing text generation: 2229.828125 MiB (0 if unsupported) I 00:00:17.699155 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:00:17.699161 executorch:stats.h:114] Model Load Time: 4.837000 (seconds) I 00:00:17.699165 executorch:stats.h:124] Total inference time: 12.857000 (seconds) Rate: 8.788987 (tokens/second) I 00:00:17.699168 executorch:stats.h:132] Prompt evaluation: 1.398000 (seconds) Rate: 10.014306 (tokens/second) I 00:00:17.699171 executorch:stats.h:143] Generated 113 tokens: 11.459000 (seconds) Rate: 9.861244 (tokens/second) I 00:00:17.699174 executorch:stats.h:151] Time to first generated token: 1.398000 (seconds) I 00:00:17.699177 executorch:stats.h:158] Sampling time over 127 tokens: 549246500.843000 (seconds) ``` Before this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 302.0 MB/s (74637464 bytes in 0.236s) I 00:00:00.003050 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003200 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003226 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003337 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003396 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003449 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003502 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003553 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003629 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004075 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:05.417531 executorch:runner.cpp:101] Reading metadata from model I 00:00:05.417647 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:05.417669 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:05.417698 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:05.417716 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:05.417735 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:05.417751 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:05.417768 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:05.417787 executorch:runner.cpp:133] eos_id = 128009 I 00:00:05.417793 executorch:runner.cpp:133] eos_id = 128001 I 00:00:05.417808 executorch:runner.cpp:188] RSS after loading model: 2230.812500 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:19.689616 executorch:runner.cpp:258] RSS after prompt prefill: 2230.812500 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:02:15.269693 executorch:runner.cpp:272] RSS after finishing text generation: 2230.812500 MiB (0 if unsupported) I 00:02:15.269810 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:02:15.269825 executorch:stats.h:114] Model Load Time: 5.414000 (seconds) I 00:02:15.269832 executorch:stats.h:124] Total inference time: 129.852000 (seconds) Rate: 0.870221 (tokens/second) I 00:02:15.269837 executorch:stats.h:132] Prompt evaluation: 14.271000 (seconds) Rate: 0.981010 (tokens/second) I 00:02:15.269841 executorch:stats.h:143] Generated 113 tokens: 115.581000 (seconds) Rate: 0.977669 (tokens/second) I 00:02:15.269844 executorch:stats.h:151] Time to first generated token: 14.271000 (seconds) I 00:02:15.269847 executorch:stats.h:158] Sampling time over 127 tokens: 549711269.115000 (seconds) PyTorchObserver {"prompt_tokens":14,"generated_tokens":113,"model_load_start_ms":1743712527974,"model_load_end_ms":1743712533388,"inference_start_ms":1743712533388,"inference_end_ms":1743712663240,"prompt_eval_end_ms":1743712547659,"first_token_ms":1743712547659,"aggregate_sampling_time_ms":549711269115,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` Differential Revision: [D72412950](https://our.internmc.facebook.com/intern/diff/D72412950/) [ghstack-poisoned] --- ..._linear_weight_transposed_interleaved.glsl | 131 +++++++++++++ ..._linear_weight_transposed_interleaved.yaml | 11 ++ .../runtime/graph/ops/glsl/q_4w_linear.glsl | 130 +++++++------ .../runtime/graph/ops/glsl/q_4w_linear.yaml | 7 +- .../graph/ops/impl/QuantizedLinearInt4.cpp | 183 ++++++++++++++++++ ...izedLinear.cpp => QuantizedLinearInt8.cpp} | 149 -------------- 6 files changed, 395 insertions(+), 216 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp rename backends/vulkan/runtime/graph/ops/impl/{QuantizedLinear.cpp => QuantizedLinearInt8.cpp} (64%) diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl new file mode 100644 index 00000000000..b156861213e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_required_extensions("uint8")} +${define_required_extensions("int8")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE)} +${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +uint8_t get_first(const uint8_t packed) { + return uint8_t((packed & 0xF0) >> 4); +} + +uint8_t get_second(const uint8_t packed) { + return uint8_t(packed & 0x0F); +} + +uint8_t combine(const uint8_t first, const uint8_t second) { + return uint8_t(first << 4 | second); +} + +/* + * This shader packs the weight tensor into a texture. + * + * The original tensor has a (W, H) shape of (K / 2, N) and each scalar element + * is a uint8_t, which contains 2 packed 4 bit uint values. + * + * The transform performed by this shader is to first transpose the tensor, so + * the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers + * are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits + * of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of + * each value contain the 4, 5, 6, 7 4-bit values. + * + * As a concrete example, consider the following weight tensor. The | demarks + * the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the + * leftmost 4 bits and 2 in the rightmost 4 bits. + * + * 1| 2, 3| 4, 5| 6, 7| 8, + * 9|10, 11|12, 13|14, 15|16, + * 17|18, 19|20, 21|22, 23|24, + * 25|26, 27|28, 29|30, 31|32, + * 33|34, 35|36, 37|38, 39|40, + * 41|42, 43|44, 45|46, 47|48, + * 49|50, 51|52, 53|54, 55|56, + * 57|58, 59|60, 61|62, 63|64, + * + * After packing, the packed tensor would contain + * + * 1|33, 9|41, 17|49, 25|57, + * 2|34, 10|42, 18|50, 26|58, + * 3|35, 11|43, 19|51, 27|59, + * 4|36, 12|44, 20|52, 28|60, + * 5|37, 13|45, 21|53, 29|61, + * 6|38, 14|46, 22|54, 30|62, + * 7|39, 15|47, 23|55, 31|63, + * 8|40, 16|48, 24|56, 32|64, + * + * The purpose of interleaving is to make it easier to extract the unpacked + * values in order using the u8vec4 vectorized type. With the packing in place, + * The 4-bit values can be extracted via + * + * u8vec4 packed; + * u8vec4 vals_0123 = (packed & 0xF0) >> 4; + * u8vec4 vals_4567 = (packed | 0x0F); + */ +void main() { + // Each thread writes 2 output texels along the height axis + ivec2 packed_pos = ivec2( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y << 1); + + // The packed tensor is width packed + if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) { + return; + } + + int out_col = packed_pos.x << 3; + int out_row = packed_pos.y; + + int in_col = out_row; + int in_int8_col = in_col >> 1; + int in_row = out_col; + + int in_numrows = qmat2_sizes.x << 1; + int in_numcols = qmat2_sizes.y; + int in_num_int8_cols = qmat2_sizes.y >> 1; + + uint8_t in_vals[8][2]; + for (int r = 0; r < 8; ++r) { + if (in_row + r < in_numrows) { + uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col]; + in_vals[r][0] = get_first(in_val_packed); + in_vals[r][1] = get_second(in_val_packed); + } else { + in_vals[r][0] = uint8_t(254); + in_vals[r][1] = uint8_t(254); + } + } + + u8vec4 out_tex_1 = u8vec4( + combine(in_vals[0][0], in_vals[4][0]), + combine(in_vals[1][0], in_vals[5][0]), + combine(in_vals[2][0], in_vals[6][0]), + combine(in_vals[3][0], in_vals[7][0])); + + u8vec4 out_tex_2 = u8vec4( + combine(in_vals[0][1], in_vals[4][1]), + combine(in_vals[1][1], in_vals[5][1]), + combine(in_vals[2][1], in_vals[6][1]), + combine(in_vals[3][1], in_vals[7][1])); + + imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1); + imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml new file mode 100644 index 00000000000..1dbb0b87557 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_int4_linear_weight_transposed_interleaved: + parameter_names_with_default_values: + STORAGE: texture3d + shader_variants: + - NAME: pack_int4_linear_weight_transposed_interleaved diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index b702a110a65..bfb3dc516d8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -8,34 +8,42 @@ #version 450 core -#include "indexing_utils.h" - #define PRECISION ${PRECISION} -#define FOUR 4 - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define FLOAT_T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} -${define_required_extensions([DTYPE, "uint8", "uint16"])} -#extension GL_EXT_control_flow_attributes : require +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} layout(std430) buffer; -${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")} -${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "ret_limits")} -${layout_declare_ubo(B, "ivec4", "x_sizes")} -${layout_declare_ubo(B, "ivec4", "weights_strides")} -${layout_declare_ubo(B, "ivec4", "qparams_strides")} +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE)} +${layout_declare_tensor(B, "r", "t_qparams", DTYPE, STORAGE)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(constant_id = 3) const int group_size = 1; +layout(constant_id = 3) const int group_size = 64; + +uint8_t get_first(const uint8_t packed) { + return uint8_t((packed & 0xF0) >> 4); +} + +uint8_t get_second(const uint8_t packed) { + return uint8_t(packed & 0x0F); +} + +uint8_t combine(const uint8_t first, const uint8_t second) { + return uint8_t(first << 4 | second); +} /* * This shader computes a linear operator between a floating point input matrix @@ -43,9 +51,11 @@ layout(constant_id = 3) const int group_size = 1; * * The (W, H, C) shape of each tensor is: * - x: (K, M) - * - weights: (K / 2, N) + * - weights: (N / 2, K) * - The weights tensor has a data type of `uint8`. Each element in the tensor * contains 2 4-bit values packed into a uint8. + * - See the pack_int4_linear_weight_transposed_interleave shader to see more + * details on how the weight tensor is stored. * - qparams: (2, N, number_of_groups) * - This tensor contains the scales and zeros quantization parameters for the * weights tensor. The weight tensor is quantized group-wise, which means @@ -57,56 +67,52 @@ layout(constant_id = 3) const int group_size = 1; * Note that this shader assumes that all tensors are width packed. */ void main() { - // output positions being calculated are (n, m), (n + 1, m), ... - // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows - // of the weights tensor. - const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(ret_pos, ret_limits))) { + const uint out_row = gl_GlobalInvocationID.y; + // Each thread writes out 2 texels along the width axis, equivalent to 8 + // scalar elements. Therefore multiply the thread_idx.x by 8. + const uint out_col = gl_GlobalInvocationID.x << 3; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { return; } - // Since ret is width packed, need to multiply by 4 - const uint16_t n = uint16_t(ret_pos.x * 4); + const int num_blocks = mat1_sizes.x / group_size; - // K is guaranteed to be a multiple of group size - const uint16_t num_blocks = uint16_t(x_sizes.x / group_size); + VEC4_T sums[2]; - uint16_t k_texel_i = uint16_t(0); - vec4 sums = vec4(0.0); - for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) { - vec4 scales; - vec4 zeros; + sums[0] = VEC4_T(0); + sums[1] = VEC4_T(0); - [[unroll]] for (int comp = 0; comp < 4; comp++) { - const vec4 scale_and_zero = load_texel( - qparams, u16vec3(0, n + comp, block_idx)); - scales[comp] = scale_and_zero.x; - zeros[comp] = scale_and_zero.y; - } + VEC4_T scales[2]; + VEC4_T zeros[2]; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + scales[0] = texelFetch(t_qparams, ivec3(gl_GlobalInvocationID.x << 1, 0, block_idx), 0); + zeros[0] = texelFetch(t_qparams, ivec3(gl_GlobalInvocationID.x << 1, 1, block_idx), 0); + + scales[1] = texelFetch(t_qparams, ivec3((gl_GlobalInvocationID.x << 1) + 1, 0, block_idx), 0); + zeros[1] = texelFetch(t_qparams, ivec3((gl_GlobalInvocationID.x << 1) + 1, 1, block_idx), 0); + + for (int g_idx = 0; g_idx < group_size; g_idx += 4) { + const int k = block_idx * group_size + g_idx; - for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) { - const VEC4_T x_texel = load_texel( - x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z)); - - [[unroll]] for (int comp = 0; comp < 4; comp++) { - const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2); - // Need to read 4 unpacked values, which corresponds to 2 packed values - const uint8_t weights_val_1 = weights[weights_bufi]; - const uint8_t weights_val_2 = weights[weights_bufi + 1]; - - const u8vec4 weights_texel = u8vec4( - (weights_val_1 & 0xF0) >> 4, - weights_val_1 & 0x0F, - (weights_val_2 & 0xF0) >> 4, - weights_val_2 & 0x0F); - - // Note that the unpacked 4-bit values are unsigned, therefore they must - // first be "centered" around 0 by subtracting 8 before applying the - // scale and zero point. - sums[comp] += dot( - x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]); + const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0); + + for (int comp = 0; comp < 4; ++comp) { + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec3(gl_GlobalInvocationID.x, k + comp, 0), + 0); + + const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4; + const uvec4 weight_tex_2 = packed_weight_tex & 0x0F; + + sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]); + sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]); } } } - write_texel(ret, ret_pos, sums); + + imageStore(t_out, ivec3((gl_GlobalInvocationID.x << 1), out_row, 0), sums[0]); + imageStore(t_out, ivec3((gl_GlobalInvocationID.x << 1) + 1, out_row, 0), sums[1]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml index 40d95d4a05f..cfd328653e4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml @@ -8,9 +8,6 @@ q_4w_linear: parameter_names_with_default_values: DTYPE: float STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: float - - VALUE: half + WEIGHT_STORAGE: texture3d shader_variants: - - NAME: q_4w_linear_texture3d + - NAME: q_4w_linear_texture3d_float diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp new file mode 100644 index 00000000000..6f887fe4868 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +namespace vkcompute { + +void check_q_4w_linear_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.int16_shader_types_enabled()); + VK_CHECK_COND(graph.int8_buffers_enabled()); + + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + // Due to the way weight packing works, group size needs to be a multiple of 8 + VK_CHECK_COND(group_size_val % 8 == 0); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); +} + +void resize_q_4w_linear_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int out_cols = utils::val_at(-2, mat1->sizes()); + const int out_rows = utils::val_at(-1, mat2->sizes()) * 2; + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data, + const utils::StorageType storage_type) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + const int64_t N_div2 = N / int64_t(2); + + std::vector qmat2_sizes{K, N_div2}; + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); + + utils::uvec3 global_size = graph.logical_limits_of(qmat2); + global_size[1] = utils::div_up(global_size[1], uint32_t(2)); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL(pack_int4_linear_weight_transposed_interleaved), + global_size, + graph.create_local_wg_size(global_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2)})); + + return qmat2; +} + +void add_q_4w_linear_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef out) { + check_q_4w_linear_args( + graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + + ValueRef mat2 = prepack_int4_linear_weight_transposed_interleaved( + graph, mat2_data, utils::kTexture3D); + // ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data); + + ValueRef scales_and_zeros = prepack_standard_hw_transposed( + graph, + scales_and_zeros_data, + graph.storage_type_of(out), + utils::kWidthPacked); + + std::string kernel_name = "q_4w_linear"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const uint32_t group_size_val = graph.extract_scalar(group_size); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2, scales_and_zeros}, vkapi::kRead}}, + // Shader params buffers + {}, + // Specialization Constants + {SV(group_size_val)}, + // Resizing Logic + resize_q_4w_linear_node, + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)})); +} + +void linear_weight_int4( + ComputeGraph& graph, + const std::vector& args) { + return add_q_4w_linear_node( + graph, + args[0], // mat1 + args[1], // mat2 + args[2], // group_size + args[3], // scales_and_zeros + // There is an unused variable inner_k_tiles which is used to call + // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th + // argument is skipped. + args[5] // out + ); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp similarity index 64% rename from backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp rename to backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index f4f5c853ddd..49085ff4e06 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -268,157 +268,8 @@ void weight_int8pack_mm( return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } -void check_q_4w_linear_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros, - const ValueRef out) { - VK_CHECK_COND(graph.int16_shader_types_enabled()); - VK_CHECK_COND(graph.int8_buffers_enabled()); - - VK_CHECK_COND(graph.val_is_tensor(mat1)); - VK_CHECK_COND(graph.val_is_tref(mat2_data)); - VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); - - VK_CHECK_COND(graph.dim_of(mat1) <= 3); - VK_CHECK_COND(graph.dim_of(mat2_data) == 2); - VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); - - VK_CHECK_COND(graph.size_at(-3, mat1) == 1); - const int K = graph.size_at(-1, mat1); - VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); - - const int group_size_val = graph.extract_scalar(group_size); - VK_CHECK_COND(K % group_size_val == 0); - - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); - VK_CHECK_COND(graph.has_standard_axis_map(out)); -} - -void resize_q_4w_linear_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, mat2->sizes()); - - std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - out->virtual_resize(new_out_sizes); -} - -void add_q_4w_linear_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - check_q_4w_linear_args( - graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - - utils::StorageType storage_type = graph.storage_type_of(out); - - ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data); - - ValueRef scales_and_zeros = prepack_standard( - graph, - scales_and_zeros_data, - graph.storage_type_of(out), - utils::kWidthPacked); - - std::string kernel_name = "q_4w_linear"; - add_storage_type_suffix(kernel_name, storage_type); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - const uint32_t group_size_val = graph.extract_scalar(group_size); - - ValueRef mat1_W_packed = mat1; - ValueRef out_W_packed = out; - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - // Create temporary tensors to store the width packed versions of mat1 and out - TmpTensor mat1_tmp( - &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); - TmpTensor out_tmp( - &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); - if (storage_type == utils::kTexture3D) { - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - // Ensure mat1 is width packed - mat1_W_packed = mat1_tmp; - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // Ensure out is packed correctly - out_W_packed = out_tmp; - } - } - - vkapi::ParamsBindList ubos({}); - ubos.append(graph.logical_limits_ubo(out_W_packed)); - ubos.append(graph.sizes_ubo(mat1_W_packed)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed); - utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, - // Inputs and Outputs - {{out_W_packed, vkapi::MemoryAccessType::WRITE}, - {{mat1_W_packed, mat2, scales_and_zeros}, - vkapi::MemoryAccessType::READ}}, - // Shader params buffers - ubos, - // Specialization Constants - {SV(group_size_val)}, - // Resizing Logic - resize_q_4w_linear_node, - {})); - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(out) != WHCN::kWidthDim) { - viewFn(graph, {out_W_packed, graph.add_none(), out}); - } -} - -void linear_weight_int4( - ComputeGraph& graph, - const std::vector& args) { - return add_q_4w_linear_node( - graph, - args[0], // mat1 - args[1], // mat2 - args[2], // group_size - args[3], // scales_and_zeros - // There is an unused variable inner_k_tiles which is used to call - // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th - // argument is skipped. - args[5] // out - ); -} - REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); - VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); } } // namespace vkcompute From 76718915f2aceba3d5be25f3556337fbf49244fd Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 3 Apr 2025 14:59:28 -0700 Subject: [PATCH 2/2] Update on "[ET-VK] Improve packing format for int4 linear operator + misc improvements" ## Context Improve performance of the quantized int4 linear shader by packing the scales and zeros tensor, as well as the weight tensor in a more optimal way. See the comments in the `pack_int4_linear_weight_transposed_interleave` shader for more details about how the new packing works. ## Changes * Split int8 quantized linear and int4 quantized linear into separate C++ files for better code organization * Introduce packing shader for int4 weights * Update int4 linear shader to account for packed weights ## Impact This change massively improves the performance of the weight int4 quantized linear operator. With this change, running LLaMa 3.2 1B can now achieve 10 tok/s, from 0.9 tok/s on an Adreno 740. This is a 10x improvement! With this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 332.3 MB/s (74692800 bytes in 0.214s) I 00:00:00.003353 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003533 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003563 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003685 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003747 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003799 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003852 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003902 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003976 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004289 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:04.841690 executorch:runner.cpp:101] Reading metadata from model I 00:00:04.841808 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:04.841830 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:04.841851 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:04.841874 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:04.841893 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:04.841909 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:04.841927 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:04.841945 executorch:runner.cpp:133] eos_id = 128009 I 00:00:04.841951 executorch:runner.cpp:133] eos_id = 128001 I 00:00:04.841963 executorch:runner.cpp:188] RSS after loading model: 2229.828125 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:06.239633 executorch:runner.cpp:258] RSS after prompt prefill: 2229.828125 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:00:17.699086 executorch:runner.cpp:272] RSS after finishing text generation: 2229.828125 MiB (0 if unsupported) I 00:00:17.699155 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:00:17.699161 executorch:stats.h:114] Model Load Time: 4.837000 (seconds) I 00:00:17.699165 executorch:stats.h:124] Total inference time: 12.857000 (seconds) Rate: 8.788987 (tokens/second) I 00:00:17.699168 executorch:stats.h:132] Prompt evaluation: 1.398000 (seconds) Rate: 10.014306 (tokens/second) I 00:00:17.699171 executorch:stats.h:143] Generated 113 tokens: 11.459000 (seconds) Rate: 9.861244 (tokens/second) I 00:00:17.699174 executorch:stats.h:151] Time to first generated token: 1.398000 (seconds) I 00:00:17.699177 executorch:stats.h:158] Sampling time over 127 tokens: 549246500.843000 (seconds) ``` Before this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 302.0 MB/s (74637464 bytes in 0.236s) I 00:00:00.003050 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003200 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003226 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003337 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003396 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003449 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003502 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003553 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003629 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004075 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:05.417531 executorch:runner.cpp:101] Reading metadata from model I 00:00:05.417647 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:05.417669 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:05.417698 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:05.417716 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:05.417735 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:05.417751 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:05.417768 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:05.417787 executorch:runner.cpp:133] eos_id = 128009 I 00:00:05.417793 executorch:runner.cpp:133] eos_id = 128001 I 00:00:05.417808 executorch:runner.cpp:188] RSS after loading model: 2230.812500 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:19.689616 executorch:runner.cpp:258] RSS after prompt prefill: 2230.812500 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:02:15.269693 executorch:runner.cpp:272] RSS after finishing text generation: 2230.812500 MiB (0 if unsupported) I 00:02:15.269810 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:02:15.269825 executorch:stats.h:114] Model Load Time: 5.414000 (seconds) I 00:02:15.269832 executorch:stats.h:124] Total inference time: 129.852000 (seconds) Rate: 0.870221 (tokens/second) I 00:02:15.269837 executorch:stats.h:132] Prompt evaluation: 14.271000 (seconds) Rate: 0.981010 (tokens/second) I 00:02:15.269841 executorch:stats.h:143] Generated 113 tokens: 115.581000 (seconds) Rate: 0.977669 (tokens/second) I 00:02:15.269844 executorch:stats.h:151] Time to first generated token: 14.271000 (seconds) I 00:02:15.269847 executorch:stats.h:158] Sampling time over 127 tokens: 549711269.115000 (seconds) PyTorchObserver {"prompt_tokens":14,"generated_tokens":113,"model_load_start_ms":1743712527974,"model_load_end_ms":1743712533388,"inference_start_ms":1743712533388,"inference_end_ms":1743712663240,"prompt_eval_end_ms":1743712547659,"first_token_ms":1743712547659,"aggregate_sampling_time_ms":549711269115,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` Differential Revision: [D72412950](https://our.internmc.facebook.com/intern/diff/D72412950/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/q_4w_linear.glsl | 15 +++++++++------ ...rInt4.cpp => QuantizedLinearGroupwiseInt4.cpp} | 1 - 2 files changed, 9 insertions(+), 7 deletions(-) rename backends/vulkan/runtime/graph/ops/impl/{QuantizedLinearInt4.cpp => QuantizedLinearGroupwiseInt4.cpp} (98%) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index bfb3dc516d8..06a27df4f6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -71,6 +71,9 @@ void main() { // Each thread writes out 2 texels along the width axis, equivalent to 8 // scalar elements. Therefore multiply the thread_idx.x by 8. const uint out_col = gl_GlobalInvocationID.x << 3; + // Similar reasoning to the above, each thread works on 2 texels along the + // width axis so multiply thread_idx.x by 2. + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; if (out_col >= out_sizes.x || out_row >= out_sizes.y) { return; @@ -87,11 +90,11 @@ void main() { VEC4_T zeros[2]; for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - scales[0] = texelFetch(t_qparams, ivec3(gl_GlobalInvocationID.x << 1, 0, block_idx), 0); - zeros[0] = texelFetch(t_qparams, ivec3(gl_GlobalInvocationID.x << 1, 1, block_idx), 0); + scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); + zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - scales[1] = texelFetch(t_qparams, ivec3((gl_GlobalInvocationID.x << 1) + 1, 0, block_idx), 0); - zeros[1] = texelFetch(t_qparams, ivec3((gl_GlobalInvocationID.x << 1) + 1, 1, block_idx), 0); + scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); + zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); for (int g_idx = 0; g_idx < group_size; g_idx += 4) { const int k = block_idx * group_size + g_idx; @@ -113,6 +116,6 @@ void main() { } } - imageStore(t_out, ivec3((gl_GlobalInvocationID.x << 1), out_row, 0), sums[0]); - imageStore(t_out, ivec3((gl_GlobalInvocationID.x << 1) + 1, out_row, 0), sums[1]); + imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp similarity index 98% rename from backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp rename to backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp index 6f887fe4868..fc2d6609959 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt4.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp @@ -121,7 +121,6 @@ void add_q_4w_linear_node( ValueRef mat2 = prepack_int4_linear_weight_transposed_interleaved( graph, mat2_data, utils::kTexture3D); - // ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data); ValueRef scales_and_zeros = prepack_standard_hw_transposed( graph,