Skip to content

Commit f23c350

Browse files
committed
Update on "[ET-VK][ez] Use standard quant naming scheme for quantized ops"
## Context Use standard naming scheme for quantized operators introduced in the previous PR. For weight only quantized linear operators, the names introduced are: `linear_qcsnw`: * q - quantized * c - per-channel / channelswise * s - symmetric * n - number of bits (qcs4w for 4-bit quant, qcs8w for 8-bit quant) * w - weight quantized `linear_qga4w`: * q - quantized * g - per-group / groupwise * a - affine * 4 - quantized to 4 bits * w - weight quantized ## Changes Rename instances of `q_8w_linear` to `linear_qcs8w` or `linear_qcsnw`. Rename instances of `q_4w_linear` to `linear_qga4w`. Rename cpp files to match the new naming convention. Differential Revision: [D73941992](https://our.internmc.facebook.com/intern/diff/D73941992/) [ghstack-poisoned]
2 parents 6c96ac3 + d19d277 commit f23c350

File tree

108 files changed

+1990
-1313
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+1990
-1313
lines changed

.ci/scripts/build-qnn-sdk.sh

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set_up_aot() {
3333
cmake .. \
3434
-DCMAKE_INSTALL_PREFIX=$PWD \
3535
-DEXECUTORCH_BUILD_QNN=ON \
36+
-DANDROID_NATIVE_API_LEVEL=30 \
3637
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
3738
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
3839
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \

.ci/scripts/test_model.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ test_model_with_coreml() {
222222

223223
DTYPE=float16
224224

225-
"${PYTHON_EXECUTABLE}" -m examples.apple.coreml.scripts.export --model_name="${MODEL_NAME}" --compute_precision "${DTYPE}"
225+
"${PYTHON_EXECUTABLE}" -m examples.apple.coreml.scripts.export --model_name="${MODEL_NAME}" --compute_precision "${DTYPE}" --use_partitioner
226226
EXPORTED_MODEL=$(find "." -type f -name "${MODEL_NAME}*.pte" -print -quit)
227227

228228
if [ -n "$EXPORTED_MODEL" ]; then

Package.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ let package = Package(
7777
name: "\(key)_dependencies",
7878
dependencies: [.target(name: key)],
7979
path: ".Package.swift/\(key)",
80-
linkerSettings:
80+
linkerSettings: [
81+
.linkedLibrary("c++")
82+
] +
8183
(value["frameworks"] as? [String] ?? []).map { .linkedFramework($0) } +
8284
(value["libraries"] as? [String] ?? []).map { .linkedLibrary($0) }
8385
),
@@ -94,10 +96,8 @@ let package = Package(
9496
.copy("resources/add.pte")
9597
],
9698
linkerSettings: [
97-
.linkedLibrary("c++"),
9899
.unsafeFlags([
99-
"-Xlinker", "-force_load",
100-
"-Xlinker", "cmake-out/kernels_portable.xcframework/macos-arm64/libkernels_portable_macos.a",
100+
"-Xlinker", "-all_load",
101101
])
102102
]
103103
)

backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm

+15-1
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,17 @@
8888
ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type());
8989
return std::nullopt;
9090
}
91-
91+
9292
std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end());
9393
std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end());
94+
95+
// If tensor is rank 0, wrap in rank 1
96+
// See https://github.com/apple/coremltools/blob/8.2/coremltools/converters/mil/frontend/torch/exir_utils.py#L73
97+
if (shape.size() == 0) {
98+
shape.push_back(1);
99+
strides.push_back(1);
100+
}
101+
94102
MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides));
95103
switch (argType) {
96104
case ArgType::Input: {
@@ -233,6 +241,12 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
233241
std::array<SizesType, kTensorDimensionLimit> new_shape;
234242
for (size_t i = nInputs; i < nInputs + nOutputs; i++) {
235243
Tensor& t = args[i]->toTensor();
244+
// If t has rank 0, do not resize. delegate_args[i] will have rank 1
245+
// because we resized it in get_multi_array
246+
if (t.dim() == 0) {
247+
continue;
248+
}
249+
236250
int rank = delegate_args[i].layout().rank();
237251
assert (rank <= new_shape.size());
238252
for (int d = 0; d < rank; d++) {

backends/apple/coreml/runtime/test/ETCoreMLModelManagerTests.mm

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ - (void)testAddModelExecution {
113113
XCTAssertNotNil(inputs);
114114
MLMultiArray *output = [ETCoreMLTestUtils filledMultiArrayWithShape:inputs[0].shape dataType:inputs[0].dataType repeatedValue:@(0) error:&localError];
115115
NSArray<MLMultiArray *> *args = [inputs arrayByAddingObject:output];
116-
XCTAssertTrue([self.modelManager executeModelWithHandle:handle
116+
XCTAssertTrue([self.modelManager executeModelWithHandle:handle
117117
args:args
118118
loggingOptions:executorchcoreml::ModelLoggingOptions()
119119
eventLogger:nullptr

backends/apple/coreml/scripts/install_requirements.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ SCRIPT_DIR_PATH="$(
1212

1313
# TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner.
1414
# Keep this version in sync with: pyproject.toml
15-
COREMLTOOLS_VERSION="8.2"
15+
COREMLTOOLS_VERSION="8.3"
1616

1717
red=`tput setaf 1`
1818
green=`tput setaf 2`

backends/arm/operators/op_upsample_bilinear2d.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert (
38-
inputs[0].shape is not None and output.shape is not None
39-
), "Only static shapes are supported"
37+
if inputs[0].shape is None or output.shape is None:
38+
raise ValueError("Only static shapes are supported")
4039

4140
input_dtype = inputs[0].dtype
4241

@@ -55,9 +54,12 @@ def define_node(
5554
def in_int16_range(x):
5655
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
5756

58-
assert in_int16_range(scale_n_yx)
59-
assert in_int16_range(scale_d_yx)
60-
assert in_int16_range(border_yx)
57+
if not in_int16_range(scale_n_yx):
58+
raise ValueError("scale_n_yx is out of the int16 range")
59+
if not in_int16_range(scale_d_yx):
60+
raise ValueError("scale_d_yx is out of the int16 range")
61+
if not in_int16_range(border_yx):
62+
raise ValueError("border_yx is out of the int16 range")
6163

6264
attr = ts.TosaSerializerAttribute()
6365
attr.ResizeAttribute(

backends/arm/operators/op_upsample_nearest2d.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ def define_node(
3636
) -> None:
3737
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3838

39-
assert (
40-
inputs[0].shape is not None and output.shape is not None
41-
), "Only static shapes are supported"
39+
if inputs[0].shape is None or output.shape is None:
40+
raise ValueError("Only static shapes are supported")
4241

4342
# tosa_shape output is NHWC, take HW
4443
input_size_yx = torch.tensor(
@@ -55,9 +54,12 @@ def define_node(
5554
def in_int16_range(x):
5655
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
5756

58-
assert in_int16_range(scale_n_yx)
59-
assert in_int16_range(scale_d_yx)
60-
assert in_int16_range(border_yx)
57+
if not in_int16_range(scale_n_yx):
58+
raise ValueError("scale_n_yx is out of the int16 range")
59+
if not in_int16_range(scale_d_yx):
60+
raise ValueError("scale_d_yx is out of the int16 range")
61+
if not in_int16_range(border_yx):
62+
raise ValueError("border_yx is out of the int16 range")
6163

6264
attr = ts.TosaSerializerAttribute()
6365
attr.ResizeAttribute(

backends/cadence/aot/compiler_utils.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def get_cascaded_ops(
109109
return nodes
110110

111111

112-
# Capture the effect of transpose op on incoming dimension order
113-
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
112+
def get_transposed_dims(
113+
node: torch.fx.Node, dims: Optional[List[int]] = None
114+
) -> List[int]:
114115
"""
115-
Given a transpose node, and the incoming dimension ordering of the input
116-
tensor to the transpose node, return the net effect of transpose op on the
117-
dimension order.
116+
Applies the transposition as given by node onto the dimensions given in input
117+
e.g (1, 2) on [a, b, c, d] would return [a, c, b, d]
118118
"""
119119
assert node.target == exir_ops.edge.aten.transpose_copy.int
120120
# Assert that the dims is not empty
@@ -127,28 +127,22 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
127127
assert isinstance(transpose_dims1, int)
128128
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
129129
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
130-
# Perform transpose on dimmension ordering (dims)
131-
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
132-
return dims
130+
new_dims = list(dims)
131+
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
132+
return new_dims
133133

134134

135-
# Capture the effect of permute op on incoming dimension order
136-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
135+
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
137136
"""
138-
Given a permute node, and the incoming dimension ordering of the input
139-
tensor to the permute node, return the net effect of permute op on the
140-
dimension order.
137+
Applies the permutation as given by node onto the dimensions given in input
138+
e.g (1, 2, 0) on [a, b, c] would return [b, c, a]
141139
"""
142140
assert node.target == exir_ops.edge.aten.permute_copy.default
143141
# Permute each index of the dimension ordering (dims)
144142
# pyre-fixme[6]: This combined typecheck isn't supported yet.
145143
permute_dims: List[int] = list(node.args[1])
146144
assert all(isinstance(x, int) for x in permute_dims)
147-
# If the dims is empty, we can simply return the permute order
148-
if not dims:
149-
return permute_dims
150-
dims = [dims[x] for x in permute_dims]
151-
return dims
145+
return [dims[x] for x in permute_dims]
152146

153147

154148
# Return the tensor of buffer/parameter op

backends/cadence/aot/fuse_ops.py

+25-43
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import Any, Callable, cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887+
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888
"""
888889

889890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907908
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908909
return False
909910

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
932-
933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
935-
936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
938-
939-
return non_unit_dims == non_unit_dims_permuted
940-
941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
911+
# checking that permut2(permut1(identify)) == identity
912+
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913+
ident_dims = list(range(len(input_shape)))
914+
# this mapping helps to handle both transpose and permutations
915+
f: dict[Any, Callable] = {
916+
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
917+
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
918+
}
919+
in_dims = f[producer.target](producer, ident_dims)
920+
out_dims = f[consumer.target](consumer, in_dims)
921+
return out_dims == ident_dims
946922

947923
def get_fused_node(
948924
self,
@@ -960,11 +936,17 @@ def get_fused_node(
960936
return view
961937

962938
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
939+
# Remove any transpose/permutation op pair that cancel each other.
964940
self.find_and_fuse(
965941
graph_module,
966-
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967-
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
942+
producer_op_packets={
943+
exir_ops.edge.aten.transpose_copy,
944+
exir_ops.edge.aten.permute_copy,
945+
},
946+
consumer_op_packets={
947+
exir_ops.edge.aten.transpose_copy,
948+
exir_ops.edge.aten.permute_copy,
949+
},
968950
bypass_ops=self.bypass_ops,
969951
)
970952
result = super().call(graph_module)
@@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph:
10281010
FuseQuantDequantToRequantizePass,
10291011
FuseMulIntoDequantPass,
10301012
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1013+
FuseTransposeOrPermuteOpPairsPass,
10321014
]

backends/cadence/aot/passes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

backends/cadence/aot/replace_ops.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22632263

22642264

22652265
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2266-
class ReplacePowWithMullPass(ExportPass):
2266+
class ReplacePowWithMulPass(ExportPass):
22672267
"""
2268-
Replace the pow op with degree 2 for a mul op.
2268+
Replace the pow op for a mul op.
22692269
"""
22702270

22712271
def call_operator(
@@ -2275,19 +2275,32 @@ def call_operator(
22752275
kwargs: Dict[str, Argument],
22762276
meta: NodeMetadata,
22772277
) -> ProxyValue:
2278-
# TODO(eigen): Add support for other degrees.
2279-
if (
2280-
op
2281-
not in {
2282-
exir_ops.edge.aten.pow.Scalar,
2278+
if not (
2279+
len(args) > 1
2280+
and isinstance(args[1], int)
2281+
and cast(int, args[1]) > 1
2282+
and cast(int, args[1]) < 5
2283+
and op
2284+
in {
2285+
exir_ops.edge.aten.pow.Tensor_Scalar,
22832286
}
2284-
or args[0] != 2
22852287
):
22862288
return super().call_operator(op, args, kwargs, meta)
22872289

2290+
x = args[0]
2291+
exponent = cast(int, args[1])
2292+
2293+
if exponent > 2:
2294+
for _ in range(exponent, 2, -1):
2295+
x = super().call_operator(
2296+
exir_ops.edge.aten.mul.Tensor,
2297+
(x, args[0]),
2298+
{},
2299+
meta,
2300+
)
22882301
return super().call_operator(
22892302
exir_ops.edge.aten.mul.Tensor,
2290-
(args[1], args[1]),
2303+
(x, args[0]),
22912304
{},
22922305
meta,
22932306
)
@@ -2429,5 +2442,5 @@ class CadenceReplaceOpsInGraph:
24292442
ReplaceWhereWithFullArgsWithWhereScalar,
24302443
ReplaceGeluWithApproximateGeluPass,
24312444
ReplaceSplitWithSlicePass,
2432-
ReplacePowWithMullPass,
2445+
ReplacePowWithMulPass,
24332446
]

0 commit comments

Comments
 (0)