Skip to content

Commit 180e264

Browse files
committed
Move KV scales logic to custom operator for torch.compile compatibiliy
Signed-off-by: adabeyta <[email protected]>
1 parent 0ee023e commit 180e264

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

vllm/attention/layer.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,8 @@ def forward(
277277
`vllm.forward_context.get_forward_context().attn_metadata`.
278278
"""
279279
if self.calculate_kv_scales:
280-
attn_metadata = get_forward_context().attn_metadata
281-
if attn_metadata.enable_kv_scales_calculation:
282-
self.calc_kv_scales(query, key, value)
280+
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
281+
self.layer_name)
283282

284283
output_dtype = query.dtype
285284
if self.query_quant is not None:
@@ -341,6 +340,16 @@ def forward(
341340
return torch.ops.vllm.unified_attention(
342341
query, key, value, self.layer_name)
343342

343+
def calc_kv_scales(self, query, key, value):
344+
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
345+
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
346+
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
347+
self._q_scale_float = self._q_scale.item()
348+
self._k_scale_float = self._k_scale.item()
349+
self._v_scale_float = self._v_scale.item()
350+
# We only calculate the scales once
351+
self.calculate_kv_scales = False
352+
344353
def extra_repr(self) -> str:
345354
s = f"head_size={self.impl.head_size}" # type: ignore
346355
s += f", num_heads={self.impl.num_heads}" # type: ignore
@@ -544,47 +553,41 @@ def maybe_save_kv_layer_to_connector(
544553
attn_metadata[layer_name])
545554

546555

547-
def unified_kv_scale_calc(
556+
def maybe_calc_kv_scales(
548557
query: torch.Tensor,
549558
key: torch.Tensor,
550559
value: torch.Tensor,
551-
q_scale: torch.Tensor,
552-
k_scale: torch.Tensor,
553-
v_scale: torch.Tensor,
554-
q_range: torch.Tensor,
555-
k_range: torch.Tensor,
556-
v_range: torch.Tensor,
557-
scale_calc: bool,
560+
layer_name: str,
558561
) -> None:
559562

560-
if not scale_calc:
563+
forward_context: ForwardContext = get_forward_context()
564+
attn_metadata = forward_context.attn_metadata
565+
566+
if isinstance(attn_metadata, dict):
567+
attn_metadata = attn_metadata[layer_name]
568+
569+
if attn_metadata is None or not getattr(
570+
attn_metadata, 'enable_kv_scales_calculation', False):
561571
return
562572

563-
q_scale.copy_(torch.abs(query).max() / q_range)
564-
k_scale.copy_(torch.abs(key).max() / k_range)
565-
v_scale.copy_(torch.abs(value).max() / v_range)
573+
self = forward_context.no_compile_layers[layer_name]
574+
self.calc_kv_scales(query, key, value)
566575

567576

568-
def unified_kv_scale_calc_fake(
577+
def maybe_calc_kv_scales_fake(
569578
query: torch.Tensor,
570579
key: torch.Tensor,
571580
value: torch.Tensor,
572-
q_scale: torch.Tensor,
573-
k_scale: torch.Tensor,
574-
v_scale: torch.Tensor,
575-
q_range: torch.Tensor,
576-
k_range: torch.Tensor,
577-
v_range: torch.Tensor,
578-
scale_calc: bool,
581+
layer_name: str,
579582
) -> None:
580583
return
581584

582585

583586
direct_register_custom_op(
584-
op_name="unified_kv_scale_calc",
585-
op_func=unified_kv_scale_calc,
586-
mutates_args=["q_scale", "k_scale", "v_scale"],
587-
fake_impl=unified_kv_scale_calc_fake,
587+
op_name="maybe_calc_kv_scales",
588+
op_func=maybe_calc_kv_scales,
589+
mutates_args=[],
590+
fake_impl=maybe_calc_kv_scales_fake,
588591
dispatch_key=current_platform.dispatch_key,
589592
tags=tag_cudagraph_unsafe,
590593
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,6 +2275,15 @@ def execute_model(
22752275
cudagraph_runtime_mode, batch_descriptor = \
22762276
self.cudagraph_dispatcher.dispatch(batch_descriptor)
22772277

2278+
# Set cudagraph mode to none if calc_kv_scales is true.
2279+
if attn_metadata is not None:
2280+
metadata_list = (attn_metadata.values() if isinstance(
2281+
attn_metadata, dict) else [attn_metadata])
2282+
if any(
2283+
getattr(m, 'enable_kv_scales_calculation', False)
2284+
for m in metadata_list):
2285+
cudagraph_runtime_mode = CUDAGraphMode.NONE
2286+
22782287
# This is currently to get around the assert in the DPMetadata
22792288
# where it wants `num_tokens_across_dp` to align with `num_tokens`
22802289
if ubatch_slices is not None:

0 commit comments

Comments
 (0)