diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 39e3212e9..7bd1834e6 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -3,6 +3,7 @@ from __future__ import annotations import types +from enum import Enum from typing import Generic, Literal, TypeVar, Union, get_args, get_origin from pydantic import BaseModel @@ -37,7 +38,7 @@ class CancelledElicitation(BaseModel): # Primitive types allowed in elicitation schemas -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool, Enum) def _validate_elicitation_schema(schema: type[BaseModel]) -> None: @@ -70,6 +71,10 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: # All args must be primitive types or None return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + # Handle Enum types + if isinstance(annotation, type) and issubclass(annotation, str) and issubclass(annotation, Enum): + return True + return False diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 896eb1f80..4beb62b19 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -2,6 +2,7 @@ Test the elicitation feature using stdio transport. """ +from enum import Enum from typing import Any import pytest @@ -142,6 +143,39 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert "Validation failed as expected" in result.content[0].text assert field_name in result.content[0].text + # Test valid Enum types (should not fail validation) + class Status(str, Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + class ValidStrEnumSchema(BaseModel): + status: Status = Field(description="Status using StrEnum") + + def create_valid_validation_tool(name: str, schema_class: type[BaseModel]): + @mcp.tool(name=name, description=f"Tool testing {name}") + async def tool(ctx: Context[ServerSession, None]) -> str: + # This should succeed without validation error + result = await ctx.elicit(message="Test valid schema", schema=schema_class) + return f"Success: {result.action}" + + return tool + + create_valid_validation_tool("valid_strenum", ValidStrEnumSchema) + + async def enum_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Return the required status field + return ElicitResult(action="accept", content={"status": "active"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=enum_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("valid_strenum", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Success: accept" == result.content[0].text + @pytest.mark.anyio async def test_elicitation_with_optional_fields():