Skip to content

Commit d817493

Browse files
committed
Update on "[ET-VK] Introduce generic export pass for fusing Q/DQ nodes"
## Context When quantizing models with the PT2E quantization flow, quantize/dequantize nodes will be inserted into the graph. However, these quantize/dequantize nodes must be fused with operators such as `aten.linear.default` to produce nodes corresponding to quantized operators (e.g. `weight_int8pack_mm`) in order for quantized operator implementations to be called at runtime. Currently, the op fusion is done by the `fuse_dequant_linear.py` pass, however, this only handles one specific fusion pattern to generate a `weight_int8pack_mm` operator. As more quantized operators are to be supported in ET-VK via the PT2E quantization flow, a more generic fusion pass is needed that can handle a variety of fusion patterns. ## Changes Introduce the `FuseQuantizedOpsTransform()` pass. I elected to introduce a new pass under the `backends/vulkan/_passes` directory, as opposed to modifying the existing pass because I anticipate the majority of the fusion patterns to be specific to ET-VK. Remove the existing `FuseDequantLinearPass()` Switch to using the `FuseQuantizedOpsTransform` pass instead of the old `FuseDequantLinear` pass. Add `test_vulkan_passes` Python test to test export passes. Some small refactors to `test_vulkan_delegate` Python test to improve code organizations. Differential Revision: [D73794042](https://our.internmc.facebook.com/intern/diff/D73794042/) [ghstack-poisoned]
2 parents 732236a + 3097f54 commit d817493

File tree

201 files changed

+3124
-731
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

201 files changed

+3124
-731
lines changed

.github/workflows/_link_check.yml

+30-14
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,51 @@ on:
77

88
jobs:
99
lint-urls:
10+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }}
1011
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1112
with:
1213
runner: linux.2xlarge
1314
docker-image: executorch-ubuntu-22.04-linter
14-
submodules: 'none'
15+
submodules: false
1516
fetch-depth: 0
1617
ref: ${{ inputs.ref }}
17-
timeout: 90
18+
timeout: 120
1819
script: |
1920
./scripts/lint_urls.sh $(
20-
[ "${{ github.event_name }}" = "pull_request" ] \
21-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
22-
|| [ "${{ github.event_name }}" = "push" ] \
23-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
24-
)
21+
{ [ "${{ github.event_name }}" = "pull_request" ] \
22+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
23+
|| \
24+
{ [ "${{ github.event_name }}" = "push" ] \
25+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
26+
) || {
27+
echo
28+
echo "URL lint failed."
29+
echo "If this is a transient outage, you can bypass it by adding the \`skip-url-lint\` label to your PR."
30+
echo "Or add \`@lint-ignore\` somewhere on the same line as the URL you want to skip checking."
31+
exit 1
32+
}
2533
2634
lint-xrefs:
35+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }}
2736
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2837
with:
2938
runner: linux.2xlarge
3039
docker-image: executorch-ubuntu-22.04-linter
31-
submodules: 'none'
40+
submodules: false
3241
fetch-depth: 0
3342
ref: ${{ inputs.ref }}
34-
timeout: 90
43+
timeout: 60
3544
script: |
3645
./scripts/lint_xrefs.sh $(
37-
[ "${{ github.event_name }}" = "pull_request" ] \
38-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
39-
|| [ "${{ github.event_name }}" = "push" ] \
40-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
41-
)
46+
{ [ "${{ github.event_name }}" = "pull_request" ] \
47+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
48+
|| \
49+
{ [ "${{ github.event_name }}" = "push" ] \
50+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
51+
) || {
52+
echo
53+
echo "Xref lint failed."
54+
echo "If this is a transient outage, you can bypass it by adding the \`skip-xref-lint\` label to your PR."
55+
echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking."
56+
exit 1
57+
}

.github/workflows/build-presets.yml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Build Presets
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- release/*
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
13+
cancel-in-progress: true

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ endif()
608608
# any backends.
609609
#
610610
add_library(executorch ${_executorch__srcs})
611-
target_link_libraries(executorch PUBLIC executorch_core)
611+
target_link_libraries(executorch PRIVATE executorch_core)
612612
target_include_directories(executorch PUBLIC ${_common_include_directories})
613613
target_compile_definitions(executorch PUBLIC C10_USING_CUSTOM_GENERATED_MACROS)
614614
target_compile_options(executorch PUBLIC ${_common_compile_options})

backends/arm/_passes/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ python_library(
1111
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1212
"//executorch/exir:lib",
1313
"//executorch/backends/transforms:utils",
14+
"//executorch/backends/transforms:decompose_sdpa",
1415
],
1516
)

backends/arm/_passes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
@@ -57,4 +58,5 @@
5758
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
5859
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
5960
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
61+
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
6062
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/annotate_decomposed_matmul.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070
if quantized_input:
7171
matmul_args = matmul_node.all_input_nodes
7272
for node in matmul_args:
73+
# Find the dq-node connected to this mm/bmm arg
7374
input_node = self._match_partition_to_node(
7475
node, partition.input_nodes
7576
)
76-
77-
# Remove partition input dq-node
78-
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
79-
graph_module.graph.erase_node(input_node)
8077
input_node_qargs = QuantArgs.from_operator(
8178
input_node.target, input_node.args
8279
)
83-
80+
# Insert new dq-node just before the mm/bmm with input_node's qparams
8481
with graph_module.graph.inserting_before(matmul_node):
8582
# Create new dq-node before matmul
8683
dq_node = create_node(
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087
dq_node.args = (node, *input_node_qargs)
9188
matmul_node.replace_input_with(node, dq_node)
9289

90+
for partition_input in partition.input_nodes:
91+
# Remove partition input dq-node
92+
partition_input.replace_all_uses_with(
93+
partition_input.all_input_nodes[0]
94+
)
95+
graph_module.graph.erase_node(partition_input)
96+
9397
partition_output = list(partition.output_nodes[0].users)[0]
9498
quantized_output = partition_output.target == q_op
9599
if quantized_output:

backends/arm/_passes/arm_pass_manager.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -49,6 +50,7 @@
4950
MatchWhereSelfDtypePass,
5051
QuantizeOperatorArguments,
5152
RemoveClonePass,
53+
ReplaceInfValues,
5254
ReplaceScalarWithTensorArgPassTOSABI,
5355
ReplaceScalarWithTensorArgPassTOSAMI,
5456
RetraceFoldedDtypesPass,
@@ -204,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
204206
self.add_pass(DecomposeVarPass())
205207
self.add_pass(DecomposeMeanDimPass())
206208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
207210
self.add_pass(DecomposeDivPass())
208211
self.add_pass(DecomposeLeakyReLUPass())
209212
self.add_pass(DecomposeSqrtPass())
@@ -216,4 +219,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216219
self.add_pass(DecomposeSoftmaxPass())
217220

218221
self.add_pass(ConvertMinMaxPass())
222+
self.add_pass(ReplaceInfValues())
219223
return self._transform(graph_module)

backends/arm/_passes/convert_split_to_slice.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import torch.fx
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415

@@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3435
split_node = node
3536
input_node = split_node.all_input_nodes[0]
3637
output_nodes = split_node.users.copy()
37-
_, shape, _ = extract_tensor_meta(input_node.meta)
38+
shape = get_first_fake_tensor(input_node).shape
3839
rank = len(shape)
3940
split_lengths = split_node.args[1]
4041
dim = split_node.args[2] if len(split_node.args) > 2 else 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This pass is based on backends/qualcomm/_passes/replace_inf_values.py
8+
# with some modification to replaced inf values.
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceInfValues(ExportPass):
15+
"""
16+
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
17+
"""
18+
19+
def __init__(self):
20+
super(ReplaceInfValues, self).__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule):
23+
modified = False
24+
for buf_name, tensor in graph_module.named_buffers():
25+
if tensor.is_floating_point():
26+
modified = True
27+
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
28+
tensor[tensor == float("inf")] = 255
29+
tensor[tensor == float("-inf")] = -255
30+
setattr(graph_module, buf_name, tensor)
31+
32+
for node in graph_module.graph.nodes:
33+
arg_list = list(node.args)
34+
for index, arg in enumerate(arg_list):
35+
if arg == float("-inf"):
36+
modified = True
37+
arg_list[index] = -255
38+
elif arg == float("inf"):
39+
modified = True
40+
arg_list[index] = +255
41+
node.args = tuple(arg_list)
42+
43+
if modified:
44+
graph_module.recompile()
45+
return PassResult(graph_module, modified)

backends/arm/operator_support/slice_copy_support.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SupportedTOSAOperatorCheck,
1313
)
1414
from executorch.backends.arm.tosa_specification import TosaSpecification
15-
from executorch.backends.arm.tosa_utils import getNodeArgs
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716

1817
logger = logging.getLogger(__name__)
@@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) ->
3332
if tosa_spec not in self.tosa_specs:
3433
return False
3534

36-
inputs = getNodeArgs(node)
37-
if len(inputs) == 5 and (step := inputs[4].number) != 1:
35+
args = node.args
36+
if len(args) == 5 and (step := args[4]) != 1:
3837
logging.warning(f"{node.target} with step size of {step} not supported.")
3938
return False
4039
return True

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.mul.Tensor,
195195
exir_ops.edge.aten.ne.Tensor,
196196
exir_ops.edge.aten.ne.Scalar,
197+
exir_ops.edge.aten.neg.default,
197198
exir_ops.edge.aten.add.Scalar,
198199
exir_ops.edge.aten.sub.Scalar,
199200
exir_ops.edge.aten.mul.Scalar,
@@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
311312
exir_ops.edge.aten.max_pool2d_with_indices.default,
312313
exir_ops.edge.aten.mm.default,
313314
exir_ops.edge.aten.mul.Tensor,
315+
exir_ops.edge.aten.neg.default,
314316
exir_ops.edge.aten.relu.default,
315317
exir_ops.edge.aten.sub.Tensor,
316318
exir_ops.edge.aten.upsample_bilinear2d.vec,

backends/arm/operators/TARGETS

+6
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@ python_library(
1010
],
1111
)
1212

13+
python_library(
14+
name = "operator_validation_utils",
15+
srcs = ["operator_validation_utils.py"],
16+
)
17+
1318
python_library(
1419
name = "ops",
1520
srcs = glob(["op_*.py", "ops_*.py"]),
1621
deps = [
1722
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa",
1823
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
1924
":node_visitor",
25+
":operator_validation_utils",
2026
"//executorch/backends/arm:tosa_mapping",
2127
"//executorch/backends/arm:tosa_quant_utils",
2228
"//executorch/backends/arm:tosa_utils",

backends/arm/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_maximum,
3232
op_minimum,
3333
op_mul,
34+
op_neg,
3435
op_permute,
3536
op_pow,
3637
op_reciprocal,

0 commit comments

Comments
 (0)