diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 2c4287950dcf..f25c367433f4 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import weakref from collections.abc import Sequence from copy import deepcopy from typing import Callable, Union @@ -10,7 +11,26 @@ from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass -from vllm.config import get_current_vllm_config +from vllm.compilation.pass_manager import with_pattern_match_debug +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig, get_current_vllm_config + + +class LazyInitPass(InductorPass): + """ + If there's a pass that we want to initialize lazily in a test, + we can wrap it in LazyInitPass, which will initialize the pass when invoked + and then immediately invoke it. + """ + + def __init__(self, pass_cls: type[VllmInductorPass], + vllm_config: VllmConfig): + self.pass_cls = pass_cls + self.vllm_config = weakref.proxy(vllm_config) # avoid cycle + + def __call__(self, graph: fx.Graph) -> None: + self.pass_ = self.pass_cls(self.vllm_config) + self.pass_(graph) class TestBackend: @@ -40,10 +60,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs): example_inputs, config_patches=self.inductor_config) + @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + + VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: pass_(graph) + VllmInductorPass.dump_prefix += 1 + + VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) # assign by reference, will reflect the final state of the graph diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 9a51e6b3514f..1dc21365d557 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) + assert async_tp_pass.matched_count == 1 + # In pre-nodes, all gather or reduce scatter should exist, # fused_matmul_reduce_scatter or fused_all_gather_matmul should not backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 90e8e0ff9585..7afd6251bbbd 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,7 +4,7 @@ import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, VllmConfig from vllm.utils import _is_torch_equal_or_newer @@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch): assert not vllm_config.compilation_config.use_cudagraph +def test_custom_op(): + # proper syntax + _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) + + with pytest.raises(ValueError, match="Invalid syntax '"): + _ = CompilationConfig(custom_ops=["quant_fp8"]) + + # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @pytest.mark.forked # NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 0c7e6fbccf20..2ee9aa7476be 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,9 +8,10 @@ from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FUSED_OPS, FusionPass +from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) @@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, vllm_config.compilation_config = CompilationConfig( pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass, act_quant_fusion_pass - ] if do_fusion else [noop_pass] + passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass + ] if do_fusion else [noop_pass, cleanup_pass] func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index eedb9bdcd529..3d8897d3f18b 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -4,11 +4,11 @@ import pytest import torch -import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass) + RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm @@ -79,15 +79,15 @@ def ops_in_model_after(self): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) -@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize("cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, cuda_force_torch): @@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(noop_pass, fusion_pass) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic @@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + assert fusion_pass.matched_count == 2 + # In pre-nodes, fp8 quant should be there and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index dd31e0db1f59..60f32c863208 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -9,6 +9,7 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce @@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, + cleanup_pass) token_num = batch_size * seq_len model = test_model_cls(hidden_size, token_num) @@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) + assert all_reduce_fusion_pass.matched_count == 1 backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 022f183b3193..40e47c8c591c 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -6,18 +6,19 @@ import pytest import torch._dynamo -from tests.compile.backend import TestBackend +from tests.compile.backend import LazyInitPass, TestBackend from tests.models.utils import check_outputs_equal from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata) from vllm import LLM, SamplingParams from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, ModelConfig, PassConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) @@ -105,7 +106,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) llm2 = LLM(model, enforce_eager=True, @@ -198,7 +199,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, device=self.device, ) - def build_attn_metadata(self, batch_size: int, use_hnd: bool): + def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ + -> AttentionMetadata: """Initialize attention metadata.""" # Create common attn metadata @@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw - ) - test_backend = TestBackend(noop_pass, attn_pass) + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) # Compile model with fusion enabled model_compiled = torch.compile(model_fused, @@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + # access the underlying `AttnFusionPass` on the `LazyInitPass` + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) + # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) attn_nodes_post = list(find_op_nodes(ATTN_OP, diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index fb9f9dde2279..b2734e915bbb 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -6,10 +6,12 @@ import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FusionPass +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass +from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce @@ -104,7 +106,7 @@ def __init__(self, # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, @@ -137,8 +139,7 @@ def forward(self, hidden_states, residual): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - # for static input quantization - # self.fp8_linear is initialized with use_per_token_if_dynamic=False + # scaled_mm with static input quantization fp8_linear_result = self.fp8_linear.apply(norm_output, self.w, self.wscale, @@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model( dtype=dtype, seed=42) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - passes_for_backend = [noop_pass, sequence_parallelism_pass] + passes_for_backend: list[VllmInductorPass] = \ + [noop_pass, sequence_parallelism_pass] if enable_fusion: - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) passes_for_backend.append(fusion_pass) + passes_for_backend.append(cleanup_pass) + backend_no_func = TestBackend(*passes_for_backend) backend_func = TestBackend(*passes_for_backend, func_pass) @@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model( compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func(hidden_states, residual) + assert sequence_parallelism_pass.matched_count == 1 + # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not backend_no_func.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index ae190d25cad6..c445f4dde2cc 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -15,6 +15,7 @@ # yapf: enable from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): super().__init__() + from vllm.compilation.activation_quant_fusion import ( + silu_and_mul_nvfp4_quant_supported) + assert silu_and_mul_nvfp4_quant_supported + self.silu_and_mul = SiluAndMul() # create nvfp4 weight @@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(NoOpEliminationPass(config), fusion_pass) + passes = [ + NoOpEliminationPass(config), fusion_pass, + PostCleanupPass(config) + ] + backend = TestBackend(*passes) model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) @@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, atol=atol, rtol=rtol) + assert fusion_pass.matched_count == 1 + # In pre-nodes, quant op should be present and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index f2fbb1200eec..74462fb37ca9 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -17,7 +17,7 @@ from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -152,7 +152,7 @@ def replacement(result: torch.Tensor, output_scale: torch.Tensor, register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -class ActivationQuantFusionPass(VllmInductorPass): +class ActivationQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -176,16 +176,12 @@ def __init__(self, config: VllmConfig): pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4.register(self.patterns) - def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_act_quant_fusion") - - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns in ActivationQuantFusionPass", - count) + self.dump_patterns(config, self.patterns) - self.dump_graph(graph, "after_act_quant_fusion") - self.end_and_log() + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) def uuid(self): return VllmInductorPass.hash_source(self, ActivationQuantPattern, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0658b59a2e21..331cd8a87392 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -20,7 +20,7 @@ from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -348,7 +348,7 @@ def replacement(x: torch.Tensor, weight: torch.Tensor, pm.fwd_only, pm_pass) -class AsyncTPPass(VllmInductorPass): +class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig): @@ -378,18 +378,17 @@ def __init__(self, config: VllmConfig): AllGatherCutlassScaledMMPattern( self.model_dtype, self.device).register(self.patterns) + self.dump_patterns(config, self.patterns) + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_async_tp_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with async TP pass.", count) - self.dump_graph(graph, "after_async_tp_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) if flashinfer_comm is not None: @@ -1068,7 +1067,7 @@ def replacement(quant_result: torch.Tensor, residual: torch.Tensor, pm.fwd_only, pm_pass) -class AllReduceFusionPass(VllmInductorPass): +class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1124,6 +1123,7 @@ def __init__(self, config: VllmConfig): fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) self.register_patterns() + self.dump_patterns(config, self.patterns) @enable_fake_mode def register_patterns(self): @@ -1172,15 +1172,14 @@ def register_patterns(self): self.disabled = False + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: + logger.debug("AllReduceFusionPass disabled") return - self.begin() - self.dump_graph(graph, "before_all_reduce_fusion_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_all_reduce_fusion_pass") - self.end_and_log() + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) def __del__(self): if getattr(self, "disabled", True): diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 6bc721eec3d4..54403c1f7ca3 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass): To add new nodes to defunctionalize, add to the if-elif chain in __call__. """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): # XPU does not support auto-functionalization yet. # Will enable this when switch to vllm-xpu-kernels. @@ -34,9 +35,6 @@ def __call__(self, graph: torch.fx.Graph): "pass currently.") return - self.begin() - self.dump_graph(graph, "before_fix_functionalization") - self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 for node in graph.nodes: @@ -111,7 +109,7 @@ def __call__(self, graph: torch.fx.Graph): count += 1 - self.dump_graph(graph, "before_fix_functionalization_cleanup") + self.dump_graph(graph, "before_cleanup") # Remove the nodes all at once count_removed = len(self.nodes_to_remove) @@ -120,8 +118,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("De-functionalized %s nodes, removed %s nodes", count, count_removed) - self.dump_graph(graph, "after_fix_functionalization") - self.end_and_log() + self.nodes_to_remove.clear() def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index afa739c966a5..3034b6eaeaca 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, NamedTuple, Optional +from typing import Any, NamedTuple import torch import torch._inductor.pattern_matcher as pm @@ -16,10 +16,8 @@ kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform -from .fx_utils import find_getitem_maybe from .inductor_pass import enable_fake_mode -from .multi_output_match import MultiOutputMatch -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() @@ -50,8 +48,7 @@ def empty_i32(*args, **kwargs): torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[ - kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default class FusedRMSQuantKey(NamedTuple): @@ -80,68 +77,6 @@ def __str__(self): } -class QuantMultiOutputMatch(MultiOutputMatch): - - def __init__(self, match: pm.Match, quant_op, fused_op): - super().__init__(match) - assert isinstance(quant_op, OpOverload) - assert isinstance(fused_op, OpOverload) - self.QUANT_OP = quant_op # in-place quant op - self.FUSED_OP = fused_op # in-place fused quant op - - def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, - int]], - **kwargs): - """ - This utility function inserts an auto-functionalized node for FUSED_OP. - It also correctly sets its meta value and rebinds the users of the - unfused nodes to use the fused node instead. - - :param fused_return_mapping: A dictionary, mapping from getitem indices - of the fused node result to a tuple of the old node and a getitem index. - :param kwargs: kwargs that get directly forwarded to the auto_fn node - - Example: - If we want to replace this graph: - _, x1, x2 = auto_fn(op1) - _, y1, y2 = auto_fn(op2) - - with - _, x1, y2, x2 = auto_fn(FUSED_OP) - - we would call: - insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} - - Note that the 0th element is None for auto-functionalized in-place ops. - Hence, others appear 1-indexed. - """ - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) - indices = fused_return_mapping.keys() - getitem_nodes = self.insert_getitems(fused_node, indices) - - # Prepare the meta value, use a list so it's mutable - meta_val = [None] * (max(indices) + 1) - - # Iterate through elements of the tuple produced by fused_node - for idx, getitem_node in zip(indices, getitem_nodes): - old_node, old_idx = fused_return_mapping[idx] - - # If the old value was never used, the old_getitem might not exist - old_getitem = find_getitem_maybe(old_node, old_idx) - if old_getitem is not None: - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - old_getitem.replace_all_uses_with(getitem_node) - getitem_node.meta["val"] = old_getitem.meta["val"] - - # Extract the appropriate meta value - # It is present even if the getitem node does not exist - meta_val[idx] = old_node.meta["val"][old_idx] - - # Fix the meta value on the new fused node - fused_node.meta["val"] = tuple(meta_val) - - class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): @@ -224,8 +159,7 @@ def __init__(self, symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, @@ -271,36 +205,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and residual. - # The auto_fn node returns a tuple of (None, result, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa - # result_node_new = at[1] - # residual_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - # 0 is always None - fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} - self.insert_fused_node(fused_return_mapping, - **kwargs, - epsilon=rms_node.kwargs["epsilon"]) + ) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -317,8 +222,7 @@ def __init__(self, symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -366,39 +270,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 1 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and scale. - # The auto_fn node returns a tuple of (None, result, scale). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - del kwargs["result_rms"] # not used in the fused op - - fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - residual=None, # not used but required - **kwargs) + ) class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -415,8 +287,7 @@ def __init__(self, symmetric=symmetric)) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): + def register(self, pm_pass: PatternMatcherPass): def pattern(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, @@ -464,137 +335,49 @@ def replacement(result: torch.Tensor, input: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract result, scale, and residual. - # The auto_fn node returns a tuple (None, result, scale, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - # residual_node_new = at[3] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - fused_return_mapping = { - 1: (quant_node, 1), # result - 2: (quant_node, 2), # scale - 3: (rms_node, 2), # residual - } - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - **kwargs) - - -class FusionPass(VllmInductorPass): + ) + + +class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ - This pass fuses a pre-defined set of custom ops into fused ops. - It uses the torch pattern matcher to find the patterns and replace them. - It also manually processes multi-output matches, as those are broken in - the torch pattern matcher. - - Because patterns can only be registered once, the pass is a singleton. - This will be addressed in a future version of PyTorch: - https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. """ - _instance: 'Optional[FusionPass]' = None - - @classmethod - def instance(cls, config: VllmConfig): - """ - Get the singleton instance of the FusionPass. - If the instance exists, the config is updated but - initialization is not repeated. - """ - if cls._instance is None: - cls._instance = FusionPass(config) - else: - cls._instance.pass_config = config.compilation_config.pass_config - return cls._instance - @enable_fake_mode def __init__(self, config: VllmConfig): - assert self.__class__._instance is None, \ - "FusionPass singleton instance already exists" super().__init__(config) - self.matches: list[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="fusion_pass") + pass_name="rmsnorm_quant_fusion_pass") for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Matches for patterns below have 2 or more outputs, - # so we need to process them manually (see process_matches) - - # Fuse rms_norm + static fp8 quant + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + self.patterns) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) - - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() - - def record_match(self, match: MultiOutputMatch) -> bool: - # Hijack the extra_check to record the match and - # save it for post-processing. - self.matches.append(match) - - # Return False to prevent automatic replacement. - return False - - def process_matches(self, graph: fx.Graph): - """ - Manually process multi-output matches and replace them with fused nodes. - See MultiOutputMatch for more details. - """ - for match in self.matches: - match.process() + self.patterns) - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in self.matches - for node in match.match.nodes) + self.dump_patterns(config, self.patterns) + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_fusion") - - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_pattern_match") - - # Manually process multi-output matches (and run DCE) - self.process_matches(graph) - logger.debug("Post-processed %s matches", len(self.matches)) - self.dump_graph(graph, "after_fusion") - self.matches.clear() - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + return self.hash_source(self, RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index e3677b3dd62d..2c6cf8f12fdc 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -18,7 +18,7 @@ from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -245,7 +245,7 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pm_pass) -class AttnFusionPass(VllmInductorPass): +class AttnFusionPass(VllmPatternMatcherPass): """ This pass fuses post-attention quantization onto attention if supported. @@ -282,20 +282,12 @@ def __init__(self, config: VllmConfig): "were found in CompilationConfig.static_forward_context " "so no fusion patterns were registered.") - def __call__(self, graph: torch.fx.graph.Graph) -> None: - self.begin() - self.dump_graph(graph, "before_attn_fusion") - - count = self.patterns.apply(graph) + self.dump_patterns(config, self.patterns) - # TODO: Move this to pass_manager.py after the fx graph broken issue - # has been resolved. - # see https://github.com/vllm-project/vllm/issues/23091 - graph.eliminate_dead_code() - - logger.debug("Fused quantization onto %s attention nodes", count) - self.dump_graph(graph, "after_attn_fusion") - self.end_and_log() + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.graph.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused quant onto %s attention nodes", self.matched_count) def uuid(self): return VllmInductorPass.hash_source(self, AttentionQuantPattern, diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py deleted file mode 100644 index 6d1893777cec..000000000000 --- a/vllm/compilation/multi_output_match.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import abc -import operator -from abc import abstractmethod -from collections.abc import Iterable - -from torch import fx -from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor import pattern_matcher as pm -from torch._ops import OpOverload -from torch.fx import Node - -from vllm.compilation.fx_utils import find_auto_fn - - -class MultiOutputMatch(abc.ABC): - """ - This class provides utilities to process multi-output matches and - manually insert replacements. - - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 - """ - - def __init__(self, match: pm.Match): - self.match = match - - @abstractmethod - def process(self): - """ - Process a multi-output match and manually insert the replacement. - - This method should: - 1. Insert the replacement nodes after the last node in the match. - 2. Rebind the users of nodes in the match to use the new nodes. - 3. Set meta["val"] for de-functionalization. - - The result of an auto-functionalized node is a tuple of tensors. - The first element is the return value of the function, usually None. - The remaining elements are the mutated args of the function. - - All auto-functionalized nodes must contain a proper meta["val"], - as it is used by de-functionalization. meta["val"] has to contain the - value of the node (tuple of tensors) that would be returned by the - functionalized node during tracing. - - Existing nodes in the graph all have this property set, but we have - to set it manually for new nodes we insert. - - Example: - # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None - at = auto_functionalized(torch.ops._C.foo.default, a, b, c) - # at.meta["val"] = (None, a, c) - """ - raise NotImplementedError - - @property - def nodes(self) -> list[fx.Node]: - return self.match.nodes - - @property - def graph(self) -> fx.Graph: - return self.match.graph - - def find_auto_fn(self, op) -> fx.Node: - """ - Find the first auto_functionalized node with the given op in the match. - """ - return find_auto_fn(self.nodes, op) - - def inserting_after_match(self): - """ - Insert nodes after the last node in the match. - This is done to avoid use-before-definition errors after inserting - replacement nodes. - """ - - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(self.graph.nodes): - if last_node_in_match in self.match.nodes: - break - else: - raise ValueError("No nodes in graph") - - return self.graph.inserting_after(last_node_in_match) - - def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> tuple[fx.Node, ...]: - """ - Insert operator.getitem nodes to extract elements from a tuple node. - - :param tuple_node: The tuple node to extract elements from. - :param indices: The indices of the elements to extract. - :return: Tuple of the new getitem nodes, corresponding to the indices. - """ - with self.graph.inserting_after(tuple_node): - return tuple( - self.graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices) - - def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: - """ - Insert an auto_functionalized node with the given op and kwargs. - """ - return self.graph.call_function(auto_functionalized, (op, ), - kwargs=kwargs) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 17e85e70218d..2c453daf873d 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass): out: "f16[s0, 4096]" = at[1] """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_noop_elimination") count = 0 # Remove no-op reshapes/views: for node in graph.nodes: @@ -121,8 +120,6 @@ def __call__(self, graph: torch.fx.Graph): count += 1 logger.debug("Removed %s no-op reshapes and slices", count) - self.dump_graph(graph, "after_noop_elimination") - self.end_and_log() # ---------------------- Reshape helpers ---------------------- def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 1b1cbe4fa12c..e323fa1f7734 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,15 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from torch import fx as fx +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import set_env_var + +from .post_cleanup import PostCleanupPass +from .vllm_inductor_pass import VllmInductorPass if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass - from .fusion import FusionPass + from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass if current_platform.is_cuda(): @@ -19,11 +25,28 @@ from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass from .sequence_parallelism import SequenceParallelismPass -from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +def with_pattern_match_debug(fn): + """ + Function decorator that turns on inductor pattern match debug + for the duration of the call. + Used to avoid logging builtin Inductor pattern matching. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None: + # optionally check rank here + with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val): + return fn(*args, **kwargs) + return fn(*args, **kwargs) + + return wrapper + + class PostGradPassManager(CustomGraphPass): """ The pass manager for post-grad passes. @@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: list[VllmInductorPass] = [] + self.passes: list[InductorPass] = [] + @with_pattern_match_debug def __call__(self, graph: fx.Graph): + VllmInductorPass.dump_prefix = 0 # reset dump index + shape = get_pass_context().runtime_shape for pass_ in self.passes: if pass_.is_applicable_for_shape(shape): pass_(graph) + VllmInductorPass.dump_prefix += 1 + + # post-cleanup goes before fix_functionalization + # because it requires a functional graph + self.post_cleanup(graph) + VllmInductorPass.dump_prefix += 1 # always run fix_functionalization last self.fix_functionalization(graph) + VllmInductorPass.dump_prefix = None # Cleanup index def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config @@ -61,14 +94,18 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_fusion: - self.passes += [FusionPass.instance(config)] + self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/post_cleanup.py b/vllm/compilation/post_cleanup.py new file mode 100644 index 000000000000..6a31f3935da7 --- /dev/null +++ b/vllm/compilation/post_cleanup.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from torch import fx + +from vllm.compilation.vllm_inductor_pass import VllmInductorPass + + +class PostCleanupPass(VllmInductorPass): + """ + This pass performs cleanup after custom passes. + It topologically sorts the graph and removes unused nodes. + This is needed because the pattern matcher does not guarantee producing + a topologically sorted graph, and there may be unused nodes left around. + """ + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + from torch._inductor.pattern_matcher import stable_topological_sort + stable_topological_sort(graph) + graph.eliminate_dead_code() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 1758ed4c86d2..a6ca50c925a2 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -15,7 +15,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -417,7 +417,7 @@ def replacement( pm.fwd_only, pm_pass) -class SequenceParallelismPass(VllmInductorPass): +class SequenceParallelismPass(VllmPatternMatcherPass): """ This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by @@ -466,19 +466,13 @@ def __init__(self, config: VllmConfig): LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) - - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() + self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_sequence_parallelism_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with sequence parallelism", count) - self.dump_graph(graph, "after_sequence_parallelism_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index b822b05b0f1e..837770d18199 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import functools +import operator import time +from pathlib import Path +from typing import ClassVar, Optional +import regex as re import torch from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.pattern_matcher import (PatternMatcherPass, + PatternPrettyPrinter) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass): An inductor pass with access to vLLM PassConfig. It provides timing, logging, and dumping utilities. """ + dump_prefix: ClassVar[Optional[int]] = None + """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config @@ -28,8 +36,24 @@ def __init__(self, config: VllmConfig): else None self.pass_name = self.__class__.__name__ + @staticmethod + def time_and_log(call_fn): + + @functools.wraps(call_fn) + def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before") + call_fn(self, graph) + self.dump_graph(graph, "after") + self.end_and_log() + + return wrapped + def dump_graph(self, graph: torch.fx.Graph, stage: str): - lazy_format_graph_code(stage, graph.owning_module) + i = VllmInductorPass.dump_prefix + i_str = "" if i is None else f".{i}" + lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}", + graph.owning_module) def begin(self): self._start_time = time.perf_counter_ns() @@ -40,6 +64,88 @@ def end_and_log(self): logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) +class VllmPatternMatcherPass(VllmInductorPass): + """ + A VllmInductorPass that uses the Inductor pattern matcher. + Its main use is providing the dump_patterns utility that dumps the + Inductor pattern matcher patterns into a file, which greatly aids debugging. + + TODO(luka) move more utilities to this pass. + """ + matched_count: int = 0 + """The number of matched patterns in the pass.""" + + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( + r"") + + def _replace_op_overloads(self, string: str) -> str: + """Replace with nicer formulations""" + return self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) + + def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): + """ + If debug dumping is enabled, dump the Inductor pattern-matcher patterns + into the debug_dump_path folder next to the dumped fx graphs. + + This method does its best to print something that looks like Python code + for easier debugging and potentially navigation. If any errors appear in + the output, please add to this method. + + TODO(luka): use pattern object to manually produce pattern graph + """ + debug_dump_path = config.compilation_config.debug_dump_path + if not debug_dump_path: + return + + rank = config.parallel_config.rank + debug_dump_path = Path(debug_dump_path) / f"rank_{rank}" + debug_dump_path.mkdir(parents=True, exist_ok=True) + + from vllm.utils import unique_filepath + file_path = unique_filepath( + lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py") + + with file_path.open("w") as f: + print( + f'# This file was produced by VllmPatternMatcherPass.' + f'dump_patterns for {self.pass_name}.\n' + f'# It does its best to produce valid-Python-looking code but' + f' please add to dump_patterns if there are any errors.\n\n' + f'from torch._higher_order_ops.auto_functionalize import ' + f'auto_functionalized as auto_functionalized\n' + f'from torch._inductor.pattern_matcher import *', + file=f) + + for node, patterns in pm_pass.patterns.items(): + # fix the operator.getitem repr + if node[1] == operator.getitem: + node_repr = f"({repr(node[0])}, operator.getitem)" + else: + node_repr = repr(node) + + node_repr = self._replace_op_overloads(node_repr) + + print(f"\n\n# Patterns for op: {node_repr}", file=f) + for i, pattern in enumerate(patterns): + # reserve auto_functionalized ahead of time + pp = PatternPrettyPrinter() + pp.namespace.create_name("auto_functionalized", None) + + # Assemble pattern + out_node = pp.pretty_print(pattern.pattern) + pattern_repr = "\n".join([f"def pattern_{i}():"] + [ + f"{pp.memoized_objs_names[key]} = " + f"{pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + [f"return {out_node}"]).replace("\n", "\n ") + + pattern_repr = self._replace_op_overloads(pattern_repr) + print(f"{pattern_repr}\n", file=f) + + class PrinterInductorPass(VllmInductorPass): def __init__(self, name: str, config: VllmConfig): diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ddd8de4324f6..9e88dad48deb 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -909,10 +909,9 @@ def set_current_vllm_config(vllm_config: VllmConfig, except Exception: raise else: - logger.debug("enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops) - logger.debug("disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops) + if check_compile: + vllm_config.compilation_config.custom_op_log_check() + if check_compile and \ vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ and compilation_counter.num_models_seen == num_models_seen: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 22b38daf46c3..34fa7fcfe7e8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -487,6 +487,12 @@ def __post_init__(self) -> None: "supported with torch>=2.9.0.dev. Set " "use_inductor_graph_partition=False instead.") + for op in self.custom_ops: + if op[0] not in {'+', '-'} and op not in {'all', 'none'}: + raise ValueError(f"Invalid syntax '{op}' for custom op, " + "must be 'all', 'none', '+op' or '-op' " + "(where 'op' is the registered op name)") + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -532,8 +538,8 @@ def init_with_cudagraph_sizes(self, for x in self.compile_sizes: if isinstance(x, str): assert x == "cudagraph_capture_sizes", \ - "Unrecognized size type in compile_sizes, " \ - f"expect 'cudagraph_capture_sizes', got {x}" + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) @@ -628,3 +634,41 @@ def is_attention_compiled_piecewise(self) -> bool: return use_fx_graph_piecewise_compilation or \ use_inductor_piecewise_compilation + + def custom_op_log_check(self): + """ + This method logs the enabled/disabled custom ops and checks that the + passed custom_ops field only contains relevant ops. + It is called at the end of set_current_vllm_config, + after the custom ops have been instantiated. + """ + + if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0: + logger.debug("No custom ops found in model.") + return + + logger.debug("enabled custom ops: %s", self.enabled_custom_ops) + logger.debug("disabled custom ops: %s", self.disabled_custom_ops) + + all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops) + for op in self.custom_ops: + if op in {"all", "none"}: + continue + + assert op[0] in {'+', '-'}, "Invalid custom op syntax " \ + "(should be checked during init)" + + # check if op name exists in model + op_name = op[1:] + if op_name not in all_ops_in_model: + from vllm.model_executor.custom_op import CustomOp + + # Does op exist at all or is it just not present in this model? + # Note: Only imported op classes appear in the registry. + missing_str = "doesn't exist (or wasn't imported/registered)" \ + if op_name not in CustomOp.op_registry \ + else "not present in model" + + enable_str = "enabling" if op[0] == '+' else "disabling" + logger.warning_once("Op '%s' %s, %s with '%s' has no effect", + op_name, missing_str, enable_str, op) diff --git a/vllm/envs.py b/vllm/envs.py index 3991a789d80f..92cf3eece324 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -188,6 +188,7 @@ VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None def get_default_cache_root(): @@ -440,6 +441,11 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", + # Debug pattern matching inside custom passes. + # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). + "VLLM_PATTERN_MATCH_DEBUG": + lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 968bba664f0a..3271822ea875 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3451,7 +3451,7 @@ def length_from_prompt_token_ids_or_embeds( prompt_token_ids: Optional[list[int]], prompt_embeds: Optional[torch.Tensor], ) -> int: - """Calculate the request length (in number of tokens) give either + """Calculate the request length (in number of tokens) give either prompt_token_ids or prompt_embeds. """ prompt_token_len = None if prompt_token_ids is None else len( @@ -3472,3 +3472,16 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_token_ids={prompt_token_len}" f" prompt_embeds={prompt_embeds_len}") return prompt_token_len + + +@contextlib.contextmanager +def set_env_var(key, value): + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old