diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index aa3cca5f384..8502e254ec5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -377,7 +377,12 @@ def register_mm_op(features: OpFeatures): return features -@update_features(exir_ops.edge.aten._weight_int8pack_mm.default) +@update_features( + [ + exir_ops.edge.aten._weight_int8pack_mm.default, + exir_ops.edge.et_vk.linear_qcs4w.default, + ] +) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_axis_map=False, diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 2126104430f..2b41d2b7e1a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -41,22 +41,32 @@ /* * Fast division by 4 using bit shifting */ -#define div4(x) (x >> 2) +#define div4(x) ((x) >> 2) + +/* + * Fast multiplication by 4 using bit shifting + */ +#define mul4(x) ((x) << 2) /* * Divides input and rounds up to 4 */ -#define divup4(x) ((x + 3) >> 2) +#define divup4(x) (((x) + 3) >> 2) + +/* + * Divides input by denominator and rounds up + */ +#define divup(x, d) (((x) + (d) - 1) / (d)) /* * Aligns input to the next multiple of 4 */ -#define alignup4(x) ((x + 3) & -4) +#define alignup4(x) (((x) + 3) & -4) /* * Fast modulo by 4 using bit masking */ -#define mod4(x) (x & 3) +#define mod4(x) ((x) & 3) /* * Find the packed dimension of a tensor given its strides. The packed dimension diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl index 3ad9e759910..c766a3cd7d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl @@ -14,6 +14,7 @@ #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} #define TILE_ROWS ${TILE_ROWS} +#define TILE_TXCOLS ${TILE_TXCOLS} #define NGROUPS 8 #define NWORKERS 8 @@ -29,7 +30,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +$if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} layout(push_constant) uniform restrict Block { @@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS]; +shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS]; void main() { - const uint out_width_ntexels = divup4(out_sizes.x); - const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2; - const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS; + // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. + $if TILE_TXCOLS > 1: + const uint global_wg_x = uint(divup(out_sizes.x, 4 * TILE_TXCOLS)); + const uint out_txcol = uint( + (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + $else: + const uint global_wg_x = uint(divup4(out_sizes.x)); + const uint out_txcol = uint(gl_GlobalInvocationID.x % global_wg_x); + + const uint out_row = uint( + (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + + $if QUANT_NBITS == 4: + const uint weight_txcol = uint(out_txcol / 2); const int gid = int(gl_LocalInvocationID.x); // group id const int wid = int(gl_LocalInvocationID.z); // worker id @@ -56,46 +71,78 @@ void main() { return; } - VEC4_T a[TILE_ROWS]; - VEC4_T b[4]; - VEC4_T local_c[TILE_ROWS]; + VEC4_T mat1[TILE_ROWS]; + VEC4_T qmat2[4][TILE_TXCOLS]; + VEC4_T local_sums[TILE_ROWS][TILE_TXCOLS]; - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - local_c[i] = VEC4_T(0.0); + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + local_sums[r][${c}] = VEC4_T(0.0); } - $if SCALES_STORAGE == "buffer": - const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); - $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); - - for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) { - // Preload t_weight - [[unroll]] for (int i = 0; i < 4; i++) { - $if WEIGHT_STORAGE == "buffer": - b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2]; + VEC4_T scales[TILE_TXCOLS]; + $for c in range(TILE_TXCOLS): + $if SCALES_STORAGE == "buffer": + scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); + $else: + scales[${c}] = VEC4_T( + texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0)); + + for (int pos = (4 * wid), txpos = wid; + pos < in_sizes.x; + pos += (4 * NWORKERS), txpos += NWORKERS) { + $if WEIGHT_STORAGE == "buffer": + uint qmat2_bufi; + uint weight_row_txstride = div4(weight_sizes.x); + + // Preload weight tensor + [[unroll]] for (int r = 0; r < 4; r++) { + $if QUANT_NBITS == 4: + $for c in range(0, TILE_TXCOLS, 2): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; + const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + $else: + const uvec4 packed_weight_tex = texelFetch( + t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); + + qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); + qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); $else: - b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); + $for c in range(TILE_TXCOLS): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; + qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + $else: + qmat2[r][${c}] = VEC4_T( + texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); } - // Preload t_in - for (int i = 0; i < TILE_ROWS; i++) { + + $if IN_STORAGE == "buffer": + uint in_row_txstride = div4(in_sizes.x); + + // Preload input tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { $if IN_STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos]; $else: - a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + mat1[i] = VEC4_T( + texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); } // Accumulate partial output - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - local_c[i] += a[i].x * b[0] + - a[i].y * b[1] + - a[i].z * b[2] + - a[i].w * b[3]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + local_sums[r][${c}] += mat1[r].x * qmat2[0][${c}] + + mat1[r].y * qmat2[1][${c}] + + mat1[r].z * qmat2[2][${c}] + + mat1[r].w * qmat2[3][${c}]; } } - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - partial_c[gid][wid][i] = local_c[i]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + partial_sums[gid][wid][r][${c}] = local_sums[r][${c}]; } memoryBarrierShared(); @@ -105,21 +152,33 @@ void main() { return; } - VEC4_T c[TILE_ROWS]; + VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; + + for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] = VEC4_T(0.0); - for (int row = 0; row < TILE_ROWS; ++row) { - c[row] = VEC4_T(0.0); [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { - c[row] += partial_c[gid][worker][row]; + $for c in range(TILE_TXCOLS): + sums[r][${c}] += partial_sums[gid][worker][r][${c}]; } } - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if OUT_STORAGE == "buffer": - if (out_row + i < out_sizes.y) { - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; - } - $else: - imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + $if OUT_STORAGE == "buffer": + uint out_bufi; + uint out_row_txstride = div4(out_sizes.x); + + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + out_bufi = (out_row + r) * out_row_txstride + out_txcol; + t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + } + $else: + imageStore( + t_out, + ivec3(out_txcol + ${c}, out_row + r, 0), + sums[r][${c}] * scales[${c}]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml index e0477a3a3d1..bb222fcd965 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml @@ -12,6 +12,8 @@ linear_qcsnw_coop: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 + TILE_TXCOLS: 1 + QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: - VALUE: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 3ef952ea34d..f6f05aab7ca 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -14,6 +14,7 @@ #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} #define TILE_ROWS ${TILE_ROWS} +#define TILE_TXCOLS ${TILE_TXCOLS} ${define_required_extensions(DTYPE)} @@ -26,7 +27,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +$if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} @@ -43,57 +47,110 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require void main() { - const uint16_t out_width_ntexels = uint16_t(divup4(out_sizes.x)); - const uint16_t out_col = uint16_t((gl_GlobalInvocationID.x % out_width_ntexels) << 2); - const uint16_t out_row = uint16_t((gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS); + // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. + $if TILE_TXCOLS > 1: + const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS)); + const uint16_t out_txcol = uint16_t( + (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + $else: + const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x)); + const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x); + + const uint16_t out_row = uint16_t( + (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + + $if QUANT_NBITS == 4: + const uint16_t weight_txcol = uint16_t(out_txcol / 2); if (out_row >= uint16_t(out_sizes.y)) { return; } - VEC4_T a[TILE_ROWS]; - VEC4_T b[4]; - VEC4_T c[TILE_ROWS]; + VEC4_T mat1[TILE_ROWS]; + VEC4_T qmat2[4][TILE_TXCOLS]; + VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; - $if SCALES_STORAGE == "buffer": - const VEC4_T scales = VEC4_T(t_scales[int(out_col >> 2)]); - $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2(out_col >> 2, 0), 0)); + VEC4_T scales[TILE_TXCOLS]; + $for c in range(TILE_TXCOLS): + $if SCALES_STORAGE == "buffer": + scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); + $else: + scales[${c}] = VEC4_T( + texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0)); - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - c[i] = VEC4_T(0.0); + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] = VEC4_T(0.0); } - for (uint16_t pos = uint16_t(0); pos < uint16_t(in_sizes.x); pos += uint16_t(4)) { + for (uint16_t pos = uint16_t(0), txpos = uint16_t(0); + pos < uint16_t(in_sizes.x); + pos += uint16_t(4), txpos += uint16_t(1)) { + $if WEIGHT_STORAGE == "buffer": + uint qmat2_bufi; + uint weight_row_txstride = div4(weight_sizes.x); + // Preload weight tensor - [[unroll]] for (int i = 0; i < 4; i++) { - $if WEIGHT_STORAGE == "buffer": - b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2]; + [[unroll]] for (int r = 0; r < 4; r++) { + $if QUANT_NBITS == 4: + $for c in range(0, TILE_TXCOLS, 2): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; + const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + $else: + const uvec4 packed_weight_tex = texelFetch( + t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); + + qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); + qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); $else: - b[i] = VEC4_T(texelFetch(t_weight, u16vec2(out_col >> 2, pos + i), 0)); + $for c in range(TILE_TXCOLS): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; + qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + $else: + qmat2[r][${c}] = VEC4_T( + texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0)); } + $if IN_STORAGE == "buffer": + uint in_row_txstride = div4(in_sizes.x); + // Preload input tensor [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { $if IN_STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos]; $else: - a[i] = VEC4_T(texelFetch(t_in, u16vec3(pos >> 2, out_row + i, 0), 0)); + mat1[i] = VEC4_T( + texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); } // Accumulate output - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] += mat1[r].x * qmat2[0][${c}] + + mat1[r].y * qmat2[1][${c}] + + mat1[r].z * qmat2[2][${c}] + + mat1[r].w * qmat2[3][${c}]; } } // Store to output tensor - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if OUT_STORAGE == "buffer": - if (out_row + i < out_sizes.y) { - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; - } - $else: - imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + $if OUT_STORAGE == "buffer": + uint out_bufi; + uint out_row_txstride = div4(out_sizes.x); + + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + out_bufi = (out_row + r) * out_row_txstride + out_txcol; + t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + } + $else: + imageStore( + t_out, + ivec3(out_txcol + ${c}, out_row + r, 0), + sums[r][${c}] * scales[${c}]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index f9f0134d995..1c9ec4e524a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -12,6 +12,8 @@ linear_qcsnw_tiled: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 + TILE_TXCOLS: 1 + QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: - VALUE: 1 @@ -30,3 +32,11 @@ linear_qcsnw_tiled: OUT_STORAGE: buffer WEIGHT_STORAGE: buffer SCALES_STORAGE: buffer + - NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float + TILE_TXCOLS: 2 + QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + TILE_TXCOLS: 2 + QUANT_NBITS: 4 diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 7269b75ae6e..60475663a61 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -17,6 +17,7 @@ namespace vkcompute { void check_linear_qcsnw_args( const ComputeGraph& graph, + const int quant_nbits, const ValueRef mat1, const ValueRef qmat2_data, const ValueRef scales, @@ -31,13 +32,20 @@ void check_linear_qcsnw_args( VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); - VK_CHECK_COND( - utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); - VK_CHECK_COND( - utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + if (quant_nbits == 4) { + VK_CHECK_COND( + utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes) * 2); + VK_CHECK_COND( + utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + } else { + VK_CHECK_COND( + utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); + VK_CHECK_COND( + utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + } } -void resize_linear_qcs8w_node( +void resize_linear_qcsnw_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { @@ -48,7 +56,12 @@ void resize_linear_qcs8w_node( vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]); const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-1, qmat2->sizes()); + int out_rows = utils::val_at(-1, qmat2->sizes()); + // Byte dtype suggests 4-bit quantization in which case the weight tensor is + // packed with 2 values per byte. + if (qmat2->dtype() == vkapi::kByte) { + out_rows *= 2; + } std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { @@ -131,7 +144,7 @@ void add_linear_qcs8w_node( // Specialization Constants {}, // Resizing Logic - resize_linear_qcs8w_node, + resize_linear_qcsnw_node, {}, pcs)); if (!graph.is_buffer_storage(out) && @@ -140,27 +153,33 @@ void add_linear_qcs8w_node( } } -void add_linear_qcs8w_tiled_node( +void add_linear_qcsnw_tiled_node( ComputeGraph& graph, const bool use_coop_algorithm, + const int quant_nbits, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - utils::StorageType q_mat2_storage = utils::kTexture2D; - uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); std::vector qmat2_orig_sizes = graph.sizes_of(q_mat2_data); const int64_t ndim = graph.dim_of(q_mat2_data); const int64_t K = qmat2_orig_sizes.at(ndim - 1); const int64_t N = qmat2_orig_sizes.at(ndim - 2); - if (N > max_extent * 4 || K > max_extent) { - q_mat2_storage = utils::kBuffer; - } + ValueRef q_mat2; + if (quant_nbits == 4) { + q_mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, q_mat2_data); + } else { + utils::StorageType q_mat2_storage = utils::kTexture2D; + if (N > max_extent * 4 || K > max_extent) { + q_mat2_storage = utils::kBuffer; + } - ValueRef q_mat2 = prepack_standard_hw_transposed( - graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + q_mat2 = prepack_standard_hw_transposed( + graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + } utils::StorageType scales_storage = utils::kTexture2D; if (N > max_extent) { @@ -169,8 +188,14 @@ void add_linear_qcs8w_tiled_node( ValueRef scales = prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); - std::string kernel_name = - use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled"; + std::string kernel_name; + if (quant_nbits == 4) { + kernel_name = + use_coop_algorithm ? "linear_qcs4w_coop" : "linear_qcs4w_tiled"; + } else { + kernel_name = + use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled"; + } kernel_name.reserve(kShaderNameReserve); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); @@ -195,9 +220,16 @@ void add_linear_qcs8w_tiled_node( out_tile_nrows = 4; } + // Number of output texels in the output tile + uint32_t out_tile_ntxcols = 1; + if (quant_nbits == 4) { + out_tile_ntxcols = 2; + } + utils::uvec3 out_limits = graph.logical_limits_of(out); + uint32_t global_wg_x = utils::div_up(out_limits[0], out_tile_ntxcols); utils::uvec3 global_wg_size = { - out_limits[0] * (utils::div_up(out_limits[1], out_tile_nrows)), + global_wg_x * (utils::div_up(out_limits[1], out_tile_nrows)), 1, out_limits[2]}; @@ -218,7 +250,7 @@ void add_linear_qcs8w_tiled_node( // Specialization Constants {}, // Resizing Logic - resize_linear_qcs8w_node, + resize_linear_qcsnw_node, {}, // Push Constants {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}})); @@ -235,7 +267,7 @@ bool can_use_tiled_impl( // Check if mat1 is not a 3D tensor or that batches = 1 // TODO(ssjia): Add support for batches in the tiled impl - if (graph.dim_of(mat1) == 3 && graph.size_at(-1, mat1) != 1) { + if (graph.dim_of(mat1) == 3 && graph.size_at(0, mat1) != 1) { return false; } // Check that K is a multiple of 4 @@ -280,17 +312,27 @@ bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) { void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { - check_linear_qcsnw_args(graph, args[0], args[1], args[2], args[3]); + check_linear_qcsnw_args(graph, 8, args[0], args[1], args[2], args[3]); if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); - return add_linear_qcs8w_tiled_node( - graph, use_coop_algorithm, args[0], args[1], args[2], args[3]); + return add_linear_qcsnw_tiled_node( + graph, use_coop_algorithm, 8, args[0], args[1], args[2], args[3]); } return add_linear_qcs8w_node(graph, args[0], args[1], args[2], args[3]); } +void linear_qcs4w(ComputeGraph& graph, const std::vector& args) { + check_linear_qcsnw_args(graph, 4, args[0], args[1], args[2], args[3]); + + VK_CHECK_COND(can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])); + bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); + return add_linear_qcsnw_tiled_node( + graph, use_coop_algorithm, 4, args[0], args[1], args[2], args[3]); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); + VK_REGISTER_OP(et_vk.linear_qcs4w.default, linear_qcs4w); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index ec718bea7da..f7803fb2e16 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -70,54 +70,6 @@ void resize_linear_qga4w_node( out->virtual_resize(new_out_sizes); } -ValueRef prepack_int4_linear_weight_transposed_interleaved( - ComputeGraph& graph, - const ValueRef qmat2_data) { - std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); - const int64_t ndim = graph.dim_of(qmat2_data); - - const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; - const int64_t N = qmat2_orig_sizes.at(ndim - 2); - const int64_t N_div2 = N / int64_t(2); - - utils::StorageType storage_type = utils::kTexture2D; - uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); - if (N_div2 > max_extent * 4 || K > max_extent) { - storage_type = utils::kBuffer; - } - - std::vector qmat2_sizes{K, N_div2}; - ValueRef qmat2 = graph.add_tensor( - qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); - - utils::uvec3 global_wg_size; - global_wg_size = graph.logical_limits_of(qmat2); - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2)); - - std::string kernel_name = - graph.context()->adapter_ptr()->has_full_int8_buffers_support() - ? "pack_int4_linear_weight_transposed_interleaved" - : "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer"; - add_storage_type_suffix(kernel_name, storage_type); - - graph.prepack_nodes().emplace_back(new PrepackNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), - // Inputs and Outputs - qmat2_data, - qmat2, - // UBOs - {}, - // Specialization Constants - {}, - // Push Constants - {graph.sizes_pc_of(qmat2)})); - - return qmat2; -} - void add_linear_qga4w_node( ComputeGraph& graph, const ValueRef mat1, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index f59d1cd65d9..e8f562d39c2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -237,6 +237,54 @@ ValueRef prepack_direct_copy_buffer( return tensor; } +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + const int64_t N_div2 = N / int64_t(2); + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (N_div2 > max_extent * 4 || K > max_extent) { + storage_type = utils::kBuffer; + } + + std::vector qmat2_sizes{K, N_div2}; + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); + + utils::uvec3 global_wg_size; + global_wg_size = graph.logical_limits_of(qmat2); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2)); + + std::string kernel_name = + graph.context()->adapter_ptr()->has_full_int8_buffers_support() + ? "pack_int4_linear_weight_transposed_interleaved" + : "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2)})); + + return qmat2; +} + void prepack_op(ComputeGraph& graph, const std::vector& args) { return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 1b6f245bd34..090a3718295 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -87,4 +87,11 @@ ValueRef prepack_direct_copy_buffer( ComputeGraph& graph, const ValueRef tensor_data); +// +// Op specific prepack functions + +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data); + } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 5d08ee57859..b95b7b3aa6d 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -62,7 +62,7 @@ at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { return weights_unpacked; } -at::Tensor dequantize_and_linear( +at::Tensor dequantize_and_linear_qga4w( const at::Tensor& x, const at::Tensor& weights_4x2, const int64_t groupsize, @@ -97,6 +97,56 @@ at::Tensor dequantize_and_linear( return at::linear(x, weights_dequantized); } +at::Tensor dequantize_and_linear_qcs4w( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const at::Tensor& scales) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + // const int scale_idx = k_groups * n + group_idx; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = scales[n].item().to(); + + weights_dequantized[n][k] = (float(first_val) - 8.0) * scale; + weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale; + } + } + + return at::linear(x, weights_dequantized); +} + +at::Tensor linear_qcs4w_reference_impl( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const at::Tensor& scales) { + const std::vector original_x_size(x.sizes().vec()); + const size_t ndim = original_x_size.size(); + const int64_t out_features = weights_4x2.size(0); + const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); + + const at::Tensor weights_unpacked = + (unpack_weights_4x2(weights_4x2) - 8).to(at::kChar); + at::Tensor out = + at::_weight_int8pack_mm(x_flattened, weights_unpacked, scales); + + std::vector out_shape( + original_x_size.begin(), original_x_size.end()); + out_shape.at(ndim - 1) = out_features; + return out.reshape(out_shape); +} + // // Test functions // @@ -126,12 +176,31 @@ void test_reference_linear_qga4w( scales_and_zeros, inner_k_tiles); - at::Tensor out_ref = dequantize_and_linear( + at::Tensor out_ref = dequantize_and_linear_qga4w( x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); ASSERT_TRUE(at::allclose(out, out_ref)); } +void test_reference_linear_qcs4w( + const int B, + const int M, + const int K, + const int N) { + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); + + at::Tensor scales = at::rand({N}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out = linear_qcs4w_reference_impl(x, weights_4x2, scales); + + at::Tensor out_ref = dequantize_and_linear_qcs4w(x, weights_4x2, scales); + + ASSERT_TRUE(at::allclose(out, out_ref)); +} + vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { using namespace vkcompute; switch (at_scalartype) { @@ -265,6 +334,85 @@ void test_vulkan_linear_qga4w( vkcompute::utils::kTexture3D); } +void test_vulkan_linear_qcs4w_impl( + const int B, + const int M, + const int K, + const int N, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + at::Tensor scales = at::rand({N}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out_ref = linear_qcs4w_reference_impl(x, weights_4x2, scales); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(scales); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); + + VK_GET_OP_FN("et_vk.linear_qcs4w.default") + (graph, {r_x.value, r_weights_4x2, r_scales, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); +} + +void test_vulkan_linear_qcs4w( + const int B, + const int M, + const int K, + const int N) { + test_vulkan_linear_qcs4w_impl( + B, M, K, N, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + + test_vulkan_linear_qcs4w_impl( + B, M, K, N, vkcompute::utils::kTexture3D, vkcompute::utils::kTexture3D); +} + TEST(VulkanLinearQGA4WTest, test_reference_impl) { test_reference_linear_qga4w( /*B = */ 1, @@ -294,3 +442,33 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*K = */ 256, /*N = */ 256); } + +TEST(VulkanLinearQCS4WTest, test_reference_impl) { + test_reference_linear_qcs4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} + +TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); + + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 256, + /*N = */ 256); +} + +TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 32, + /*K = */ 32, + /*N = */ 32); +}