@@ -277,9 +277,8 @@ def forward(
277
277
`vllm.forward_context.get_forward_context().attn_metadata`.
278
278
"""
279
279
if self .calculate_kv_scales :
280
- attn_metadata = get_forward_context ().attn_metadata
281
- if attn_metadata .enable_kv_scales_calculation :
282
- self .calc_kv_scales (query , key , value )
280
+ torch .ops .vllm .maybe_calc_kv_scales (query , key , value ,
281
+ self .layer_name )
283
282
284
283
output_dtype = query .dtype
285
284
if self .query_quant is not None :
@@ -341,6 +340,16 @@ def forward(
341
340
return torch .ops .vllm .unified_attention (
342
341
query , key , value , self .layer_name )
343
342
343
+ def calc_kv_scales (self , query , key , value ):
344
+ self ._q_scale .copy_ (torch .abs (query ).max () / self .q_range )
345
+ self ._k_scale .copy_ (torch .abs (key ).max () / self .k_range )
346
+ self ._v_scale .copy_ (torch .abs (value ).max () / self .v_range )
347
+ self ._q_scale_float = self ._q_scale .item ()
348
+ self ._k_scale_float = self ._k_scale .item ()
349
+ self ._v_scale_float = self ._v_scale .item ()
350
+ # We only calculate the scales once
351
+ self .calculate_kv_scales = False
352
+
344
353
def extra_repr (self ) -> str :
345
354
s = f"head_size={ self .impl .head_size } " # type: ignore
346
355
s += f", num_heads={ self .impl .num_heads } " # type: ignore
@@ -544,47 +553,41 @@ def maybe_save_kv_layer_to_connector(
544
553
attn_metadata [layer_name ])
545
554
546
555
547
- def unified_kv_scale_calc (
556
+ def maybe_calc_kv_scales (
548
557
query : torch .Tensor ,
549
558
key : torch .Tensor ,
550
559
value : torch .Tensor ,
551
- q_scale : torch .Tensor ,
552
- k_scale : torch .Tensor ,
553
- v_scale : torch .Tensor ,
554
- q_range : torch .Tensor ,
555
- k_range : torch .Tensor ,
556
- v_range : torch .Tensor ,
557
- scale_calc : bool ,
560
+ layer_name : str ,
558
561
) -> None :
559
562
560
- if not scale_calc :
563
+ forward_context : ForwardContext = get_forward_context ()
564
+ attn_metadata = forward_context .attn_metadata
565
+
566
+ if isinstance (attn_metadata , dict ):
567
+ attn_metadata = attn_metadata [layer_name ]
568
+
569
+ if attn_metadata is None or not getattr (
570
+ attn_metadata , 'enable_kv_scales_calculation' , False ):
561
571
return
562
572
563
- q_scale .copy_ (torch .abs (query ).max () / q_range )
564
- k_scale .copy_ (torch .abs (key ).max () / k_range )
565
- v_scale .copy_ (torch .abs (value ).max () / v_range )
573
+ self = forward_context .no_compile_layers [layer_name ]
574
+ self .calc_kv_scales (query , key , value )
566
575
567
576
568
- def unified_kv_scale_calc_fake (
577
+ def maybe_calc_kv_scales_fake (
569
578
query : torch .Tensor ,
570
579
key : torch .Tensor ,
571
580
value : torch .Tensor ,
572
- q_scale : torch .Tensor ,
573
- k_scale : torch .Tensor ,
574
- v_scale : torch .Tensor ,
575
- q_range : torch .Tensor ,
576
- k_range : torch .Tensor ,
577
- v_range : torch .Tensor ,
578
- scale_calc : bool ,
581
+ layer_name : str ,
579
582
) -> None :
580
583
return
581
584
582
585
583
586
direct_register_custom_op (
584
- op_name = "unified_kv_scale_calc " ,
585
- op_func = unified_kv_scale_calc ,
586
- mutates_args = ["q_scale" , "k_scale" , "v_scale" ],
587
- fake_impl = unified_kv_scale_calc_fake ,
587
+ op_name = "maybe_calc_kv_scales " ,
588
+ op_func = maybe_calc_kv_scales ,
589
+ mutates_args = [],
590
+ fake_impl = maybe_calc_kv_scales_fake ,
588
591
dispatch_key = current_platform .dispatch_key ,
589
592
tags = tag_cudagraph_unsafe ,
590
593
)
0 commit comments