Skip to content

[ET-VK] Improve packing format for int4 linear operator + misc improvements #9949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,7 @@ vTensor::vTensor(
VK_CHECK_COND(
dim_order_is_valid(dim_order_), "computed dim order is invalid");

if (storage_type != utils::kBuffer) {
set_logical_limits(storage_.image_extents_);
}
set_logical_limits(storage_.image_extents_);
}

// NOLINTNEXTLINE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_required_extensions("uint8")}
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

uint8_t get_first(const uint8_t packed) {
return uint8_t((packed & 0xF0) >> 4);
}

uint8_t get_second(const uint8_t packed) {
return uint8_t(packed & 0x0F);
}

uint8_t combine(const uint8_t first, const uint8_t second) {
return uint8_t(first << 4 | second);
}

/*
* This shader packs the weight tensor into a texture.
*
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
* is a uint8_t, which contains 2 packed 4 bit uint values.
*
* The transform performed by this shader is to first transpose the tensor, so
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
* each value contain the 4, 5, 6, 7 4-bit values.
*
* As a concrete example, consider the following weight tensor. The | demarks
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
* leftmost 4 bits and 2 in the rightmost 4 bits.
*
* 1| 2, 3| 4, 5| 6, 7| 8,
* 9|10, 11|12, 13|14, 15|16,
* 17|18, 19|20, 21|22, 23|24,
* 25|26, 27|28, 29|30, 31|32,
* 33|34, 35|36, 37|38, 39|40,
* 41|42, 43|44, 45|46, 47|48,
* 49|50, 51|52, 53|54, 55|56,
* 57|58, 59|60, 61|62, 63|64,
*
* After packing, the packed tensor would contain
*
* 1|33, 9|41, 17|49, 25|57,
* 2|34, 10|42, 18|50, 26|58,
* 3|35, 11|43, 19|51, 27|59,
* 4|36, 12|44, 20|52, 28|60,
* 5|37, 13|45, 21|53, 29|61,
* 6|38, 14|46, 22|54, 30|62,
* 7|39, 15|47, 23|55, 31|63,
* 8|40, 16|48, 24|56, 32|64,
*
* The purpose of interleaving is to make it easier to extract the unpacked
* values in order using the u8vec4 vectorized type. With the packing in place,
* The 4-bit values can be extracted via
*
* u8vec4 packed;
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
* u8vec4 vals_4567 = (packed | 0x0F);
*/
void main() {
// Each thread writes 2 output texels along the height axis
ivec2 packed_pos = ivec2(
gl_GlobalInvocationID.x,
gl_GlobalInvocationID.y << 1);

// The packed tensor is width packed
if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) {
return;
}

int out_col = packed_pos.x << 3;
int out_row = packed_pos.y;

int in_col = out_row;
int in_int8_col = in_col >> 1;
int in_row = out_col;

int in_numrows = qmat2_sizes.x << 1;
int in_numcols = qmat2_sizes.y;
int in_num_int8_cols = qmat2_sizes.y >> 1;

uint8_t in_vals[8][2];
for (int r = 0; r < 8; ++r) {
if (in_row + r < in_numrows) {
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
in_vals[r][0] = get_first(in_val_packed);
in_vals[r][1] = get_second(in_val_packed);
} else {
in_vals[r][0] = uint8_t(254);
in_vals[r][1] = uint8_t(254);
}
}

u8vec4 out_tex_1 = u8vec4(
combine(in_vals[0][0], in_vals[4][0]),
combine(in_vals[1][0], in_vals[5][0]),
combine(in_vals[2][0], in_vals[6][0]),
combine(in_vals[3][0], in_vals[7][0]));

u8vec4 out_tex_2 = u8vec4(
combine(in_vals[0][1], in_vals[4][1]),
combine(in_vals[1][1], in_vals[5][1]),
combine(in_vals[2][1], in_vals[6][1]),
combine(in_vals[3][1], in_vals[7][1]));

$if STORAGE == "buffer":
int stride = qmat2_sizes.x >> 2;
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
$else:
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

pack_int4_linear_weight_transposed_interleaved:
parameter_names_with_default_values:
STORAGE: texture3d
shader_variants:
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
STORAGE: buffer
134 changes: 72 additions & 62 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,42 @@

#version 450 core

#include "indexing_utils.h"

#define PRECISION ${PRECISION}

#define FOUR 4

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
#define T ${buffer_scalar_type(DTYPE)}
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

${define_active_storage_type(STORAGE)}

${define_required_extensions([DTYPE, "uint8", "uint16"])}
#extension GL_EXT_control_flow_attributes : require
${define_required_extensions(DTYPE)}
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}
${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "ret_limits")}
${layout_declare_ubo(B, "ivec4", "x_sizes")}
${layout_declare_ubo(B, "ivec4", "weights_strides")}
${layout_declare_ubo(B, "ivec4", "qparams_strides")}
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 mat1_sizes;
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int group_size = 1;
layout(constant_id = 3) const int group_size = 64;

/*
* This shader computes a linear operator between a floating point input matrix
* x and a weights matrix that is quantized to 4 bits.
*
* The (W, H, C) shape of each tensor is:
* - x: (K, M)
* - weights: (K / 2, N)
* - weights: (N / 2, K)
* - The weights tensor has a data type of `uint8`. Each element in the tensor
* contains 2 4-bit values packed into a uint8.
* - See the pack_int4_linear_weight_transposed_interleave shader to see more
* details on how the weight tensor is stored.
* - qparams: (2, N, number_of_groups)
* - This tensor contains the scales and zeros quantization parameters for the
* weights tensor. The weight tensor is quantized group-wise, which means
Expand All @@ -57,56 +55,68 @@ layout(constant_id = 3) const int group_size = 1;
* Note that this shader assumes that all tensors are width packed.
*/
void main() {
// output positions being calculated are (n, m), (n + 1, m), ...
// This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
// of the weights tensor.
const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(ret_pos, ret_limits))) {
const uint out_row = gl_GlobalInvocationID.y;
// Each thread writes out 2 texels along the width axis, equivalent to 8
// scalar elements. Therefore multiply the thread_idx.x by 8.
const uint out_col = gl_GlobalInvocationID.x << 3;
// Similar reasoning to the above, each thread works on 2 texels along the
// width axis so multiply thread_idx.x by 2.
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;

if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
return;
}

// Since ret is width packed, need to multiply by 4
const uint16_t n = uint16_t(ret_pos.x * 4);
const int num_blocks = mat1_sizes.x / group_size;

// K is guaranteed to be a multiple of group size
const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
VEC4_T sums[2];

uint16_t k_texel_i = uint16_t(0);
vec4 sums = vec4(0.0);
for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {
vec4 scales;
vec4 zeros;
sums[0] = VEC4_T(0);
sums[1] = VEC4_T(0);

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const vec4 scale_and_zero = load_texel(
qparams, u16vec3(0, n + comp, block_idx));
scales[comp] = scale_and_zero.x;
zeros[comp] = scale_and_zero.y;
}
VEC4_T scales[2];
VEC4_T zeros[2];

$if WEIGHT_STORAGE == "buffer":
const int qmat2_stride = qmat2_sizes.x >> 2;

for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);

for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {
const VEC4_T x_texel = load_texel(
x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));

[[unroll]] for (int comp = 0; comp < 4; comp++) {
const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);
// Need to read 4 unpacked values, which corresponds to 2 packed values
const uint8_t weights_val_1 = weights[weights_bufi];
const uint8_t weights_val_2 = weights[weights_bufi + 1];

const u8vec4 weights_texel = u8vec4(
(weights_val_1 & 0xF0) >> 4,
weights_val_1 & 0x0F,
(weights_val_2 & 0xF0) >> 4,
weights_val_2 & 0x0F);

// Note that the unpacked 4-bit values are unsigned, therefore they must
// first be "centered" around 0 by subtracting 8 before applying the
// scale and zero point.
sums[comp] += dot(
x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);

for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
const int k = block_idx * group_size + g_idx;

$if IN_STORAGE == "buffer":
const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2];
$else:
const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0);

for (int comp = 0; comp < 4; ++comp) {
$if WEIGHT_STORAGE == "buffer":
const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
$else:
const uvec4 packed_weight_tex = texelFetch(
t_qmat2,
ivec3(gl_GlobalInvocationID.x, k + comp, 0),
0);

const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;
const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;

sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]);
sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]);
}
}
}
write_texel(ret, ret_pos, sums);

$if OUT_STORAGE == "buffer":
t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0];
t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1];
$else:
imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]);
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]);
}
19 changes: 13 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
q_4w_linear:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
OUT_STORAGE: texture3d
IN_STORAGE: texture3d
WEIGHT_STORAGE: texture3d
shader_variants:
- NAME: q_4w_linear_texture3d
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
IN_STORAGE: buffer
- NAME: q_4w_linear_buffer_buffer_texture3d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
- NAME: q_4w_linear_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
WEIGHT_STORAGE: buffer
Loading
Loading