Skip to content

Commit 74c0dff

Browse files
committed
Update base for Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
2 parents 1013001 + 6759d35 commit 74c0dff

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886886
Fuse transpose or permute op pairs to a single view op.
887887
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888+
This happens when op2(op1) == identity, modulo unitary dimensions.
889+
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890+
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888891
"""
889892

890893
# A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
909912
return False
910913

911-
# checking that permut2(permut1(identify)) == identity
914+
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
912915
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913916
ident_dims = list(range(len(input_shape)))
914917
# this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921
}
919922
in_dims = f[producer.target](producer, ident_dims)
920923
out_dims = f[consumer.target](consumer, in_dims)
921-
return out_dims == ident_dims
924+
# Filtering out unitary dimensions
925+
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
926+
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
927+
return non_unit_out_dims == non_unit_ident_dims
922928

923929
def get_fused_node(
924930
self,
925931
producer: torch.fx.Node,
926932
consumer: torch.fx.Node,
927933
graph_module: torch.fx.GraphModule,
928934
) -> torch.fx.Node:
935+
# This step is important because of how we can fuse transpositions that are not perfectly
936+
# reverse one of another but will be fused if there are unitary dimensions.
937+
# The fused operation must have the same output shape as the consumer.
929938
output_shape = consumer.meta["val"].shape
930939
with graph_module.graph.inserting_after(consumer):
931940
view = graph_module.graph.call_function(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,28 @@ def _create_operator(
584584
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
585585
False,
586586
),
587+
# transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension
588+
# so it ends up being the same on memory => fuse
589+
(
590+
True,
591+
[0, 1],
592+
True,
593+
[0, 2],
594+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
595+
True,
596+
[5, 40, 1],
597+
),
598+
# transpose -> quant -> transpose is not the reverse, and unitary dimensions
599+
# don't help => don't fuse
600+
(
601+
True,
602+
[0, 1],
603+
True,
604+
[1, 3],
605+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
606+
False,
607+
[5, 40, 1, 4],
608+
),
587609
# permutation -> quant -> opposite permutation => fuse
588610
(
589611
False,
@@ -622,6 +644,28 @@ def _create_operator(
622644
False,
623645
[4, 4, 4],
624646
),
647+
# permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension
648+
# so it ends up being the same on memory => fuse
649+
(
650+
False,
651+
[1, 3, 2, 0],
652+
False,
653+
[3, 2, 1, 0],
654+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
655+
True,
656+
[3, 1, 8, 10],
657+
),
658+
# permutation -> quant -> a non reverse permutation, and unitary dimensions
659+
# don't help => don't fuse
660+
(
661+
False,
662+
[1, 3, 2, 0],
663+
False,
664+
[3, 1, 2, 0],
665+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
666+
False,
667+
[3, 1, 8, 10],
668+
),
625669
# transpose -> quant -> transpose as a permutation => fuse
626670
(
627671
True,

0 commit comments

Comments
 (0)