Skip to content

Commit d089855

Browse files
committed
Update on "[Executorch][llm] Make custom update cache op operate on indices"
This allows us to use ring buffer kv cache Differential Revision: [D73891424](https://our.internmc.facebook.com/intern/diff/D73891424/) [ghstack-poisoned]
2 parents 5cbab1f + 7a74ffc commit d089855

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886886
Fuse transpose or permute op pairs to a single view op.
887887
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888+
This happens when op2(op1) == identity, modulo unitary dimensions.
889+
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890+
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888891
"""
889892

890893
# A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
909912
return False
910913

911-
# checking that permut2(permut1(identify)) == identity
914+
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
912915
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913916
ident_dims = list(range(len(input_shape)))
914917
# this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921
}
919922
in_dims = f[producer.target](producer, ident_dims)
920923
out_dims = f[consumer.target](consumer, in_dims)
921-
return out_dims == ident_dims
924+
# Filtering out unitary dimensions
925+
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
926+
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
927+
return non_unit_out_dims == non_unit_ident_dims
922928

923929
def get_fused_node(
924930
self,
925931
producer: torch.fx.Node,
926932
consumer: torch.fx.Node,
927933
graph_module: torch.fx.GraphModule,
928934
) -> torch.fx.Node:
935+
# This step is important because of how we can fuse transpositions that are not perfectly
936+
# reverse one of another but will be fused if there are unitary dimensions.
937+
# The fused operation must have the same output shape as the consumer.
929938
output_shape = consumer.meta["val"].shape
930939
with graph_module.graph.inserting_after(consumer):
931940
view = graph_module.graph.call_function(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,28 @@ def _create_operator(
584584
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
585585
False,
586586
),
587+
# transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension
588+
# so it ends up being the same on memory => fuse
589+
(
590+
True,
591+
[0, 1],
592+
True,
593+
[0, 2],
594+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
595+
True,
596+
[5, 40, 1],
597+
),
598+
# transpose -> quant -> transpose is not the reverse, and unitary dimensions
599+
# don't help => don't fuse
600+
(
601+
True,
602+
[0, 1],
603+
True,
604+
[1, 3],
605+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
606+
False,
607+
[5, 40, 1, 4],
608+
),
587609
# permutation -> quant -> opposite permutation => fuse
588610
(
589611
False,
@@ -622,6 +644,28 @@ def _create_operator(
622644
False,
623645
[4, 4, 4],
624646
),
647+
# permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension
648+
# so it ends up being the same on memory => fuse
649+
(
650+
False,
651+
[1, 3, 2, 0],
652+
False,
653+
[3, 2, 1, 0],
654+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
655+
True,
656+
[3, 1, 8, 10],
657+
),
658+
# permutation -> quant -> a non reverse permutation, and unitary dimensions
659+
# don't help => don't fuse
660+
(
661+
False,
662+
[1, 3, 2, 0],
663+
False,
664+
[3, 1, 2, 0],
665+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
666+
False,
667+
[3, 1, 8, 10],
668+
),
625669
# transpose -> quant -> transpose as a permutation => fuse
626670
(
627671
True,

0 commit comments

Comments
 (0)