Skip to content

Commit 8579f5d

Browse files
committed
fix import error for non-rocm platform
Signed-off-by: charlifu <[email protected]>
1 parent ac49dd2 commit 8579f5d

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tests/compile/test_fusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from vllm.compilation.matcher_utils import QUANT_OPS
1414
from vllm.compilation.noop_elimination import NoOpEliminationPass
1515
from vllm.compilation.post_cleanup import PostCleanupPass
16-
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFp8GroupQuantFusionPass
1716
from vllm.config import (
1817
CompilationConfig,
1918
CompilationMode,
@@ -241,6 +240,10 @@ def test_fusion_rmsnorm_quant(
241240
# Reshape pass is needed for the fusion pass to work
242241
noop_pass = NoOpEliminationPass(vllm_config)
243242
if model_class is TestAiterRmsnormGroupFp8QuantModel:
243+
from vllm.compilation.rocm_aiter_fusion import (
244+
RocmAiterRMSNormFp8GroupQuantFusionPass,
245+
)
246+
244247
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
245248
else:
246249
fusion_pass = RMSNormQuantFusionPass(vllm_config)

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from vllm.compilation.fusion import QUANT_OPS
1818
from vllm.compilation.noop_elimination import NoOpEliminationPass
1919
from vllm.compilation.post_cleanup import PostCleanupPass
20-
from vllm.compilation.rocm_aiter_fusion import RocmAiterSiluMulFp8GroupQuantFusionPass
2120
from vllm.config import (
2221
CompilationConfig,
2322
CompilationMode,
@@ -220,6 +219,10 @@ def test_fusion_silu_and_mul_quant(
220219
with set_current_vllm_config(config):
221220
fusion_pass = ActivationQuantFusionPass(config)
222221
if model_class == TestAiterSiluMulGroupFp8QuantModel:
222+
from vllm.compilation.rocm_aiter_fusion import (
223+
RocmAiterSiluMulFp8GroupQuantFusionPass,
224+
)
225+
223226
fusion_pass = RocmAiterSiluMulFp8GroupQuantFusionPass(config)
224227

225228
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]

0 commit comments

Comments
 (0)