Skip to content

Commit 3f61954

Browse files
wszczurekhabanaEran Geva
authored andcommitted
[SW-187215] Add valid_seq_len feature to patched SDPA module
Change-Id: Ia627fe8134470d68a7e55fc978a972bb7f7b3d5b
1 parent 039af39 commit 3f61954

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,9 @@ def forward(
725725
is_causal=False,
726726
scale=None,
727727
softmax_mode="None",
728+
recompute=None,
729+
valid_seq_len=None,
730+
seq_padding_type="None",
728731
):
729732
qinput = self.quant_q(q).detach()
730733
kinput = self.quant_k(k).detach()
@@ -746,6 +749,8 @@ def forward(
746749
q_scale_o=self.scale_output,
747750
d_scale_s=self.descale_amax,
748751
is_amax_s=False,
752+
valid_seq_len=valid_seq_len,
753+
seq_padding_type=seq_padding_type
749754
)
750755
output = results[0]
751756
d_out = self.dequant_output(output)
@@ -761,6 +766,9 @@ def forward_measure(
761766
is_causal=False,
762767
scale=None,
763768
softmax_mode="fast",
769+
recompute=None,
770+
valid_seq_len=None,
771+
seq_padding_type="None",
764772
):
765773
dq = q.detach()
766774
dk = k.detach()
@@ -777,6 +785,8 @@ def forward_measure(
777785
# fp8_fused_sdpa in bf16 can use either FastSoftmax or regular
778786
softmax_mode="fast",
779787
is_amax_s=True,
788+
valid_seq_len=valid_seq_len,
789+
seq_padding_type=seq_padding_type
780790
)
781791
output = results[0]
782792
amax = results[1]

0 commit comments

Comments
 (0)