diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index ca1c2c2305d9..b20bf73fa20d 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -56,3 +56,41 @@ Every plugin has three parts: ## Compatibility Guarantee vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. + +## Class Extensions + +For specific classes that you want to plug in a custom implementation, you can use the `ExtensionManager` interface. The `ExtensionManager` interface allows you to register a custom implementation on top of an existing base class (i.e. extension group), and at runtime, you can instantiate your own implementations. + +If you are extending a class that already has an `ExtensionManager`, you can simply reuse it. Otherwise, send an PR to add a new `ExtensionManager` for your target base class. + +Below is a minimum example of how it works: + +### 1. Create an ExtensionManager on the base class + +```python +from vllm.plugins.extension_manager import ExtensionManager + +class FooBase: + ... + +foo_manager = ExtensionManager(base_cls=FooBase) + +``` + +### 2. Register your custom extension + +```python +from ... import foo_manager + +@foo_manager.register(names=["foo_impl"]) +class FooImpl(FooBase): + ... +``` + +### 3. Instantiate at runtime + +```python +from ... import foo_manager + +foo_impl_object = foo_manager.create(name="foo_impl") +``` diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bd8e06513e13..e77a867e2a7f 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -10,7 +10,9 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) def make_tool_call(name, arguments): @@ -88,7 +90,7 @@ def make_tool_call(name, arguments): def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, expected_content): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "hunyuan_a13b")(mock_tokenizer) content, tool_calls = run_tool_extraction(tool_parser, model_output, @@ -141,7 +143,7 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "hunyuan_a13b")(mock_tokenizer) reconstructor = run_tool_extraction_streaming( tool_parser, model_deltas, assert_one_tool_per_delta=False) diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8c86b4889e15..f0f21fd9f19d 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -8,7 +8,9 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) # Test cases similar to pythonic parser but with Llama4 specific format SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" @@ -59,7 +61,7 @@ @pytest.mark.parametrize("streaming", [True, False]) def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "llama4_pythonic")(mock_tokenizer) model_output = "How can I help you today?" @@ -161,7 +163,7 @@ def test_no_tool_call(streaming: bool): def test_tool_call(streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "llama4_pythonic")(mock_tokenizer) content, tool_calls = run_tool_extraction(tool_parser, @@ -176,7 +178,7 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "llama4_pythonic")(mock_tokenizer) model_output_deltas = [ "<|python_start|>[get_weather(city='LA', metric='C'), " @@ -198,7 +200,7 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "llama4_pythonic")(mock_tokenizer) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index d83137472598..26557eed5676 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -8,7 +8,9 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" @@ -58,8 +60,8 @@ @pytest.mark.parametrize("streaming", [True, False]) def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + tool_parser: ToolParser = tool_parser_manager.get_extension_class( + "pythonic")(mock_tokenizer) model_output = "How can I help you today?" content, tool_calls = run_tool_extraction(tool_parser, @@ -127,8 +129,8 @@ def test_no_tool_call(streaming: bool): def test_tool_call(streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + tool_parser: ToolParser = tool_parser_manager.get_extension_class( + "pythonic")(mock_tokenizer) content, tool_calls = run_tool_extraction(tool_parser, model_output, @@ -143,8 +145,8 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + tool_parser: ToolParser = tool_parser_manager.get_extension_class( + "pythonic")(mock_tokenizer) model_output_deltas = [ "[get_weather(city='San", " Francisco', metric='celsius'), " @@ -166,7 +168,7 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( + tool_parser: ToolParser = tool_parser_manager.get_extension_class( "llama4_pythonic")(mock_tokenizer) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 93a3e34835b5..e2c6ab69fc9b 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -1,16 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest from torch import nn from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.model_loader import (get_model_loader, - register_model_loader) -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) -@register_model_loader("custom_load_format") +@model_loader_manager.register(names=["custom_load_format"]) class CustomModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig) -> None: @@ -25,13 +23,5 @@ def load_weights(self, model: nn.Module, def test_register_model_loader(): - load_config = LoadConfig(load_format="custom_load_format") - assert isinstance(get_model_loader(load_config), CustomModelLoader) - - -def test_invalid_model_loader(): - with pytest.raises(ValueError): - - @register_model_loader("invalid_load_format") - class InValidModelLoader: - pass + assert isinstance(model_loader_manager.create("custom_load_format"), + CustomModelLoader) diff --git a/tests/plugins/test_extension_manager.py b/tests/plugins/test_extension_manager.py new file mode 100644 index 000000000000..aab18d12415c --- /dev/null +++ b/tests/plugins/test_extension_manager.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.plugins.extension_manager import ExtensionManager + + +class BaseA: + + def __init__(self) -> None: + pass + + +class BaseB: + + def __init__(self) -> None: + pass + + +extension_manager_a = ExtensionManager(base_cls=BaseA) +extension_manager_b = ExtensionManager(base_cls=BaseB) + + +@extension_manager_a.register(names=["a1"]) +class ChildA1(BaseA): + + def __init__(self) -> None: + super().__init__() + + +@extension_manager_a.register(names=["a2", "a2_alias"]) +class ChildA2(BaseA): + + def __init__(self) -> None: + super().__init__() + + +@extension_manager_b.register(names=["b1"]) +class ChildB1(BaseB): + + def __init__(self) -> None: + super().__init__() + + +@extension_manager_b.register(names=["b2"]) +class ChildB2(BaseB): + + def __init__(self) -> None: + super().__init__() + + +def test_extension_manager_can_register_and_create(): + a1_obj = extension_manager_a.create("a1") + a2_obj = extension_manager_a.create("a2") + + assert isinstance(a1_obj, ChildA1) + assert isinstance(a2_obj, ChildA2) + + b1_obj = extension_manager_b.create("b1") + b2_obj = extension_manager_b.create("b2") + + assert isinstance(b1_obj, ChildB1) + assert isinstance(b2_obj, ChildB2) + + +def test_extension_manager_can_register_and_get_type(): + a1_cls = extension_manager_a.get_extension_class("a1") + a2_cls = extension_manager_a.get_extension_class("a2") + + assert a1_cls is ChildA1 + assert a2_cls is ChildA2 + + b1_cls = extension_manager_b.get_extension_class("b1") + b2_cls = extension_manager_b.get_extension_class("b2") + + assert b1_cls is ChildB1 + assert b2_cls is ChildB2 + + +def test_extension_manager_can_register_and_create_with_alias(): + a2_alias_obj = extension_manager_a.create("a2_alias") + + assert isinstance(a2_alias_obj, ChildA2) + + +def test_extension_manager_throws_error_on_unknown_names(): + with pytest.raises(ValueError): + extension_manager_a.create("c1") + + with pytest.raises(ValueError): + extension_manager_b.create("c1") + + +def test_extension_manager_valid_names(): + assert extension_manager_a.get_valid_extension_names() == [ + "a1", "a2", "a2_alias" + ] + assert extension_manager_b.get_valid_extension_names() == ["b1", "b2"] + + +def test_extension_manager_must_be_unique_per_base_class(): + with pytest.raises(ValueError): + _ = ExtensionManager(base_cls=BaseA) diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index 84c615b6b8db..8288bac1284d 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,8 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig -from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.base_loader import model_loader_manager load_format = "runai_streamer" test_model = "openai-community/gpt2" @@ -19,8 +18,7 @@ def get_runai_model_loader(): - load_config = LoadConfig(load_format=load_format) - return get_model_loader(load_config) + return model_loader_manager.create(load_format) def test_get_model_loader_with_runai_flag(): diff --git a/tests/utils.py b/tests/utils.py index 9d2073f3c103..44c0db24cae4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,7 +34,7 @@ init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand -from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.base_loader import model_loader_manager from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, @@ -151,7 +151,7 @@ def __init__(self, model_config = engine_args.create_model_config() load_config = engine_args.create_load_config() - model_loader = get_model_loader(load_config) + model_loader = model_loader_manager.create(load_config.load_format) model_loader.download_model(model_config) self._start_server(model, vllm_serve_args, env_dict) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index db02767fdfd7..3c06d6378513 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -93,12 +93,14 @@ OpenAIServingTokenization) from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) -from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer, ToolServer) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) from vllm.logger import init_logger +from vllm.plugins.extension_manager import ExtensionManagerRegistry from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -1854,7 +1856,7 @@ def create_server_unix_socket(path: str) -> socket.socket: def validate_api_server_args(args): - valid_tool_parses = ToolParserManager.tool_parsers.keys() + valid_tool_parses = tool_parser_manager.get_valid_extension_names() if args.enable_auto_tool_choice \ and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " @@ -1876,7 +1878,7 @@ def setup_server(args): log_non_default_args(args) if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) + ExtensionManagerRegistry.import_extension(args.tool_parser_plugin) validate_api_server_args(args) @@ -1928,7 +1930,7 @@ async def run_server_worker(listen_address, """Run a single API server worker.""" if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) + ExtensionManagerRegistry.import_extension(args.tool_parser_plugin) server_index = client_config.get("client_index", 0) if client_config else 0 diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 6e4eff5c8024..7307241ee433 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -23,7 +23,8 @@ from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -210,7 +211,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: frontend_kwargs["middleware"]["default"] = [] # Special case: Tool call parser shows built-in options. - valid_tool_parsers = list(ToolParserManager.tool_parsers.keys()) + valid_tool_parsers = tool_parser_manager.get_valid_extension_names() parsers_str = ",".join(valid_tool_parsers) frontend_kwargs["tool_call_parser"]["metavar"] = ( f"{{{parsers_str}}} or name registered in --tool-parser-plugin") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8b50153f0115..692c7fbed90f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -37,7 +37,9 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, clamp_prompt_logprobs) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.entrypoints.utils import get_max_tokens @@ -116,8 +118,9 @@ def __init__( logger.warning( "Llama3.2 models may struggle to emit valid pythonic" " tool calls") - self.tool_parser = ToolParserManager.get_tool_parser( - tool_parser) + assert tool_parser is not None + self.tool_parser = tool_parser_manager.get_extension_class( + name=tool_parser) except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " f"tool_parser:'{tool_parser}' which has not " diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 44aa1208a54c..95e1aeefa6c8 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .abstract_tool_parser import ToolParser, ToolParserManager +from .abstract_tool_parser import ToolParser from .deepseekv3_tool_parser import DeepSeekV3ToolParser from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser @@ -25,7 +25,6 @@ __all__ = [ "ToolParser", - "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 02aeab613631..8f70d714278c 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from collections.abc import Sequence from functools import cached_property from typing import Callable, Optional, Union @@ -10,8 +9,10 @@ DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger +from vllm.plugins.extension_manager import (ExtensionManager, + ExtensionManagerRegistry) from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import import_from_path, is_list_of +from vllm.utils import is_list_of logger = init_logger(__name__) @@ -80,20 +81,12 @@ def extract_tool_calls_streaming( "implemented!") -class ToolParserManager: - tool_parsers: dict[str, type] = {} +tool_parser_manager = ExtensionManager(base_cls=ToolParser) - @classmethod - def get_tool_parser(cls, name) -> type: - """ - Get tool parser by name which is registered by `register_module`. - - Raise a KeyError exception if the name is not registered. - """ - if name in cls.tool_parsers: - return cls.tool_parsers[name] - raise KeyError(f"tool helper: '{name}' not found in tool_parsers") +# Legacy ToolParserManager class, kept for compatibility. +# Use `@tool_parser_manager.register(names=["foo"])` to register new tool parsers. +class ToolParserManager: @classmethod def _register_module(cls, @@ -108,12 +101,14 @@ def _register_module(cls, module_name = module.__name__ if isinstance(module_name, str): module_name = [module_name] + if ToolParser.__name__ not in ExtensionManagerRegistry._registry: + ExtensionManagerRegistry._registry[ToolParser.__name__] = {} for name in module_name: - if not force and name in cls.tool_parsers: - existed_module = cls.tool_parsers[name] - raise KeyError(f'{name} is already registered ' - f'at {existed_module.__module__}') - cls.tool_parsers[name] = module + if not force and name in ExtensionManagerRegistry._registry[ + ToolParser.__name__]: + raise KeyError(f'Tool parser {name} is already registered') + ExtensionManagerRegistry._registry[ + ToolParser.__name__][name] = module @classmethod def register_module( @@ -147,18 +142,3 @@ def _register(module): return module return _register - - @classmethod - def import_tool_parser(cls, plugin_path: str) -> None: - """ - Import a user-defined tool parser by the path of the tool parser define - file. - """ - module_name = os.path.splitext(os.path.basename(plugin_path))[0] - - try: - import_from_path(module_name, plugin_path) - except Exception: - logger.exception("Failed to load module '%s' from %s.", - module_name, plugin_path) - return diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index ac272b0c3b20..209e738a8f74 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -13,14 +13,14 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("deepseek_v3") +@tool_parser_manager.register(names=["deepseek_v3"]) class DeepSeekV3ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 8fd14f171d0a..94d8f9c3e5d9 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -15,14 +15,14 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("glm45") +@tool_parser_manager.register(names=["glm45"]) class Glm4MoeModelToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 824b100f357b..05a62df57981 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -17,7 +17,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, find_common_prefix, is_complete_json, @@ -28,7 +28,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("granite-20b-fc") +@tool_parser_manager.register(names=["granite-20b-fc"]) class Granite20bFCToolParser(ToolParser): """ Tool call parser for the granite-20b-functioncalling model intended diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index ac517616a95b..8e4648b28993 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -15,7 +15,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, find_common_prefix, is_complete_json, @@ -26,7 +26,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("granite") +@tool_parser_manager.register(names=["granite"]) class GraniteToolParser(ToolParser): """ Tool call parser for the granite 3.0 models. Intended diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index a6ce33af6bd0..75c86ee122e9 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -16,14 +16,14 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("hermes") +@tool_parser_manager.register(names=["hermes"]) class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index 2b65f2579fb4..4b92ca90c88b 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -14,7 +14,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -23,7 +23,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("hunyuan_a13b") +@tool_parser_manager.register(names=["hunyuan_a13b"]) class HunyuanA13BToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 6ef8fadf59ac..b0d060e516ab 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -15,7 +15,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -24,7 +24,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module(["internlm"]) +@tool_parser_manager.register(names=["internlm"]) class Internlm2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 3b41f6034704..e3b7d0b8e29f 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -15,7 +15,9 @@ DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -25,7 +27,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("jamba") +@tool_parser_manager.register(names=["jamba"]) class JambaToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 834b33052b45..177bf56313c7 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -13,14 +13,14 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["kimi_k2"]) +@tool_parser_manager.register(names=["kimi_k2"]) class KimiK2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 6bf44a4345a9..4f86a8bd4b21 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -15,7 +15,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger logger = init_logger(__name__) @@ -25,7 +25,7 @@ class _UnexpectedAstError(Exception): pass -@ToolParserManager.register_module("llama4_pythonic") +@tool_parser_manager.register(names=["llama4_pythonic"]) class Llama4PythonicToolParser(ToolParser): """ Toolcall parser for Llama4 that produce tool calls in a pythonic style diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 31b19c8db416..7a8718ee874b 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -17,7 +17,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, is_complete_json, partial_json_loads) @@ -26,8 +26,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("llama3_json") -@ToolParserManager.register_module("llama4_json") +@tool_parser_manager.register(names=["llama3_json", "llama4_json"]) class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.x and 4 models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 283e6095013d..abdf67bb7734 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -14,7 +14,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -23,7 +23,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("minimax") +@tool_parser_manager.register(names=["minimax"]) class MinimaxToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f122904..d802e9a74680 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -18,7 +18,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -49,7 +49,7 @@ def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: and model_tokenizer.version >= 11 -@ToolParserManager.register_module("mistral") +@tool_parser_manager.register(names=["mixtral"]) class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 85dd56213c6a..69dd064b925b 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -14,13 +14,13 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger logger = init_logger(__name__) -@ToolParserManager.register_module("phi4_mini_json") +@tool_parser_manager.register(names=["phi4_mini_json"]) class Phi4MiniJsonToolParser(ToolParser): """ Tool call parser for phi-4-mini models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 73329cdf701d..ca7c48f06130 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -16,7 +16,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,7 +26,7 @@ class _UnexpectedAstError(Exception): pass -@ToolParserManager.register_module("pythonic") +@tool_parser_manager.register(names=["pythonic"]) class PythonicToolParser(ToolParser): """ Tool call parser for models that produce tool calls in a pythonic style, diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index 2501d6739e8f..3abb603ad1d9 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -15,14 +15,14 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@tool_parser_manager.register(names=["qwen3_coder"]) class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index a20d18eb5254..52fe53c1cfe2 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -14,7 +14,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -22,7 +22,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module(["step3"]) +@tool_parser_manager.register(names=["step3"]) class Step3ToolParser(ToolParser): """ Tool parser for a model that uses a specific XML-like format for tool calls. diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 87cd413b3720..9c66e1c89fad 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -14,7 +14,7 @@ ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, tool_parser_manager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -22,7 +22,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module("xlam") +@tool_parser_manager.register(names=["xlam"]) class xLAMToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3..2e09f7de6328 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -5,9 +5,10 @@ from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.bitsandbytes_loader import ( BitsAndBytesModelLoader) from vllm.model_executor.model_loader.default_loader import DefaultModelLoader @@ -40,79 +41,13 @@ "sharded_state", "tensorizer", ] -_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = { - "auto": DefaultModelLoader, - "bitsandbytes": BitsAndBytesModelLoader, - "dummy": DummyModelLoader, - "fastsafetensors": DefaultModelLoader, - "gguf": GGUFModelLoader, - "mistral": DefaultModelLoader, - "npcache": DefaultModelLoader, - "pt": DefaultModelLoader, - "runai_streamer": RunaiModelStreamerLoader, - "runai_streamer_sharded": ShardedStateLoader, - "safetensors": DefaultModelLoader, - "sharded_state": ShardedStateLoader, - "tensorizer": TensorizerLoader, -} - - -def register_model_loader(load_format: str): - """Register a customized vllm model loader. - - When a load format is not supported by vllm, you can register a customized - model loader to support it. - - Args: - load_format (str): The model loader format name. - - Examples: - >>> from vllm.config import LoadConfig - >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader - >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader - >>> - >>> @register_model_loader("my_loader") - ... class MyModelLoader(BaseModelLoader): - ... def download_model(self): - ... pass - ... - ... def load_weights(self): - ... pass - >>> - >>> load_config = LoadConfig(load_format="my_loader") - >>> type(get_model_loader(load_config)) - - """ # noqa: E501 - - def _wrapper(model_loader_cls): - if load_format in _LOAD_FORMAT_TO_MODEL_LOADER: - logger.warning( - "Load format `%s` is already registered, and will be " - "overwritten by the new loader class `%s`.", load_format, - model_loader_cls) - if not issubclass(model_loader_cls, BaseModelLoader): - raise ValueError("The model loader must be a subclass of " - "`BaseModelLoader`.") - _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls - logger.info("Registered model loader `%s` with load format `%s`", - model_loader_cls, load_format) - return model_loader_cls - - return _wrapper - - -def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: - """Get a model loader based on the load format.""" - load_format = load_config.load_format - if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER: - raise ValueError(f"Load format `{load_format}` is not supported") - return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config) def get_model(*, vllm_config: VllmConfig, model_config: Optional[ModelConfig] = None) -> nn.Module: - loader = get_model_loader(vllm_config.load_config) + loader = model_loader_manager.create(vllm_config.load_config.load_format, + vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config return loader.load_model(vllm_config=vllm_config, @@ -121,11 +56,9 @@ def get_model(*, __all__ = [ "get_model", - "get_model_loader", "get_architecture_class_name", "get_model_architecture", "get_model_cls", - "register_model_loader", "BaseModelLoader", "BitsAndBytesModelLoader", "GGUFModelLoader", diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960..9dc5cfea78aa 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.plugins.extension_manager import ExtensionManager logger = init_logger(__name__) @@ -49,3 +50,6 @@ def load_model(self, vllm_config: VllmConfig, self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) return model.eval() + + +model_loader_manager = ExtensionManager(base_cls=BaseModelLoader) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index b8393956eed3..e18c412a5253 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -27,7 +27,8 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.utils import (ParamMapping, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( @@ -51,6 +52,7 @@ def is_moe_model(model: torch.nn.Module) -> bool: isinstance(module, FusedMoE) for module in model.modules())) +@model_loader_manager.register(names=["bitsandbytes"]) class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 34b8d8e4ed62..c8e3daf21fc8 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -15,7 +15,8 @@ from vllm import envs from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, @@ -26,6 +27,9 @@ logger = init_logger(__name__) +@model_loader_manager.register(names=[ + "auto", "fastsafetensors", "mistral", "npcache", "pt", "safetensors" +]) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index f4a7da5744e0..38f067c1bb48 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -3,11 +3,13 @@ import torch.nn as nn from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.weight_utils import ( initialize_dummy_weights) +@model_loader_manager.register(names=["dummy"]) class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06..45c851c62982 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -10,7 +10,8 @@ from transformers import AutoModelForCausalLM from vllm.config import LoadConfig, ModelConfig, VllmConfig -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( @@ -18,6 +19,7 @@ gguf_quant_weights_iterator) +@model_loader_manager.register(names=["gguf"]) class GGUFModelLoader(BaseModelLoader): """ Model loader that can load GGUF files. This is useful for loading models diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 83e0f386c108..6a7bca1bfffb 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -11,7 +11,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, runai_safetensors_weights_iterator) @@ -19,6 +20,7 @@ from vllm.transformers_utils.utils import is_s3 +@model_loader_manager.register(names=["runai_streamer"]) class RunaiModelStreamerLoader(BaseModelLoader): """ Model loader that can load safetensors diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 3edd4ec4007e..d9de663460d3 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -12,7 +12,8 @@ from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, runai_safetensors_weights_iterator) from vllm.transformers_utils.s3_utils import glob as s3_glob @@ -21,6 +22,8 @@ logger = init_logger(__name__) +@model_loader_manager.register( + names=["runai_streamer_sharded", "sharded_state"]) class ShardedStateLoader(BaseModelLoader): """ Model loader that directly loads each worker's model state dict, which diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4ce..3ac3ac48d04a 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -10,7 +10,8 @@ from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import (BaseModelLoader, + model_loader_manager) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) @@ -33,6 +34,7 @@ def validate_config(config: dict): raise ValueError(f"{k} is not an allowed Tensorizer argument.") +@model_loader_manager.register(names=["tensorizer"]) class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index ef1380bdb614..4e7f1c87bcff 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -13,6 +13,7 @@ from PIL import Image from vllm import envs +from vllm.plugins.extension_manager import ExtensionManager from .base import MediaIO from .image import ImageMediaIO @@ -62,30 +63,12 @@ def load_bytes(cls, raise NotImplementedError -class VideoLoaderRegistry: +video_loader_manager = ExtensionManager(base_cls=VideoLoader) +# Kept for legacy import compatibility +VIDEO_LOADER_REGISTRY = video_loader_manager - def __init__(self) -> None: - self.name2class: dict[str, type] = {} - def register(self, name: str): - - def wrap(cls_to_register): - self.name2class[name] = cls_to_register - return cls_to_register - - return wrap - - @staticmethod - def load(cls_name: str) -> VideoLoader: - cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name) - assert cls is not None, f"VideoLoader class {cls_name} not found" - return cls() - - -VIDEO_LOADER_REGISTRY = VideoLoaderRegistry() - - -@VIDEO_LOADER_REGISTRY.register("opencv") +@video_loader_manager.register(names=["opencv"]) class OpenCVVideoBackend(VideoLoader): def get_cv2_video_api(self): @@ -178,7 +161,7 @@ def __init__( # for flexible control. self.kwargs = kwargs video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND - self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) + self.video_loader = video_loader_manager.create(video_loader_backend) def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: return self.video_loader.load_bytes(data, diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 1a1760df82c0..31f47fc1f107 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -70,3 +70,9 @@ def load_general_plugins(): # general plugins, we only need to execute the loaded functions for func in plugins.values(): func() + + +__all__ = [ + "load_plugins_by_group", + "load_general_plugins", +] diff --git a/vllm/plugins/extension_manager.py b/vllm/plugins/extension_manager.py new file mode 100644 index 000000000000..de0d0c319e35 --- /dev/null +++ b/vllm/plugins/extension_manager.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Union + +from vllm.logger import init_logger +from vllm.utils import import_from_path + +logger = init_logger(__name__) + + +class ExtensionManagerRegistry: + _registry: dict[str, dict[str, type]] = {} + + @staticmethod + def _group_key(base_cls: type) -> str: + return f"{base_cls.__module__}.{base_cls.__name__}" + + @staticmethod + def _register(base_cls: type, names: list[str]): + + def wrap(impl_cls: type): + for name in names: + if base_cls.__name__ not in ExtensionManagerRegistry._registry: + ExtensionManagerRegistry._registry[base_cls.__name__] = {} + if name in ExtensionManagerRegistry._registry[ + base_cls.__name__]: + raise ValueError( + f"Extension {name} already registered in group {base_cls.__name__}" # noqa: E501 + ) + ExtensionManagerRegistry._registry[ + base_cls.__name__][name] = impl_cls + return impl_cls + + return wrap + + @staticmethod + def _create(base_cls: type, name: str, *args, **kwargs) -> Any: + if extension_group := ExtensionManagerRegistry._registry.get( + base_cls.__name__): + if impl_cls := extension_group.get(name): + return impl_cls(*args, **kwargs) + else: + raise ValueError( + f"Extension {name} not found in group {base_cls.__name__}") + else: + raise ValueError(f"Extension group {base_cls.__name__} not found") + + @staticmethod + def _get_extension_class(base_cls: type, name: str) -> type: + if extension_group := ExtensionManagerRegistry._registry.get( + base_cls.__name__): + if impl_cls := extension_group.get(name): + return impl_cls + else: + raise ValueError( + f"Extension {name} not found in group {base_cls.__name__}") + else: + raise ValueError( + f"Extension base class {base_cls.__name__} not found") + + @staticmethod + def _get_valid_extension_names(base_cls: type) -> list[str]: + if extension_group := ExtensionManagerRegistry._registry.get( + base_cls.__name__): + return list(extension_group.keys()) + else: + return [] + + @staticmethod + def import_extension(extension_path: str) -> None: + """ + Import a user-defined extension by the path of the extension file. + """ + module_name = os.path.splitext(os.path.basename(extension_path))[0] + + try: + import_from_path(module_name, extension_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, extension_path) + return + + +class ExtensionManager: + + def __init__(self, base_cls: type) -> None: + if base_cls.__name__ in ExtensionManagerRegistry._registry: + raise ValueError( + f"Extension group {base_cls.__name__} already exists.") + ExtensionManagerRegistry._registry[base_cls.__name__] = {} + self.base_cls = base_cls + + def register(self, names: Union[str, list[str]]): + if isinstance(names, str): + names = [names] + return ExtensionManagerRegistry._register(self.base_cls, names) + + def create(self, name: str, *args, **kwargs) -> Any: + return ExtensionManagerRegistry._create(self.base_cls, name, *args, + **kwargs) + + def get_extension_class(self, name: str) -> type: + return ExtensionManagerRegistry._get_extension_class( + self.base_cls, name) + + def get_valid_extension_names(self) -> list[str]: + return ExtensionManagerRegistry._get_valid_extension_names( + self.base_cls) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 14f2305dadc5..f3dca479f727 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -38,7 +38,8 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader import TensorizerLoader +from vllm.model_executor.model_loader.base_loader import model_loader_manager from vllm.model_executor.models.interfaces import (is_mixture_of_experts, supports_eagle3, supports_transcription) @@ -1957,7 +1958,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() - model_loader = get_model_loader(self.load_config) + model_loader = model_loader_manager.create( + self.load_config.load_format, self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) @@ -2021,7 +2023,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded." - model_loader = get_model_loader(self.load_config) + model_loader = model_loader_manager.create( + self.load_config.load_format) logger.info("Reloading weights inplace...") model = self.get_model() model_loader.load_weights(model, model_config=self.model_config) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4a485b7e077d..cf904bd0335e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -26,7 +26,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA -from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.base_loader import model_loader_manager from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces_base import ( @@ -1198,7 +1198,8 @@ def load_model(self) -> None: model_config=self.vllm_config.model_config, mesh=self.mesh) else: - model_loader = get_model_loader(self.load_config) + model_loader = model_loader_manager.create( + self.load_config.load_format) logger.info("Loading model from scratch...") model = model_loader.load_model( vllm_config=self.vllm_config, @@ -1227,7 +1228,8 @@ def load_model(self) -> None: def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded." - model_loader = get_model_loader(self.load_config) + model_loader = model_loader_manager.create( + self.load_config.load_format) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config)