1414import operator
1515from collections import deque
1616from numbers import Number
17- from typing import cast , Sequence
17+ from typing import Any , Callable , cast
1818
1919# Import these for the cadence function signatures.
2020import executorch .backends .cadence .aot .ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881
882882
883883@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
884- class FuseTransposeOpPairsPass (FuseOpPairsAcrossBranchesPass ):
884+ class FuseTransposeOrPermuteOpPairsPass (FuseOpPairsAcrossBranchesPass ):
885885 """
886- Fuse transpose op pairs to a single view op.
886+ Fuse transpose or permute op pairs to a single view op.
887+ (transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888 """
888889
889890 # A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907908 if not super ().can_fuse_for_chain (producer , consumer , consumer_op_packets ):
908909 return False
909910
910- def get_dims (node : torch .fx .Node ) -> tuple [int , int ]:
911- def canonicalize (dim : int ) -> int :
912- if dim < 0 :
913- dim += len (node .meta ["val" ].shape )
914- return dim
915-
916- return tuple (canonicalize (cast (int , d )) for d in node .args [1 :3 ])
917-
918- def is_equivalent (
919- shape : Sequence [int ],
920- transpose0 : tuple [int , int ],
921- transpose1 : tuple [int , int ],
922- ) -> bool :
923- def permute_order (
924- order : Sequence [int ], dims : tuple [int , int ]
925- ) -> Sequence [int ]:
926- new_order = list (order )
927- new_order [dims [0 ]], new_order [dims [1 ]] = (
928- new_order [dims [1 ]],
929- new_order [dims [0 ]],
930- )
931- return new_order
932-
933- order = permute_order (range (len (shape )), transpose0 )
934- order = permute_order (order , transpose1 )
935-
936- non_unit_dims = [dim for dim in range (len (shape )) if shape [dim ] != 1 ]
937- non_unit_dims_permuted = [dim for dim in order if shape [dim ] != 1 ]
938-
939- return non_unit_dims == non_unit_dims_permuted
940-
941- return is_equivalent (
942- cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape ,
943- get_dims (producer ),
944- get_dims (consumer ),
945- )
911+ # checking that permut2(permut1(identify)) == identity
912+ input_shape = cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape
913+ ident_dims = list (range (len (input_shape )))
914+ # this mapping helps to handle both transpose and permutations
915+ f : dict [Any , Callable ] = {
916+ exir_ops .edge .aten .transpose_copy .int : get_transposed_dims ,
917+ exir_ops .edge .aten .permute_copy .default : get_permuted_dims ,
918+ }
919+ in_dims = f [producer .target ](producer , ident_dims )
920+ out_dims = f [consumer .target ](consumer , in_dims )
921+ return out_dims == ident_dims
946922
947923 def get_fused_node (
948924 self ,
@@ -960,11 +936,17 @@ def get_fused_node(
960936 return view
961937
962938 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
963- # Remove any dequantize op that has only quantize ops as its users .
939+ # Remove any transpose/permutation op pair that cancel each other .
964940 self .find_and_fuse (
965941 graph_module ,
966- producer_op_packets = {exir_ops .edge .aten .transpose_copy },
967- consumer_op_packets = {exir_ops .edge .aten .transpose_copy },
942+ producer_op_packets = {
943+ exir_ops .edge .aten .transpose_copy ,
944+ exir_ops .edge .aten .permute_copy ,
945+ },
946+ consumer_op_packets = {
947+ exir_ops .edge .aten .transpose_copy ,
948+ exir_ops .edge .aten .permute_copy ,
949+ },
968950 bypass_ops = self .bypass_ops ,
969951 )
970952 result = super ().call (graph_module )
@@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph:
10281010 FuseQuantDequantToRequantizePass ,
10291011 FuseMulIntoDequantPass ,
10301012 FuseFullThenReshapePass ,
1031- FuseTransposeOpPairsPass ,
1013+ FuseTransposeOrPermuteOpPairsPass ,
10321014 ]
0 commit comments