Skip to content

Commit 6c96ac3

Browse files
committed
[ET-VK][ez] Use standard quant naming scheme for quantized ops
## Context Use standard naming scheme for quantized operators introduced in the previous PR. For weight only quantized linear operators, the names introduced are: `linear_qcsnw`: * q - quantized * c - per-channel / channelswise * s - symmetric * n - number of bits (qcs4w for 4-bit quant, qcs8w for 8-bit quant) * w - weight quantized `linear_qga4w`: * q - quantized * g - per-group / groupwise * a - affine * 4 - quantized to 4 bits * w - weight quantized ## Changes Rename instances of `q_8w_linear` to `linear_qcs8w` or `linear_qcsnw`. Rename instances of `q_4w_linear` to `linear_qga4w`. Rename cpp files to match the new naming convention. Differential Revision: [D73941992](https://our.internmc.facebook.com/intern/diff/D73941992/) [ghstack-poisoned]
1 parent 7b643ae commit 6c96ac3

13 files changed

+52
-52
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
q_8w_linear:
7+
linear_qcsnw:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: texture3d
@@ -18,6 +18,6 @@ q_8w_linear:
1818
- VALUE: texture3d
1919
- VALUE: buffer
2020
shader_variants:
21-
- NAME: q_8w_linear_W_packed_W_packed
22-
- NAME: q_8w_linear_W_packed_H_packed
21+
- NAME: linear_qcs8w_W_packed_W_packed
22+
- NAME: linear_qcs8w_W_packed_H_packed
2323
MAT2_PACKING: H_packed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
q_8w_linear_coop:
7+
linear_qcsnw_coop:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
IN_STORAGE: texture3d
@@ -17,11 +17,11 @@ q_8w_linear_coop:
1717
- VALUE: 1
1818
SUFFIX: o4x1
1919
shader_variants:
20-
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float
21-
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float
20+
- NAME: linear_qcs8w_coop_texture3d_texture3d_texture2d_texture2d_float
21+
- NAME: linear_qcs8w_coop_buffer_buffer_texture2d_texture2d_float
2222
IN_STORAGE: buffer
2323
OUT_STORAGE: buffer
24-
- NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float
24+
- NAME: linear_qcs8w_coop_buffer_buffer_buffer_buffer_float
2525
IN_STORAGE: buffer
2626
OUT_STORAGE: buffer
2727
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
q_8w_linear_tiled:
7+
linear_qcsnw_tiled:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
IN_STORAGE: texture3d
@@ -21,11 +21,11 @@ q_8w_linear_tiled:
2121
- VALUE: 4
2222
SUFFIX: o4x4
2323
shader_variants:
24-
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
25-
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
24+
- NAME: linear_qcs8w_tiled_texture3d_texture3d_texture2d_texture2d_float
25+
- NAME: linear_qcs8w_tiled_buffer_buffer_texture2d_texture2d_float
2626
IN_STORAGE: buffer
2727
OUT_STORAGE: buffer
28-
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float
28+
- NAME: linear_qcs8w_tiled_buffer_buffer_buffer_buffer_float
2929
IN_STORAGE: buffer
3030
OUT_STORAGE: buffer
3131
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear_coop.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
q_4w_linear_coop:
7+
linear_qga4w_coop:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUT_STORAGE: texture3d
@@ -13,11 +13,11 @@ q_4w_linear_coop:
1313
PARAMS_STORAGE: buffer
1414
TILE_ROWS: 1
1515
shader_variants:
16-
- NAME: q_4w_linear_coop_texture3d_texture3d_texture2d_float
17-
- NAME: q_4w_linear_coop_buffer_buffer_texture2d_float
16+
- NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float
17+
- NAME: linear_qga4w_coop_buffer_buffer_texture2d_float
1818
OUT_STORAGE: buffer
1919
IN_STORAGE: buffer
20-
- NAME: q_4w_linear_coop_buffer_buffer_buffer_float
20+
- NAME: linear_qga4w_coop_buffer_buffer_buffer_float
2121
OUT_STORAGE: buffer
2222
IN_STORAGE: buffer
2323
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear_tiled.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
q_4w_linear_tiled:
7+
linear_qga4w_tiled:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUT_STORAGE: texture3d
@@ -13,11 +13,11 @@ q_4w_linear_tiled:
1313
PARAMS_STORAGE: buffer
1414
TILE_ROWS: 3
1515
shader_variants:
16-
- NAME: q_4w_linear_tiled_texture3d_texture3d_texture2d_float
17-
- NAME: q_4w_linear_tiled_buffer_buffer_texture2d_float
16+
- NAME: linear_qga4w_tiled_texture3d_texture3d_texture2d_float
17+
- NAME: linear_qga4w_tiled_buffer_buffer_texture2d_float
1818
OUT_STORAGE: buffer
1919
IN_STORAGE: buffer
20-
- NAME: q_4w_linear_tiled_buffer_buffer_buffer_float
20+
- NAME: linear_qga4w_tiled_buffer_buffer_buffer_float
2121
OUT_STORAGE: buffer
2222
IN_STORAGE: buffer
2323
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp renamed to backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
namespace vkcompute {
1717

18-
void check_q_8w_linear_args(
18+
void check_linear_qcsnw_args(
1919
const ComputeGraph& graph,
2020
const ValueRef mat1,
2121
const ValueRef qmat2_data,
@@ -37,7 +37,7 @@ void check_q_8w_linear_args(
3737
utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes));
3838
}
3939

40-
void resize_q_8w_linear_node(
40+
void resize_linear_qcs8w_node(
4141
ComputeGraph* graph,
4242
const std::vector<ArgGroup>& args,
4343
const std::vector<ValueRef>& extra_args) {
@@ -64,7 +64,7 @@ void resize_q_8w_linear_node(
6464
out->virtual_resize(new_out_sizes);
6565
}
6666

67-
void add_q_8w_linear_node(
67+
void add_linear_qcs8w_node(
6868
ComputeGraph& graph,
6969
const ValueRef mat1,
7070
const ValueRef q_mat2_data,
@@ -91,7 +91,7 @@ void add_q_8w_linear_node(
9191
ValueRef scales = prepack_standard(
9292
graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);
9393

94-
std::string kernel_name = "q_8w_linear";
94+
std::string kernel_name = "linear_qcs8w";
9595
kernel_name.reserve(kShaderNameReserve);
9696
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
9797
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
@@ -131,7 +131,7 @@ void add_q_8w_linear_node(
131131
// Specialization Constants
132132
{},
133133
// Resizing Logic
134-
resize_q_8w_linear_node,
134+
resize_linear_qcs8w_node,
135135
{},
136136
pcs));
137137
if (!graph.is_buffer_storage(out) &&
@@ -140,7 +140,7 @@ void add_q_8w_linear_node(
140140
}
141141
}
142142

143-
void add_q_8w_linear_tiled_node(
143+
void add_linear_qcs8w_tiled_node(
144144
ComputeGraph& graph,
145145
const bool use_coop_algorithm,
146146
const ValueRef mat1,
@@ -170,7 +170,7 @@ void add_q_8w_linear_tiled_node(
170170
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);
171171

172172
std::string kernel_name =
173-
use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled";
173+
use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled";
174174
kernel_name.reserve(kShaderNameReserve);
175175
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
176176
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
@@ -218,7 +218,7 @@ void add_q_8w_linear_tiled_node(
218218
// Specialization Constants
219219
{},
220220
// Resizing Logic
221-
resize_q_8w_linear_node,
221+
resize_linear_qcs8w_node,
222222
{},
223223
// Push Constants
224224
{{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}}));
@@ -280,13 +280,13 @@ bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) {
280280
void weight_int8pack_mm(
281281
ComputeGraph& graph,
282282
const std::vector<ValueRef>& args) {
283-
check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
283+
check_linear_qcsnw_args(graph, args[0], args[1], args[2], args[3]);
284284
if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) {
285285
bool use_coop_algorithm = can_use_coop_impl(graph, args[0]);
286-
return add_q_8w_linear_tiled_node(
286+
return add_linear_qcs8w_tiled_node(
287287
graph, use_coop_algorithm, args[0], args[1], args[2], args[3]);
288288
}
289-
return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
289+
return add_linear_qcs8w_node(graph, args[0], args[1], args[2], args[3]);
290290
}
291291

292292
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp renamed to backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
namespace vkcompute {
1717

18-
void check_q_4w_linear_args(
18+
void check_linear_qga4w_args(
1919
ComputeGraph& graph,
2020
const ValueRef mat1,
2121
const ValueRef mat2_data,
@@ -43,7 +43,7 @@ void check_q_4w_linear_args(
4343
VK_CHECK_COND(graph.has_standard_axis_map(out));
4444
}
4545

46-
void resize_q_4w_linear_node(
46+
void resize_linear_qga4w_node(
4747
ComputeGraph* graph,
4848
const std::vector<ArgGroup>& args,
4949
const std::vector<ValueRef>& extra_args) {
@@ -118,14 +118,14 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
118118
return qmat2;
119119
}
120120

121-
void add_q_4w_linear_node(
121+
void add_linear_qga4w_node(
122122
ComputeGraph& graph,
123123
const ValueRef mat1,
124124
const ValueRef mat2_data,
125125
const ValueRef group_size,
126126
const ValueRef scales_and_zeros_data,
127127
const ValueRef out) {
128-
check_q_4w_linear_args(
128+
check_linear_qga4w_args(
129129
graph, mat1, mat2_data, group_size, scales_and_zeros_data, out);
130130

131131
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
@@ -143,7 +143,7 @@ void add_q_4w_linear_node(
143143
ValueRef scales_and_zeros = prepack_standard_hw_transposed(
144144
graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked);
145145

146-
std::string kernel_name = "q_4w_linear";
146+
std::string kernel_name = "linear_qga4w";
147147
if (use_coop_algorithm) {
148148
kernel_name += "_coop";
149149
} else {
@@ -176,7 +176,7 @@ void add_q_4w_linear_node(
176176
// Specialization Constants
177177
{SV(group_size_val)},
178178
// Resizing Logic
179-
resize_q_4w_linear_node,
179+
resize_linear_qga4w_node,
180180
{},
181181
// Push Constants
182182
{graph.sizes_pc_of(out),
@@ -187,7 +187,7 @@ void add_q_4w_linear_node(
187187
void linear_weight_int4(
188188
ComputeGraph& graph,
189189
const std::vector<ValueRef>& args) {
190-
return add_q_4w_linear_node(
190+
return add_linear_qga4w_node(
191191
graph,
192192
args[0], // mat1
193193
args[1], // mat2

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
// Reference Implementations
2121
//
2222

23-
at::Tensor linear_weight_int4_reference_impl(
23+
at::Tensor linear_qga4w_reference_impl(
2424
const at::Tensor& x,
2525
const at::Tensor& weights_4x2,
2626
const int64_t groupsize,
@@ -101,7 +101,7 @@ at::Tensor dequantize_and_linear(
101101
// Test functions
102102
//
103103

104-
void test_reference_linear_int4(
104+
void test_reference_linear_qga4w(
105105
const int B,
106106
const int M,
107107
const int K,
@@ -119,7 +119,7 @@ void test_reference_linear_int4(
119119
at::Tensor scales_and_zeros =
120120
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
121121

122-
at::Tensor out = linear_weight_int4_reference_impl(
122+
at::Tensor out = linear_qga4w_reference_impl(
123123
x,
124124
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
125125
group_size,
@@ -152,7 +152,7 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
152152
}
153153
}
154154

155-
void test_vulkan_linear_int4_impl(
155+
void test_vulkan_linear_qga4w_impl(
156156
const int B,
157157
const int M,
158158
const int K,
@@ -174,7 +174,7 @@ void test_vulkan_linear_int4_impl(
174174
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
175175

176176
at::Tensor weights_int = unpack_weights_4x2(weights_4x2);
177-
at::Tensor out_ref = linear_weight_int4_reference_impl(
177+
at::Tensor out_ref = linear_qga4w_reference_impl(
178178
x,
179179
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
180180
group_size,
@@ -237,14 +237,14 @@ void test_vulkan_linear_int4_impl(
237237
ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4));
238238
}
239239

240-
void test_vulkan_linear_int4(
240+
void test_vulkan_linear_qga4w(
241241
const int B,
242242
const int M,
243243
const int K,
244244
const int N,
245245
const int group_size = 32,
246246
const int inner_k_tiles = 8) {
247-
test_vulkan_linear_int4_impl(
247+
test_vulkan_linear_qga4w_impl(
248248
B,
249249
M,
250250
K,
@@ -254,7 +254,7 @@ void test_vulkan_linear_int4(
254254
vkcompute::utils::kBuffer,
255255
vkcompute::utils::kBuffer);
256256

257-
test_vulkan_linear_int4_impl(
257+
test_vulkan_linear_qga4w_impl(
258258
B,
259259
M,
260260
K,
@@ -265,30 +265,30 @@ void test_vulkan_linear_int4(
265265
vkcompute::utils::kTexture3D);
266266
}
267267

268-
TEST(VulkanInt4LinearTest, test_reference_impl) {
269-
test_reference_linear_int4(
268+
TEST(VulkanLinearQGA4WTest, test_reference_impl) {
269+
test_reference_linear_qga4w(
270270
/*B = */ 1,
271271
/*M = */ 4,
272272
/*K = */ 128,
273273
/*N = */ 32);
274274
}
275275

276-
TEST(VulkanInt4LinearTest, test_vulkan_impl_small_m) {
277-
test_vulkan_linear_int4(
276+
TEST(VulkanLinearQGA4WTest, test_vulkan_impl_small_m) {
277+
test_vulkan_linear_qga4w(
278278
/*B = */ 1,
279279
/*M = */ 4,
280280
/*K = */ 128,
281281
/*N = */ 32);
282282

283-
test_vulkan_linear_int4(
283+
test_vulkan_linear_qga4w(
284284
/*B = */ 1,
285285
/*M = */ 1,
286286
/*K = */ 256,
287287
/*N = */ 256);
288288
}
289289

290-
TEST(VulkanInt4LinearTest, test_vulkan_impl_gemm) {
291-
test_vulkan_linear_int4(
290+
TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
291+
test_vulkan_linear_qga4w(
292292
/*B = */ 1,
293293
/*M = */ 256,
294294
/*K = */ 256,

0 commit comments

Comments
 (0)