diff --git a/src/mcp/types.py b/src/mcp/types.py index 871322740..cce8b1be0 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,7 +1,8 @@ +import re from collections.abc import Callable from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, field_validator from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import deprecated @@ -39,6 +40,10 @@ RequestId = Annotated[int, Field(strict=True)] | str AnyFunction: TypeAlias = Callable[..., Any] +# Tool name validation pattern (ASCII letters, digits, underscore, dash, dot) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +TOOL_NAME_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$") + class RequestParams(BaseModel): class Meta(BaseModel): @@ -891,6 +896,22 @@ class Tool(BaseMetadata): """ model_config = ConfigDict(extra="allow") + @field_validator("name") + @classmethod + def _validate_tool_name(cls, value: str) -> str: + if not (1 <= len(value) <= 128): + raise ValueError(f"Invalid tool name length: {len(value)}. Tool name must be between 1 and 128 characters.") + + if not TOOL_NAME_PATTERN.fullmatch(value): + raise ValueError("Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.).") + + return value + + """ + See [MCP specification](https://modelcontextprotocol.io/specification/draft/server/tools#tool-names) + for more information on tool naming conventions. + """ + class ListToolsResult(PaginatedResult): """The server's response to a tools/list request from the client.""" diff --git a/tests/test_types.py b/tests/test_types.py index 415eba66a..a505fa84d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -9,6 +9,7 @@ InitializeRequestParams, JSONRPCMessage, JSONRPCRequest, + Tool, ) @@ -56,3 +57,37 @@ async def test_method_initialization(): assert initialize_request.method == "initialize", "method should be set to 'initialize'" assert initialize_request.params is not None assert initialize_request.params.protocolVersion == LATEST_PROTOCOL_VERSION + + +@pytest.mark.parametrize( + "name", + [ + "getUser", + "DATA_EXPORT_v2", + "admin.tools.list", + "a", + "Z9_.-", + "x" * 128, # max length + ], +) +def test_tool_allows_valid_names(name: str) -> None: + Tool(name=name, inputSchema={"type": "object"}) + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("", "Invalid tool name length: 0. Tool name must be between 1 and 128 characters."), + ("x" * 129, "Invalid tool name length: 129. Tool name must be between 1 and 128 characters."), + ("has space", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ("comma,name", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ("not/allowed", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ("name@", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ("name#", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ("name$", "Invalid tool name characters. Allowed: A-Z, a-z, 0-9, underscore (_), dash (-), dot (.)."), + ], +) +def test_tool_rejects_invalid_names(name: str, expected: str) -> None: + with pytest.raises(ValueError) as exc_info: + Tool(name=name, inputSchema={"type": "object"}) + assert expected in str(exc_info.value)