Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,41 @@ Every plugin has three parts:
## Compatibility Guarantee

vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.

## Class Extensions

For specific classes that you want to plug in a custom implementation, you can use the `ExtensionManager` interface. The `ExtensionManager` interface allows you to register a custom implementation on top of an existing base class (i.e. extension group), and at runtime, you can instantiate your own implementations.

If you are extending a class that already has an `ExtensionManager`, you can simply reuse it. Otherwise, send an PR to add a new `ExtensionManager` for your target base class.

Below is a minimum example of how it works:

### 1. Create an ExtensionManager on the base class

```python
from vllm.plugins.extension_manager import ExtensionManager

class FooBase:
...

foo_manager = ExtensionManager(base_cls=FooBase)

```

### 2. Register your custom extension

```python
from ... import foo_manager

@foo_manager.register(names=["foo_impl"])
class FooImpl(FooBase):
...
```

### 3. Instantiate at runtime

```python
from ... import foo_manager

foo_impl_object = foo_manager.create(name="foo_impl")
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
tool_parser_manager)


def make_tool_call(name, arguments):
Expand Down Expand Up @@ -88,7 +90,7 @@ def make_tool_call(name, arguments):
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
expected_content):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"hunyuan_a13b")(mock_tokenizer)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
Expand Down Expand Up @@ -141,7 +143,7 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
mock_tokenizer = MagicMock()

tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"hunyuan_a13b")(mock_tokenizer)
reconstructor = run_tool_extraction_streaming(
tool_parser, model_deltas, assert_one_tool_per_delta=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
tool_parser_manager)

# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
Expand Down Expand Up @@ -59,7 +61,7 @@
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"llama4_pythonic")(mock_tokenizer)
model_output = "How can I help you today?"

Expand Down Expand Up @@ -161,7 +163,7 @@ def test_no_tool_call(streaming: bool):
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: list[FunctionCall]):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"llama4_pythonic")(mock_tokenizer)

content, tool_calls = run_tool_extraction(tool_parser,
Expand All @@ -176,7 +178,7 @@ def test_tool_call(streaming: bool, model_output: str,

def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"llama4_pythonic")(mock_tokenizer)
model_output_deltas = [
"<|python_start|>[get_weather(city='LA', metric='C'), "
Expand All @@ -198,7 +200,7 @@ def test_streaming_tool_call_with_large_steps():
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"llama4_pythonic")(mock_tokenizer)

fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
tool_parser_manager)

# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
Expand Down Expand Up @@ -58,8 +60,8 @@
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"pythonic")(mock_tokenizer)
model_output = "How can I help you today?"

content, tool_calls = run_tool_extraction(tool_parser,
Expand Down Expand Up @@ -127,8 +129,8 @@ def test_no_tool_call(streaming: bool):
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: list[FunctionCall]):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"pythonic")(mock_tokenizer)

content, tool_calls = run_tool_extraction(tool_parser,
model_output,
Expand All @@ -143,8 +145,8 @@ def test_tool_call(streaming: bool, model_output: str,

def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"pythonic")(mock_tokenizer)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
Expand All @@ -166,7 +168,7 @@ def test_streaming_tool_call_with_large_steps():
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
tool_parser: ToolParser = tool_parser_manager.get_extension_class(
"llama4_pythonic")(mock_tokenizer)

fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
Expand Down
20 changes: 5 additions & 15 deletions tests/model_executor/model_loader/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from torch import nn

from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader import (get_model_loader,
register_model_loader)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.base_loader import (BaseModelLoader,
model_loader_manager)


@register_model_loader("custom_load_format")
@model_loader_manager.register(names=["custom_load_format"])
class CustomModelLoader(BaseModelLoader):

def __init__(self, load_config: LoadConfig) -> None:
Expand All @@ -25,13 +23,5 @@ def load_weights(self, model: nn.Module,


def test_register_model_loader():
load_config = LoadConfig(load_format="custom_load_format")
assert isinstance(get_model_loader(load_config), CustomModelLoader)


def test_invalid_model_loader():
with pytest.raises(ValueError):

@register_model_loader("invalid_load_format")
class InValidModelLoader:
pass
assert isinstance(model_loader_manager.create("custom_load_format"),
CustomModelLoader)
104 changes: 104 additions & 0 deletions tests/plugins/test_extension_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.plugins.extension_manager import ExtensionManager


class BaseA:

def __init__(self) -> None:
pass


class BaseB:

def __init__(self) -> None:
pass


extension_manager_a = ExtensionManager(base_cls=BaseA)
extension_manager_b = ExtensionManager(base_cls=BaseB)


@extension_manager_a.register(names=["a1"])
class ChildA1(BaseA):

def __init__(self) -> None:
super().__init__()


@extension_manager_a.register(names=["a2", "a2_alias"])
class ChildA2(BaseA):

def __init__(self) -> None:
super().__init__()


@extension_manager_b.register(names=["b1"])
class ChildB1(BaseB):

def __init__(self) -> None:
super().__init__()


@extension_manager_b.register(names=["b2"])
class ChildB2(BaseB):

def __init__(self) -> None:
super().__init__()


def test_extension_manager_can_register_and_create():
a1_obj = extension_manager_a.create("a1")
a2_obj = extension_manager_a.create("a2")

assert isinstance(a1_obj, ChildA1)
assert isinstance(a2_obj, ChildA2)

b1_obj = extension_manager_b.create("b1")
b2_obj = extension_manager_b.create("b2")

assert isinstance(b1_obj, ChildB1)
assert isinstance(b2_obj, ChildB2)


def test_extension_manager_can_register_and_get_type():
a1_cls = extension_manager_a.get_extension_class("a1")
a2_cls = extension_manager_a.get_extension_class("a2")

assert a1_cls is ChildA1
assert a2_cls is ChildA2

b1_cls = extension_manager_b.get_extension_class("b1")
b2_cls = extension_manager_b.get_extension_class("b2")

assert b1_cls is ChildB1
assert b2_cls is ChildB2


def test_extension_manager_can_register_and_create_with_alias():
a2_alias_obj = extension_manager_a.create("a2_alias")

assert isinstance(a2_alias_obj, ChildA2)


def test_extension_manager_throws_error_on_unknown_names():
with pytest.raises(ValueError):
extension_manager_a.create("c1")

with pytest.raises(ValueError):
extension_manager_b.create("c1")


def test_extension_manager_valid_names():
assert extension_manager_a.get_valid_extension_names() == [
"a1", "a2", "a2_alias"
]
assert extension_manager_b.get_valid_extension_names() == ["b1", "b2"]


def test_extension_manager_must_be_unique_per_base_class():
with pytest.raises(ValueError):
_ = ExtensionManager(base_cls=BaseA)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import SamplingParams
from vllm.config import LoadConfig
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.base_loader import model_loader_manager

load_format = "runai_streamer"
test_model = "openai-community/gpt2"
Expand All @@ -19,8 +18,7 @@


def get_runai_model_loader():
load_config = LoadConfig(load_format=load_format)
return get_model_loader(load_config)
return model_loader_manager.create(load_format)


def test_get_model_loader_with_runai_flag():
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.base_loader import model_loader_manager
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(self,
model_config = engine_args.create_model_config()
load_config = engine_args.create_load_config()

model_loader = get_model_loader(load_config)
model_loader = model_loader_manager.create(load_config.load_format)
model_loader.download_model(model_config)

self._start_server(model, vllm_serve_args, env_dict)
Expand Down
10 changes: 6 additions & 4 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
tool_parser_manager)
from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
ToolServer)
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
log_non_default_args, with_cancellation)
from vllm.logger import init_logger
from vllm.plugins.extension_manager import ExtensionManagerRegistry
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -1854,7 +1856,7 @@ def create_server_unix_socket(path: str) -> socket.socket:


def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
valid_tool_parses = tool_parser_manager.get_valid_extension_names()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
Expand All @@ -1876,7 +1878,7 @@ def setup_server(args):
log_non_default_args(args)

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
ExtensionManagerRegistry.import_extension(args.tool_parser_plugin)

validate_api_server_args(args)

Expand Down Expand Up @@ -1928,7 +1930,7 @@ async def run_server_worker(listen_address,
"""Run a single API server worker."""

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
ExtensionManagerRegistry.import_extension(args.tool_parser_plugin)

server_index = client_config.get("client_index", 0) if client_config else 0

Expand Down
Loading
Loading