Skip to content

Commit 92f3957

Browse files
pytorchbotkirklandsign
authored andcommitted
[ET-VK] Store weights transposed for int8 linear (#9803)
## Context The weight tensor of a linear layer is usually stored in a transposed manner, such that when computing the matrix multiplication, the reduction traverses along the rows of the weight tensor as opposed to the columns. This results in a better memory access pattern for CPUs. However, for GPUs, I have found that "un-transposing" the weight tensors result in better performance. This is likely due to the fact since GPUs can compute multiple output elements in parallel, reading along the columns allows for coalescing memory loads among threads in a work group. ## Changes * Introduce the ability to transpose height and weight dims when transferring tensor data to the GPU. * Prepackthe weight tensor "un-transposed" for the int8 quantized linear operator Differential Revision: [D72066588](https://our.internmc.facebook.com/intern/diff/D72066588/)
1 parent 12e39b4 commit 92f3957

File tree

8 files changed

+110
-26
lines changed

8 files changed

+110
-26
lines changed

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ ${layout_declare_ubo(B, "ivec4", "sizes")}
2727
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2828

2929
${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
30+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
31+
3032
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3133
const lowp int packed_dim = unhash_packed_dim(t_layout);
3234

@@ -41,8 +43,23 @@ int extend_sign(int x) {
4143
}
4244

4345
ivec4 read_texel(ivec4 tidx) {
46+
ivec4 tidx_to_use = tidx;
47+
ivec4 sizes_to_use = sizes;
48+
int packed_dim_to_use = packed_dim;
49+
if (transpose_hw == 1) {
50+
sizes_to_use.xy = sizes_to_use.yx;
51+
tidx_to_use.xy = tidx.yx;
52+
53+
if (packed_dim == 1) {
54+
packed_dim_to_use = 0;
55+
}
56+
if (packed_dim == 0) {
57+
packed_dim_to_use = 1;
58+
}
59+
}
60+
4461
const ivec4 buf_indices = tidx_to_nchwi(
45-
tidx, sizes, packed_dim);
62+
tidx_to_use, sizes_to_use, packed_dim_to_use);
4663

4764
int shift = (1 << 8) - 1;
4865
ivec4 masks;
@@ -70,7 +87,7 @@ ivec4 read_texel(ivec4 tidx) {
7087

7188
void main() {
7289
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
73-
const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim);
90+
ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim);
7491

7592
if (any(greaterThanEqual(tidx, sizes))) {
7693
return;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2121
// This constant is unused in this shader but is kept so that the signature is
2222
// consistent with nchw_to_image.
2323
${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")}
24+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
2425

2526
void main() {
2627
int out_bufi = int(gl_GlobalInvocationID.x);
@@ -29,7 +30,13 @@ void main() {
2930
}
3031

3132
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides);
32-
const int in_nchwi = tidx_to_nchwi(out_tidx, out_sizes);
33+
34+
ivec4 sizes = out_sizes;
35+
if (transpose_hw == 1) {
36+
sizes.xy = sizes.yx;
37+
out_tidx.xy = out_tidx.yx;
38+
}
39+
const int in_nchwi = tidx_to_nchwi(out_tidx, sizes);
3340

3441
t_out[out_bufi] = nchw_in[in_nchwi];
3542
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,31 @@ $if not FROM_STAGING:
3030
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3131

3232
${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
33+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
34+
3335
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3436
const lowp int packed_dim = unhash_packed_dim(t_layout);
3537

3638
VEC4_T read_texel(ivec4 tidx) {
39+
ivec4 tidx_to_use = tidx;
40+
ivec4 sizes_to_use = sizes;
41+
int packed_dim_to_use = packed_dim;
42+
if (transpose_hw == 1) {
43+
sizes_to_use.xy = sizes_to_use.yx;
44+
tidx_to_use.xy = tidx.yx;
45+
46+
if (packed_dim == 1) {
47+
packed_dim_to_use = 0;
48+
}
49+
if (packed_dim == 0) {
50+
packed_dim_to_use = 1;
51+
}
52+
}
53+
3754
$if FROM_STAGING:
38-
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
55+
const ivec4 buf_indices = tidx_to_nchwi(tidx_to_use, sizes_to_use, packed_dim_to_use);
3956
$else:
40-
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
57+
const ivec4 buf_indices = tidx_to_4bufi(tidx_to_use, buf_strides, packed_dim_to_use);
4158

4259
VEC4_T texel = VEC4_T(0);
4360
if (tidx[packed_dim] < sizes[packed_dim]) {

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,21 @@ void main() {
6464

6565
FLOAT_T outval = FLOAT_T(0.0);
6666

67-
// Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
6867
int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
69-
// Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
70-
// tensor is transposed
71-
int qmat2_offset = out_tidx.x * qmat2_strides.y;
68+
int qmat2_offset = out_tidx.x;
7269

7370
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
7471
for (int i = 0; i < mat1_sizes.x; i++) {
7572
const FLOAT_T mat1_val = t_mat1[mat1_offset];
76-
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
73+
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]);
7774

7875
outval += mat1_val * mat2_val;
7976

8077
mat1_offset++;
81-
qmat2_offset++;
78+
qmat2_offset += qmat2_strides.y;
8279
}
8380

84-
t_out[out_bufi] = outval;
81+
t_out[out_bufi] = outval * scale;
8582
}
8683

8784
#else // USING_TEXTURE
@@ -97,25 +94,27 @@ void main() {
9794
return;
9895
}
9996

100-
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
97+
const uint16_t qmat2_pos_x = out_pos.x;
10198

10299
VEC4_T outtex = VEC4_T(0);
103100

104101
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
105102

103+
VEC4_T mat1_tex;
104+
VEC4_T mat2_tex[4];
106105
for (
107106
uint16_t i = uint16_t(0), x = uint16_t(0);
108107
i < uint16_t(mat1_sizes.x);
109108
i += uint16_t(4), x++)
110109
{
111-
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
112-
const VEC4_T sums = VEC4_T(
113-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
114-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
115-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))),
116-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0))));
117-
118-
outtex += sums;
110+
mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
111+
112+
mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0));
113+
mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0));
114+
mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0));
115+
mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0));
116+
117+
outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3];
119118
}
120119

121120
outtex *= scales;

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void resize_q_8w_linear_node(
4848
vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]);
4949

5050
const int out_cols = utils::val_at(-2, mat1->sizes());
51-
const int out_rows = utils::val_at(-2, qmat2->sizes());
51+
const int out_rows = utils::val_at(-1, qmat2->sizes());
5252

5353
std::vector<int64_t> new_out_sizes(3);
5454
if (mat1->sizes().size() == 2) {
@@ -86,7 +86,7 @@ void add_q_8w_linear_node(
8686
// Ensure out is packed correctly
8787
out_W_packed = out_tmp;
8888
}
89-
ValueRef q_mat2 = prepack_standard(
89+
ValueRef q_mat2 = prepack_standard_hw_transposed(
9090
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
9191
ValueRef scales = prepack_standard(
9292
graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void add_tensor_to_staging_node(
113113
void add_prepack_standard_node(
114114
ComputeGraph& graph,
115115
const ValueRef tensor_data,
116-
const ValueRef tensor) {
116+
const ValueRef tensor,
117+
const bool transpose_hw = false) {
117118
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
118119
*graph.get_tensor(tensor), graph.int8_buffers_enabled());
119120

@@ -127,6 +128,8 @@ void add_prepack_standard_node(
127128
ubos.append({graph.sizes_ubo(tensor)});
128129
}
129130

131+
int transpose_hw_spec = transpose_hw ? 1 : 0;
132+
130133
graph.prepack_nodes().emplace_back(new PrepackNode(
131134
graph,
132135
shader,
@@ -138,7 +141,7 @@ void add_prepack_standard_node(
138141
// Parameter Buffers
139142
ubos,
140143
// Specialization Constants
141-
{graph.hashed_layout_of(tensor)}));
144+
{graph.hashed_layout_of(tensor), transpose_hw_spec}));
142145
}
143146

144147
ValueRef prepack_standard(
@@ -158,6 +161,33 @@ ValueRef prepack_standard(
158161
return tensor;
159162
}
160163

164+
ValueRef prepack_standard_hw_transposed(
165+
ComputeGraph& graph,
166+
const ValueRef tensor_data,
167+
const utils::StorageType storage_type,
168+
const utils::GPUMemoryLayout layout,
169+
const bool passthrough,
170+
const utils::AxisMapLayout axis_map_layout) {
171+
(void)passthrough;
172+
173+
VK_CHECK_COND(graph.val_is_tref(tensor_data));
174+
std::vector<int64_t> new_out_sizes = graph.sizes_of(tensor_data);
175+
const int w_dim = new_out_sizes.size() - 1;
176+
const int h_dim = new_out_sizes.size() - 2;
177+
const int64_t tmp = new_out_sizes.at(w_dim);
178+
new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim);
179+
new_out_sizes.at(h_dim) = tmp;
180+
ValueRef tensor = graph.add_tensor(
181+
new_out_sizes,
182+
graph.dtype_of(tensor_data),
183+
storage_type,
184+
layout,
185+
-1,
186+
axis_map_layout);
187+
add_prepack_standard_node(graph, tensor_data, tensor, true);
188+
return tensor;
189+
}
190+
161191
ValueRef prepack_standard_like(
162192
ComputeGraph& graph,
163193
const ValueRef tensor_data,

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ ValueRef prepack_standard(
5151
const bool passthrough = false,
5252
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
5353

54+
/*
55+
* Same as prepack_standard, but transpose the height and width dimensions of
56+
* the tensor while packing.
57+
*/
58+
ValueRef prepack_standard_hw_transposed(
59+
ComputeGraph& graph,
60+
const ValueRef tensor_data,
61+
const utils::StorageType storage_type,
62+
const utils::GPUMemoryLayout layout,
63+
const bool passthrough = false,
64+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
65+
5466
/*
5567
* Equivalent to `prepack_standard()` function, except the `storage_type` and
5668
* `memory_layout` are set to match `to_copy`, which must be a `Tensor`.

backends/vulkan/test/op_tests/cases.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def get_weight_int8pack_mm_inputs():
157157
[6, 1024, 256],
158158
[6, 256, 256],
159159
[6, 256, 512],
160+
[4, 768, 4096],
161+
[1024, 1024, 1024],
160162
]
161163

162164
inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list]
163165

164166
test_suite = VkTestSuite(inputs_list)
165-
test_suite.dtypes = ["at::kFloat", "at::kHalf"]
167+
test_suite.dtypes = ["at::kFloat"]
166168
test_suite.layouts = ["utils::kWidthPacked"]
167169
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
168170
test_suite.prepacked_args = ["mat2", "scales"]

0 commit comments

Comments
 (0)