Skip to content

Commit fd9423a

Browse files
yewentao256charlifu
authored andcommitted
[Perf] Apply torch.compile for per_block_cast_to_fp8 (vllm-project#24611)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: charlifu <[email protected]>
1 parent 42337d2 commit fd9423a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/utils/deep_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _align(x: int, y: int) -> int:
135135

136136

137137
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
138-
# TODO(wentao): optimize this function, using triton or cuda kernel
138+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
139139
def per_block_cast_to_fp8(
140140
x: torch.Tensor,
141141
block_size: list[int] = DEFAULT_BLOCK_SIZE,
@@ -187,4 +187,4 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
187187
"is_deep_gemm_e8m0_used",
188188
"is_deep_gemm_supported",
189189
"should_use_deepgemm_for_fp8_linear",
190-
]
190+
]

0 commit comments

Comments
 (0)