@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885
885
"""
886
886
Fuse transpose or permute op pairs to a single view op.
887
887
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888
+ This happens when op2(op1) == identity, modulo unitary dimensions.
889
+ 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890
+ so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888
891
"""
889
892
890
893
# A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908
911
if not super ().can_fuse_for_chain (producer , consumer , consumer_op_packets ):
909
912
return False
910
913
911
- # checking that permut2(permut1(identify )) == identity
914
+ # checking that permut2(permut1(identity )) == identity, modulo unitary dimensions
912
915
input_shape = cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape
913
916
ident_dims = list (range (len (input_shape )))
914
917
# this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918
921
}
919
922
in_dims = f [producer .target ](producer , ident_dims )
920
923
out_dims = f [consumer .target ](consumer , in_dims )
921
- return out_dims == ident_dims
924
+ # Filtering out unitary dimensions
925
+ non_unit_ident_dims = [dim for dim in ident_dims if input_shape [dim ] != 1 ]
926
+ non_unit_out_dims = [dim for dim in out_dims if input_shape [dim ] != 1 ]
927
+ return non_unit_out_dims == non_unit_ident_dims
922
928
923
929
def get_fused_node (
924
930
self ,
925
931
producer : torch .fx .Node ,
926
932
consumer : torch .fx .Node ,
927
933
graph_module : torch .fx .GraphModule ,
928
934
) -> torch .fx .Node :
935
+ # This step is important because of how we can fuse transpositions that are not perfectly
936
+ # reverse one of another but will be fused if there are unitary dimensions.
937
+ # The fused operation must have the same output shape as the consumer.
929
938
output_shape = consumer .meta ["val" ].shape
930
939
with graph_module .graph .inserting_after (consumer ):
931
940
view = graph_module .graph .call_function (
0 commit comments