Skip to content
Open
54 changes: 28 additions & 26 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)}

$if STORAGE == "buffer":
${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "out_strides")}
${layout_declare_ubo(6, "int", "out_numel")}
${layout_declare_ubo(7, "ivec4", "mat1_sizes")}
${layout_declare_ubo(8, "ivec4", "mat1_strides")}
${layout_declare_ubo(9, "ivec4", "qmat2_strides")}
${layout_declare_ubo(10, "ivec4", "scales_strides")}
layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 out_strides;
ivec4 mat1_sizes;
ivec4 mat1_strides;
ivec4 qmat2_strides;
ivec4 scales_strides;
int out_numel;
};
$else:
${layout_declare_ubo(4, "ivec3", "out_limits")}
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
layout(push_constant) uniform restrict Block {
ivec3 out_limits;
ivec4 mat1_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand Down Expand Up @@ -83,42 +87,40 @@ void main() {

#else // USING_TEXTURE

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

void main() {
const u16vec2 out_pos = u16vec2(
gl_GlobalInvocationID.x,
gl_GlobalInvocationID.y);
const ivec2 out_pos = ivec2(
gl_GlobalInvocationID.x % out_limits.x,
gl_GlobalInvocationID.x / out_limits.x);

if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) {
if (out_pos.y >= out_limits.y) {
return;
}

const uint16_t qmat2_pos_x = out_pos.x;
const int qmat2_pos_x = out_pos.x;

VEC4_T outtex = VEC4_T(0);

const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0));

VEC4_T mat1_tex;
VEC4_T mat2_tex[4];
for (
uint16_t i = uint16_t(0), x = uint16_t(0);
i < uint16_t(mat1_sizes.x);
i += uint16_t(4), x++)
int i = 0, x = 0;
i < mat1_sizes.x;
i += 4, x++)
{
mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
mat1_tex = load_texel(t_mat1, ivec3(x, out_pos.y, 0));

mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0));
mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0));
mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0));
mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0));
mat2_tex[0] = load_texel(t_qmat2, ivec3(out_pos.x, i, 0));
mat2_tex[1] = load_texel(t_qmat2, ivec3(out_pos.x, i + 1, 0));
mat2_tex[2] = load_texel(t_qmat2, ivec3(out_pos.x, i + 2, 0));
mat2_tex[3] = load_texel(t_qmat2, ivec3(out_pos.x, i + 3, 0));

outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3];
}

outtex *= scales;
write_texel(t_out, u16vec3(out_pos, 0), outtex);
write_texel(t_out, ivec3(out_pos, 0), outtex);
}

#endif
58 changes: 19 additions & 39 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,47 +98,25 @@ void add_q_8w_linear_node(
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));

vkapi::ParamsBindList ubos({});
std::vector<PushConstantDataInfo> pcs;
if (graph.is_buffer_storage(out_W_packed)) {
ubos.append(
{graph.sizes_ubo(out_W_packed),
graph.strides_ubo(out_W_packed),
graph.numel_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed),
graph.strides_ubo(mat1),
graph.strides_ubo(q_mat2),
graph.strides_ubo(scales)});
pcs = {
graph.sizes_pc_of(out_W_packed),
graph.strides_pc_of(out_W_packed),
graph.sizes_pc_of(mat1_W_packed),
graph.strides_pc_of(mat1),
graph.strides_pc_of(q_mat2),
graph.strides_pc_of(scales),
graph.numel_pc_of(out_W_packed)};
} else {
ubos.append(
{graph.logical_limits_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed)});
pcs = {
graph.logical_limits_pc_of(out_W_packed),
graph.sizes_pc_of(mat1_W_packed)};
}

utils::uvec3 global_wg;
if (graph.is_buffer_storage(out)) {
global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
} else {
global_wg = graph.logical_limits_of(out_W_packed);
}

utils::uvec3 local_wg{8, 8, 1};
int32_t out_W = graph.size_at<int32_t>(-1, out_W_packed);

if (graph.is_buffer_storage(out_W_packed)) {
local_wg[0] = 64;
local_wg[1] = 1;
local_wg[2] = 1;
} else {
if (out_W % 8 != 0) {
if (out_W % 4 == 0) {
local_wg[0] = 4;
local_wg[1] = 16;
} else {
local_wg[0] = 2;
local_wg[1] = 32;
}
}
}
const utils::uvec3 global_wg = {
static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
const utils::uvec3 local_wg{64, 1, 1};

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
Expand All @@ -149,11 +127,13 @@ void add_q_8w_linear_node(
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
{},
// Specialization Constants
{},
// Resizing Logic
resize_q_8w_linear_node));
resize_q_8w_linear_node,
{},
pcs));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
Expand Down
Loading