Skip to content

[ET-VK] Implement linear_qcs4w #10588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/SS-JIA/222/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,12 @@ def register_mm_op(features: OpFeatures):
return features


@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
@update_features(
[
exir_ops.edge.aten._weight_int8pack_mm.default,
exir_ops.edge.et_vk.linear_qcs4w.default,
]
)
def register_int8_mm_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_axis_map=False,
Expand Down
18 changes: 14 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,32 @@
/*
* Fast division by 4 using bit shifting
*/
#define div4(x) (x >> 2)
#define div4(x) ((x) >> 2)

/*
* Fast multiplication by 4 using bit shifting
*/
#define mul4(x) ((x) << 2)

/*
* Divides input and rounds up to 4
*/
#define divup4(x) ((x + 3) >> 2)
#define divup4(x) (((x) + 3) >> 2)

/*
* Divides input by denominator and rounds up
*/
#define divup(x, d) (((x) + (d) - 1) / (d))

/*
* Aligns input to the next multiple of 4
*/
#define alignup4(x) ((x + 3) & -4)
#define alignup4(x) (((x) + 3) & -4)

/*
* Fast modulo by 4 using bit masking
*/
#define mod4(x) (x & 3)
#define mod4(x) ((x) & 3)

/*
* Find the packed dimension of a tensor given its strides. The packed dimension
Expand Down
145 changes: 102 additions & 43 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#define TILE_ROWS ${TILE_ROWS}
#define TILE_TXCOLS ${TILE_TXCOLS}

#define NGROUPS 8
#define NWORKERS 8
Expand All @@ -29,7 +30,10 @@ layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
$if QUANT_NBITS == 4:
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}

layout(push_constant) uniform restrict Block {
Expand All @@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS];

void main() {
const uint out_width_ntexels = divup4(out_sizes.x);
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
$if TILE_TXCOLS > 1:
const uint global_wg_x = uint(divup(out_sizes.x, 4 * TILE_TXCOLS));
const uint out_txcol = uint(
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
$else:
const uint global_wg_x = uint(divup4(out_sizes.x));
const uint out_txcol = uint(gl_GlobalInvocationID.x % global_wg_x);

const uint out_row = uint(
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);

$if QUANT_NBITS == 4:
const uint weight_txcol = uint(out_txcol / 2);

const int gid = int(gl_LocalInvocationID.x); // group id
const int wid = int(gl_LocalInvocationID.z); // worker id
Expand All @@ -56,46 +71,78 @@ void main() {
return;
}

VEC4_T a[TILE_ROWS];
VEC4_T b[4];
VEC4_T local_c[TILE_ROWS];
VEC4_T mat1[TILE_ROWS];
VEC4_T qmat2[4][TILE_TXCOLS];
VEC4_T local_sums[TILE_ROWS][TILE_TXCOLS];

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
local_c[i] = VEC4_T(0.0);
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
local_sums[r][${c}] = VEC4_T(0.0);
}

$if SCALES_STORAGE == "buffer":
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
$else:
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));

for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
// Preload t_weight
[[unroll]] for (int i = 0; i < 4; i++) {
$if WEIGHT_STORAGE == "buffer":
b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2];
VEC4_T scales[TILE_TXCOLS];
$for c in range(TILE_TXCOLS):
$if SCALES_STORAGE == "buffer":
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
$else:
scales[${c}] = VEC4_T(
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));

for (int pos = (4 * wid), txpos = wid;
pos < in_sizes.x;
pos += (4 * NWORKERS), txpos += NWORKERS) {
$if WEIGHT_STORAGE == "buffer":
uint qmat2_bufi;
uint weight_row_txstride = div4(weight_sizes.x);

// Preload weight tensor
[[unroll]] for (int r = 0; r < 4; r++) {
$if QUANT_NBITS == 4:
$for c in range(0, TILE_TXCOLS, 2):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
$else:
const uvec4 packed_weight_tex = texelFetch(
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);

qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);
qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
$else:
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
$else:
qmat2[r][${c}] = VEC4_T(
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
}
// Preload t_in
for (int i = 0; i < TILE_ROWS; i++) {

$if IN_STORAGE == "buffer":
uint in_row_txstride = div4(in_sizes.x);

// Preload input tensor
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
$if IN_STORAGE == "buffer":
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos];
$else:
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
mat1[i] = VEC4_T(
texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
}

// Accumulate partial output
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
local_c[i] += a[i].x * b[0] +
a[i].y * b[1] +
a[i].z * b[2] +
a[i].w * b[3];
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
local_sums[r][${c}] += mat1[r].x * qmat2[0][${c}] +
mat1[r].y * qmat2[1][${c}] +
mat1[r].z * qmat2[2][${c}] +
mat1[r].w * qmat2[3][${c}];
}
}

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
partial_c[gid][wid][i] = local_c[i];
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
partial_sums[gid][wid][r][${c}] = local_sums[r][${c}];
}

memoryBarrierShared();
Expand All @@ -105,21 +152,33 @@ void main() {
return;
}

VEC4_T c[TILE_ROWS];
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];

for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
sums[r][${c}] = VEC4_T(0.0);

for (int row = 0; row < TILE_ROWS; ++row) {
c[row] = VEC4_T(0.0);
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
c[row] += partial_c[gid][worker][row];
$for c in range(TILE_TXCOLS):
sums[r][${c}] += partial_sums[gid][worker][r][${c}];
}
}

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
$if OUT_STORAGE == "buffer":
if (out_row + i < out_sizes.y) {
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
}
$else:
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
$if OUT_STORAGE == "buffer":
uint out_bufi;
uint out_row_txstride = div4(out_sizes.x);

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
$if OUT_STORAGE == "buffer":
if (out_row + r < out_sizes.y) {
out_bufi = (out_row + r) * out_row_txstride + out_txcol;
t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
}
$else:
imageStore(
t_out,
ivec3(out_txcol + ${c}, out_row + r, 0),
sums[r][${c}] * scales[${c}]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ linear_qcsnw_coop:
WEIGHT_STORAGE: texture2d
SCALES_STORAGE: texture2d
TILE_ROWS: 4
TILE_TXCOLS: 1
QUANT_NBITS: 8
generate_variant_forall:
TILE_ROWS:
- VALUE: 1
Expand Down
Loading
Loading