Skip to content

Commit 53158c3

Browse files
committed
Update on "[Executorch][SDPA] Refactor + Make quantized sdpa handle sequence at dim 1 or 2"
For quantized SDPA we want to evaluate performance impact of having seq at dim 1 as well as dim 2. This diff refactors the code to enable this. The same should be done also for float SDPA but left for future. Differential Revision: [D71833060](https://our.internmc.facebook.com/intern/diff/D71833060/) [ghstack-poisoned]
2 parents c70f048 + e176f92 commit 53158c3

File tree

60 files changed

+577
-2010
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+577
-2010
lines changed

.ci/scripts/test_ios_ci.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
set -e
99

10-
APP_PATH="examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo"
10+
APP_PATH="executorch-examples/apple/ExecuTorchDemo/ExecuTorchDemo"
1111
MODEL_NAME="mv3"
1212
SIMULATOR_NAME="executorch"
1313

@@ -34,6 +34,10 @@ say() {
3434
echo -e "\033[1m\n\t** $1 **\n\033[0m"
3535
}
3636

37+
say "Cloning the Demo App"
38+
39+
git clone --depth 1 https://github.com/pytorch-labs/executorch-examples.git
40+
3741
say "Installing CoreML Backend Requirements"
3842

3943
./backends/apple/coreml/scripts/install_requirements.sh

.github/workflows/android-release-artifacts.yml

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ on:
77
description: Version name to be uploaded for AAR release
88
required: false
99
type: string
10+
upload_to_maven:
11+
description: Upload the AAR to maven staging repository
12+
required: false
13+
type: boolean
1014

1115
concurrency:
1216
group: ${{ github.workflow }}-${{ github.ref }}
@@ -31,11 +35,14 @@ jobs:
3135
build-aar:
3236
name: build-aar
3337
needs: check-if-aar-exists
34-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
38+
if: ${{ !github.event.pull_request.head.repo.fork }}
39+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.7
40+
secrets: inherit
3541
permissions:
3642
id-token: write
3743
contents: read
3844
with:
45+
secrets-env: EXECUTORCH_MAVEN_SIGNING_KEYID EXECUTORCH_MAVEN_SIGNING_PASSWORD EXECUTORCH_MAVEN_CENTRAL_PASSWORD EXECUTORCH_MAVEN_CENTRAL_USERNAME EXECUTORCH_MAVEN_SIGNING_GPG_KEY_CONTENTS
3946
runner: linux.2xlarge
4047
docker-image: executorch-ubuntu-22.04-clang12-android
4148
submodules: 'true'
@@ -52,6 +59,16 @@ jobs:
5259
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool buck2
5360
export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded
5461
62+
mkdir -p ~/.gradle
63+
touch ~/.gradle/gradle.properties
64+
echo "signing.keyId=${SECRET_EXECUTORCH_MAVEN_SIGNING_KEYID}" >> ~/.gradle/gradle.properties
65+
echo "signing.password=${SECRET_EXECUTORCH_MAVEN_SIGNING_PASSWORD}" >> ~/.gradle/gradle.properties
66+
echo "mavenCentralUsername=${SECRET_EXECUTORCH_MAVEN_CENTRAL_USERNAME}" >> ~/.gradle/gradle.properties
67+
echo "mavenCentralPassword=${SECRET_EXECUTORCH_MAVEN_CENTRAL_PASSWORD}" >> ~/.gradle/gradle.properties
68+
echo "signing.secretKeyRingFile=/tmp/secring.gpg" >> ~/.gradle/gradle.properties
69+
70+
echo -n "$SECRET_EXECUTORCH_MAVEN_SIGNING_GPG_KEY_CONTENTS" | base64 -d > /tmp/secring.gpg
71+
5572
# Build AAR Package
5673
mkdir aar-out
5774
export BUILD_AAR_DIR=aar-out
@@ -61,6 +78,12 @@ jobs:
6178
6279
shasum -a 256 "${ARTIFACTS_DIR_NAME}/executorch.aar"
6380
81+
# Publish to maven staging
82+
UPLOAD_TO_MAVEN="${{ inputs.upload_to_maven }}"
83+
if [[ "$UPLOAD_TO_MAVEN" == "true" ]]; then
84+
(cd aar-out; ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:publishToMavenCentral)
85+
fi
86+
6487
upload-release-aar:
6588
name: upload-release-aar
6689
needs: build-aar
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT
6+
7+
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
8+
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
9+
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
10+
11+
# Add all targets and TOSA profiles we support here.
12+
TARGETS = {"tosa_BI", "tosa_MI", "u55_BI", "u85_BI"}
13+
14+
15+
def get_edge_ops():
16+
"""
17+
Returns a set with edge_ops with names on the form to be used in unittests:
18+
1. Names are in lowercase.
19+
2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
20+
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
21+
22+
Examples:
23+
abs.default -> abs
24+
split_copy.Tensor -> split_tensor
25+
"""
26+
edge_ops = set()
27+
for edge_name in ALL_EDGE_OPS:
28+
op, overload = edge_name.split(".")
29+
30+
# Normalize names
31+
op = op.lower()
32+
op = op.removeprefix("_")
33+
op = op.removesuffix("_copy")
34+
op = op.removesuffix("_with_indices")
35+
op = op.removesuffix("_no_training")
36+
overload = overload.lower()
37+
38+
if overload == "default":
39+
edge_ops.add(op)
40+
else:
41+
edge_ops.add(f"{op}_{overload}")
42+
43+
return edge_ops
44+
45+
46+
def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]:
47+
"""
48+
Parses a test name on the form
49+
test_OP_TARGET_<not_delegated>_<any_other_info>
50+
where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
51+
The "not_delegated" suffix indicates that the test tests that the op is not delegated.
52+
53+
Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
54+
55+
Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
56+
"""
57+
test_name = test_name.removeprefix("test_")
58+
is_delegated = "not_delegated" not in test_name
59+
assert (
60+
"reject" not in test_name
61+
), f"Use 'not_delegated' instead of 'reject' in {test_name}"
62+
63+
op = "None"
64+
target = "None"
65+
for potential_target in TARGETS:
66+
index = test_name.find(potential_target)
67+
if index != -1:
68+
op = test_name[: index - 1]
69+
target = potential_target
70+
break
71+
# Special case for convolution
72+
op = op.removesuffix("_1d")
73+
op = op.removesuffix("_2d")
74+
75+
assert target != "None", f"{test_name} does not contain one of {TARGETS}"
76+
assert (
77+
op in edge_ops
78+
), f"Parsed unvalid OP from {test_name}, {op} does not exist in edge.yaml or CUSTOM_EDGE_OPS"
79+
80+
return op, target, is_delegated
81+
82+
83+
if __name__ == "__main__":
84+
"""Parses a list of test names given on the commandline."""
85+
import sys
86+
87+
sys.tracebacklimit = 0 # Do not print stack trace
88+
89+
edge_ops = get_edge_ops()
90+
exit_code = 0
91+
92+
for test_name in sys.argv[1:]:
93+
try:
94+
assert test_name[:5] == "test_", f"Unexpected input: {test_name}"
95+
parse_test_name(test_name, edge_ops)
96+
except AssertionError as e:
97+
print(e)
98+
exit_code = 1
99+
else:
100+
print(f"{test_name} OK")
101+
102+
sys.exit(exit_code)

backends/arm/scripts/pre-push

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,44 @@ for COMMIT in ${COMMITS}; do
166166
fi
167167
fi
168168

169+
# Op test checks
170+
op_test_files=$(echo $commit_files | grep -oE 'backends/arm/test/ops/\S+')
171+
if [ "$op_test_files" ]; then
172+
173+
# TODO: These checks can be removed when all unittests are refactored.
174+
if grep -icq "SkipIfNoCorstone" $op_test_files; then
175+
echo -e "${ERROR} @SkipIfNoCorstone300/320 is deprecated;"\
176+
"please use XfailIfNoCorstone300/320 instead." >&2
177+
FAILED=1
178+
fi
179+
180+
if grep -icq "conftest.expectedFailureOnFVP" $op_test_files; then
181+
echo -e "${ERROR} @conftest.expectedFailureOnFVP is deprecated;"\
182+
"please use XfailIfCorstone300/320 instead." >&2
183+
FAILED=1
184+
fi
185+
186+
if grep -icq "unittest.TestCase" $op_test_files; then
187+
echo -e "${ERROR} Use of the Unittest test framework is deprecated;"\
188+
"please use Pytest instead." >&2
189+
FAILED=1
190+
fi
191+
192+
if grep -icq "on_fvp(" $op_test_files; then
193+
echo -e "${ERROR} All unittests should run on FVP if relevant,"\
194+
"on_fvp suffix can be excluded." >&2
195+
FAILED=1
196+
fi
197+
198+
# Check that the tested op and target is parsed correctly from the test name
199+
test_names=$(grep -h "def test_" $op_test_files | cut -d"(" -f1 | cut -d" " -f2)
200+
python ./backends/arm/scripts/parse_test_names.py $test_names
201+
if [ $? -ne 0 ]; then
202+
echo -e "${ERROR} Failed op test name check." >&2
203+
FAILED=1
204+
fi
205+
fi
206+
169207
echo "" # Newline to visually separate commit processing
170208
done
171209

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4444

4545

46-
act_qspec_asym8u = QuantizationSpec(
46+
act_qspec_asym8s = QuantizationSpec(
4747
dtype=torch.int8,
4848
quant_min=-128,
4949
quant_max=127,
@@ -52,7 +52,7 @@
5252
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5353
)
5454

55-
wgt_qspec_asym8u = QuantizationSpec(
55+
wgt_qspec_asym8s = QuantizationSpec(
5656
dtype=torch.int8,
5757
quant_min=-128,
5858
quant_max=127,
@@ -61,7 +61,7 @@
6161
observer_or_fake_quant_ctr=MinMaxObserver,
6262
)
6363

64-
wgt_qspec_asym8s = QuantizationSpec(
64+
wgt_qspec_sym8s = QuantizationSpec(
6565
dtype=torch.int8,
6666
quant_min=-128,
6767
quant_max=127,
@@ -72,17 +72,17 @@
7272

7373
bias_qspec: Optional[QuantizationSpec] = None
7474

75-
qconfig_A8uW8u = QuantizationConfig(
76-
act_qspec_asym8u,
77-
act_qspec_asym8u,
78-
wgt_qspec_asym8u,
75+
qconfig_A8W8 = QuantizationConfig(
76+
act_qspec_asym8s,
77+
act_qspec_asym8s,
78+
wgt_qspec_asym8s,
7979
None,
8080
)
8181

82-
qconfig_A8uW8s = QuantizationConfig(
83-
act_qspec_asym8u,
84-
act_qspec_asym8u,
85-
wgt_qspec_asym8s,
82+
qconfig_A8W8sym = QuantizationConfig(
83+
act_qspec_asym8s,
84+
act_qspec_asym8s,
85+
wgt_qspec_sym8s,
8686
None,
8787
)
8888

@@ -189,15 +189,15 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
189189

190190
def get_cadence_default_quantizers() -> List[Quantizer]:
191191
return [
192-
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8uW8u),
193-
CadenceAtenQuantizer(BmmPattern(), qconfig_A8uW8u),
194-
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8uW8s),
195-
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8uW8s),
196-
CadenceAtenQuantizer(LayerNormPattern(), qconfig_A8uW8u),
197-
CadenceAtenQuantizer(LinearPattern(), qconfig_A8uW8u),
198-
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8uW8u),
199-
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8uW8u),
200-
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8uW8u),
192+
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
193+
CadenceAtenQuantizer(BmmPattern(), qconfig_A8W8),
194+
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8W8sym),
195+
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
196+
CadenceAtenQuantizer(LayerNormPattern(), qconfig_A8W8),
197+
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
198+
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
199+
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
200+
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
201201
]
202202

203203

@@ -244,6 +244,6 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
244244
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
245245
if quantizers is None:
246246
quantizers = get_cadence_default_quantizers()
247-
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
248-
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8uW8u))
247+
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
248+
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
249249
super().__init__(quantizers)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class LayoutTransform(ExportPass):
4747
layout_agnostic_ops = {
4848
exir_ops.edge.aten.abs.default,
4949
exir_ops.edge.aten.add.Tensor,
50+
exir_ops.edge.aten.amax.default,
5051
exir_ops.edge.aten.bitwise_or.Tensor,
5152
exir_ops.edge.aten.bmm.default,
5253
exir_ops.edge.aten.bitwise_and.Tensor,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
op_abs,
1010
op_adaptive_avg_pool2d,
1111
op_add,
12+
op_amax,
1213
op_and,
1314
op_arange,
1415
op_argmin,
@@ -95,6 +96,7 @@
9596
op_abs,
9697
op_adaptive_avg_pool2d,
9798
op_add,
99+
op_amax,
98100
op_and,
99101
op_arange,
100102
op_argmin,

0 commit comments

Comments
 (0)