Skip to content

Commit b005f10

Browse files
authored
[ET-VK] Re-implement (de)quantize_per_tensor.default (#15753)
Re-implement the `quantized_decomposed.(de)quantize_per_tensor.default` ops with `add_quantize_and_pack_4w4c_node` As a consequence, the `et_vk.quantize_q8ta_for_conv2d.default` and `et_vk.dequantize_q8to_from_conv2d.default` ops are not needed anymore. The overall goal is to streamline the quantize/dequantize interface in ET-VK. Differential Revision: [D86702457](https://our.internmc.facebook.com/intern/diff/D86702457/)
1 parent 1ca3252 commit b005f10

File tree

7 files changed

+48
-175
lines changed

7 files changed

+48
-175
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,6 @@ runtime.python_library(
104104
],
105105
)
106106

107-
runtime.python_library(
108-
name = "replace_qdq",
109-
srcs = ["replace_qdq.py"],
110-
visibility = [
111-
"//executorch/backends/...",
112-
],
113-
deps = [
114-
"//caffe2:torch",
115-
"//executorch/backends/vulkan:utils_lib",
116-
"//executorch/exir:pass_base",
117-
],
118-
)
119-
120107
runtime.python_library(
121108
name = "fuse_patterns",
122109
srcs = ["fuse_patterns.py"],
@@ -149,7 +136,6 @@ runtime.python_library(
149136
":insert_prepack_nodes",
150137
":remove_asserts",
151138
":remove_redundant_ops",
152-
":replace_qdq",
153139
":squeeze_unsqueeze_inputs",
154140
":tag_memory_meta_pass",
155141
]

backends/vulkan/_passes/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2020
RemoveRedundantOpsTransform,
2121
)
22-
from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass
2322
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
2423
SqueezeUnsqueezeInputs,
2524
)
@@ -33,7 +32,6 @@
3332
"remove_asserts",
3433
"RemoveAssertsTransform",
3534
"RemoveRedundantOpsTransform",
36-
"ReplaceQDQPass",
3735
"SqueezeUnsqueezeInputs",
3836
"TagMemoryMetaPass",
3937
]

backends/vulkan/_passes/replace_qdq.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

backends/vulkan/custom_ops_lib.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -539,42 +539,6 @@ def apply_rotary_emb_impl(
539539
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
540540
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
541541

542-
#############################
543-
## quantize/dequantize ops ##
544-
#############################
545-
546-
547-
def quantize_q8ta_for_conv2d_impl(
548-
input: torch.Tensor,
549-
scale: float,
550-
zero_point: int,
551-
):
552-
return torch.ops.quantized_decomposed.quantize_per_tensor(
553-
input, scale, zero_point, -128, 127, torch.int8
554-
)
555-
556-
557-
name = "quantize_q8ta_for_conv2d"
558-
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
559-
lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd")
560-
quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name)
561-
562-
563-
def dequantize_q8to_from_conv2d_impl(
564-
input: torch.Tensor,
565-
scale: float,
566-
zero_point: int,
567-
):
568-
return torch.ops.quantized_decomposed.dequantize_per_tensor(
569-
input, scale, zero_point, -128, 127, input.dtype
570-
)
571-
572-
573-
name = "dequantize_q8to_from_conv2d"
574-
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
575-
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
576-
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)
577-
578542
########################
579543
## add_q8ta_q8ta_q8to ##
580544
########################

backends/vulkan/op_registry.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,9 @@ def register_ephemeral_op():
144144

145145
@update_features(
146146
[
147-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
148-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
149147
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
150-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
151-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
152-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
153148
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
149+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
154150
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
155151
]
156152
)
@@ -630,35 +626,35 @@ def register_quantized_binary_op():
630626

631627
@update_features(
632628
[
633-
exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
629+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
630+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
634631
]
635632
)
636-
def register_quantize_for_conv2d_op():
633+
def register_quantize_op():
637634
return OpFeatures(
638635
inputs_storage=[
639636
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
640637
],
641638
outputs_storage=[
642639
utils.PACKED_INT8_4W4C_BUFFER,
643640
],
644-
supports_resize=False,
645641
)
646642

647643

648644
@update_features(
649645
[
650-
exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default,
646+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
647+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
651648
]
652649
)
653-
def register_dequantize_for_conv2d_op():
650+
def register_dequantize_op():
654651
return OpFeatures(
655652
inputs_storage=[
656653
utils.PACKED_INT8_4W4C_BUFFER,
657654
],
658655
outputs_storage=[
659656
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
660657
],
661-
supports_resize=False,
662658
)
663659

664660

backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,30 +366,52 @@ void add_unpack_4w4c_and_dequantize_node(
366366
// Operator Entrypoints
367367
//
368368

369-
void quantize_q8ta_for_conv2d(
369+
void quantize_per_tensor_impl(
370370
ComputeGraph& graph,
371371
const std::vector<ValueRef>& args) {
372-
int32_t idx = 0;
373-
const ValueRef fp_input = args.at(idx++);
374-
const ValueRef scale = args.at(idx++);
375-
const ValueRef zero_point = args.at(idx++);
376-
const ValueRef packed_int8_input = args.at(idx++);
372+
int32_t arg_idx = 0;
373+
const ValueRef fp_input = args[arg_idx++];
374+
const ValueRef scale = args[arg_idx++];
375+
const ValueRef zero_point = args[arg_idx++];
376+
const ValueRef quant_min = args[arg_idx++];
377+
(void)quant_min;
378+
const ValueRef quant_max = args[arg_idx++];
379+
(void)quant_max;
380+
const ValueRef dtype = args[arg_idx++];
381+
(void)dtype;
382+
383+
const ValueRef int8_output = args[arg_idx++];
384+
385+
VK_CHECK_COND(
386+
graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C);
377387

378388
add_quantize_and_pack_4w4c_node(
379-
graph, fp_input, scale, zero_point, packed_int8_input);
389+
graph, fp_input, scale, zero_point, int8_output);
380390
}
381391

382-
void dequantize_q8to_from_conv2d(
392+
void dequantize_per_tensor_impl(
383393
ComputeGraph& graph,
384394
const std::vector<ValueRef>& args) {
385-
int32_t idx = 0;
386-
const ValueRef packed_int8_output = args.at(idx++);
387-
const ValueRef scale = args.at(idx++);
388-
const ValueRef zero_point = args.at(idx++);
389-
const ValueRef fp_output = args.at(idx++);
395+
int32_t arg_idx = 0;
396+
const ValueRef int8_input = args[arg_idx++];
397+
const ValueRef scale = args[arg_idx++];
398+
const ValueRef zero_point = args[arg_idx++];
399+
const ValueRef quant_min = args[arg_idx++];
400+
(void)quant_min;
401+
const ValueRef quant_max = args[arg_idx++];
402+
(void)quant_max;
403+
const ValueRef dtype = args[arg_idx++];
404+
(void)dtype;
405+
const ValueRef output_dtype = args[arg_idx++];
406+
(void)output_dtype;
407+
408+
const ValueRef fp_output = args[arg_idx++];
409+
410+
VK_CHECK_COND(
411+
graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C);
390412

391413
add_unpack_4w4c_and_dequantize_node(
392-
graph, packed_int8_output, scale, zero_point, fp_output);
414+
graph, int8_input, scale, zero_point, fp_output);
393415
}
394416

395417
void qdq8ta_conv2d_input(
@@ -416,11 +438,13 @@ void qdq8ta_conv2d_input(
416438
}
417439

418440
REGISTER_OPERATORS {
419-
VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input);
420441
VK_REGISTER_OP(
421-
et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d);
442+
quantized_decomposed.quantize_per_tensor.default,
443+
quantize_per_tensor_impl);
422444
VK_REGISTER_OP(
423-
et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d);
445+
quantized_decomposed.dequantize_per_tensor.default,
446+
dequantize_per_tensor_impl);
447+
VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input);
424448
}
425449

426450
} // namespace vkcompute

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
FuseQuantizedOpsTransform,
2222
insert_prepack_nodes,
2323
RemoveRedundantOpsTransform,
24-
ReplaceQDQPass,
2524
SqueezeUnsqueezeInputs,
2625
TagMemoryMetaPass,
2726
)
@@ -162,7 +161,6 @@ def preprocess( # noqa: C901
162161
AddmmToLinearTransform(),
163162
RemoveRedundantOpsTransform(),
164163
FuseQuantizedOpsTransform(),
165-
ReplaceQDQPass(),
166164
FoldQDQPass(),
167165
SqueezeUnsqueezeInputs(),
168166
FuseViewCopyTransform(),

0 commit comments

Comments
 (0)