Skip to content

Commit 1907ae2

Browse files
committed
[ET-VK] Efficient tiled int8 matmul
Pull Request resolved: #9766 ## Context Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation. This implementation takes advantage of the following principles to squeeze out performance: * Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data. * Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load * Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix. Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/) ghstack-source-id: 275180032
1 parent de2aab6 commit 1907ae2

File tree

5 files changed

+184
-310
lines changed

5 files changed

+184
-310
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl

Lines changed: 0 additions & 212 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
${define_required_extensions(DTYPE)}
19+
20+
$if STORAGE == "buffer":
21+
${define_required_extensions("int8")}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
31+
32+
33+
layout(push_constant) uniform restrict Block {
34+
ivec4 out_sizes;
35+
ivec4 in_sizes;
36+
ivec4 weight_sizes;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
void main() {
42+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43+
const uint out_col = gl_GlobalInvocationID.x << 2;
44+
45+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
46+
return;
47+
}
48+
49+
VEC4_T a[TILE_ROWS];
50+
VEC4_T b[4];
51+
VEC4_T c[TILE_ROWS];
52+
53+
$if STORAGE == "buffer":
54+
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
55+
$else:
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
57+
58+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
59+
c[i] = VEC4_T(0.0);
60+
}
61+
62+
for (int pos = 0; pos < in_sizes.x; pos += 4) {
63+
// Preload weight tensor
64+
[[unroll]] for (int i = 0; i < 4; i++) {
65+
$if STORAGE == "buffer":
66+
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
67+
$else:
68+
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
69+
}
70+
71+
// Preload input tensor
72+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
73+
$if STORAGE == "buffer":
74+
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
75+
$else:
76+
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
77+
}
78+
79+
// Compute partial output
80+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
81+
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
82+
}
83+
}
84+
85+
// Store output tensor
86+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87+
$if STORAGE == "buffer":
88+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
89+
$else:
90+
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
91+
}
92+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_tiled:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
TILE_ROWS: 4
12+
shader_variants:
13+
- NAME: q_8w_linear_tiled_o4x4_texture3d_float
14+
STORAGE: texture3d
15+
TILE_ROWS: 4
16+
- NAME: q_8w_linear_tiled_o4x6_texture3d_float
17+
STORAGE: texture3d
18+
TILE_ROWS: 6

0 commit comments

Comments
 (0)