File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed
neural_compressor/torch/algorithms/fp8_quant/_quant_common Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -725,6 +725,9 @@ def forward(
725
725
is_causal = False ,
726
726
scale = None ,
727
727
softmax_mode = "None" ,
728
+ recompute = None ,
729
+ valid_seq_len = None ,
730
+ seq_padding_type = "None" ,
728
731
):
729
732
qinput = self .quant_q (q ).detach ()
730
733
kinput = self .quant_k (k ).detach ()
@@ -746,6 +749,8 @@ def forward(
746
749
q_scale_o = self .scale_output ,
747
750
d_scale_s = self .descale_amax ,
748
751
is_amax_s = False ,
752
+ valid_seq_len = valid_seq_len ,
753
+ seq_padding_type = seq_padding_type
749
754
)
750
755
output = results [0 ]
751
756
d_out = self .dequant_output (output )
@@ -761,6 +766,9 @@ def forward_measure(
761
766
is_causal = False ,
762
767
scale = None ,
763
768
softmax_mode = "fast" ,
769
+ recompute = None ,
770
+ valid_seq_len = None ,
771
+ seq_padding_type = "None" ,
764
772
):
765
773
dq = q .detach ()
766
774
dk = k .detach ()
@@ -777,6 +785,8 @@ def forward_measure(
777
785
# fp8_fused_sdpa in bf16 can use either FastSoftmax or regular
778
786
softmax_mode = "fast" ,
779
787
is_amax_s = True ,
788
+ valid_seq_len = valid_seq_len ,
789
+ seq_padding_type = seq_padding_type
780
790
)
781
791
output = results [0 ]
782
792
amax = results [1 ]
You can’t perform that action at this time.
0 commit comments