Skip to content

Commit 6ec29a1

Browse files
committed
Update custom op reg and add e2e testing
1 parent 180e264 commit 6ec29a1

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/compile/test_full_graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@ def test_custom_compile_config(
139139
run_model(compilation_config, model, model_kwargs)
140140

141141

142+
@pytest.mark.parametrize(
143+
"optimization_level",
144+
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
145+
)
146+
def test_fp8_kv_scale_compile(optimization_level: int):
147+
model = "Qwen/Qwen2-0.5B"
148+
model_kwargs = {
149+
"quantization": "fp8",
150+
"kv_cache_dtype": "fp8_e4m3",
151+
"calculate_kv_scales": True,
152+
"max_model_len": 512,
153+
}
154+
run_model(optimization_level, model, model_kwargs)
155+
156+
142157
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
143158
if not is_torch_equal_or_newer("2.9.0.dev"):
144159
pytest.skip("inductor graph partition is only available "

vllm/attention/layer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,8 @@ def maybe_calc_kv_scales_fake(
586586
direct_register_custom_op(
587587
op_name="maybe_calc_kv_scales",
588588
op_func=maybe_calc_kv_scales,
589-
mutates_args=[],
589+
mutates_args=["query", "key", "value"],
590590
fake_impl=maybe_calc_kv_scales_fake,
591-
dispatch_key=current_platform.dispatch_key,
592-
tags=tag_cudagraph_unsafe,
593591
)
594592

595593

0 commit comments

Comments
 (0)