diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/shared/in_memory_task_store.py b/examples/shared/in_memory_task_store.py new file mode 100644 index 000000000..94b9a2bb3 --- /dev/null +++ b/examples/shared/in_memory_task_store.py @@ -0,0 +1,175 @@ +""" +In-memory implementation of TaskStore for demonstration purposes. + +This implementation stores all tasks in memory and provides automatic cleanup +based on the keepAlive duration specified in the task metadata. + +Note: This is not suitable for production use as all data is lost on restart. +For production, consider implementing TaskStore with a database or distributed cache. +""" + +import asyncio +from dataclasses import dataclass +from typing import Any + +from mcp.shared.task import TaskStatus, TaskStore, is_terminal +from mcp.types import Request, RequestId, Result, Task, TaskMetadata + + +@dataclass +class StoredTask: + """Internal storage representation of a task.""" + + task: Task + request: Request[Any, Any] + request_id: RequestId + result: Result | None = None + + +class InMemoryTaskStore(TaskStore): + """ + A simple in-memory implementation of TaskStore for demonstration purposes. + + This implementation stores all tasks in memory and provides automatic cleanup + based on the keepAlive duration specified in the task metadata. + + Note: This is not suitable for production use as all data is lost on restart. + For production, consider implementing TaskStore with a database or distributed cache. + """ + + def __init__(self) -> None: + self._tasks: dict[str, StoredTask] = {} + self._cleanup_tasks: dict[str, asyncio.Task[None]] = {} + + async def create_task(self, task: TaskMetadata, request_id: RequestId, request: Request[Any, Any]) -> None: + """Create a new task with the given metadata and original request.""" + task_id = task.taskId + + if task_id in self._tasks: + raise ValueError(f"Task with ID {task_id} already exists") + + task_obj = Task( + taskId=task_id, + status="submitted", + keepAlive=task.keepAlive, + pollInterval=500, # Default 500ms poll frequency + ) + + self._tasks[task_id] = StoredTask(task=task_obj, request=request, request_id=request_id) + + # Schedule cleanup if keepAlive is specified + if task.keepAlive is not None: + self._schedule_cleanup(task_id, task.keepAlive / 1000.0) + + async def get_task(self, task_id: str) -> Task | None: + """Get the current status of a task.""" + stored = self._tasks.get(task_id) + if stored is None: + return None + + # Return a copy to prevent external modification + return Task(**stored.task.model_dump()) + + async def store_task_result(self, task_id: str, result: Result) -> None: + """Store the result of a completed task.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + stored.result = result + stored.task.status = "completed" + + # Reset cleanup timer to start from now (if keepAlive is set) + if stored.task.keepAlive is not None: + self._cancel_cleanup(task_id) + self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0) + + async def get_task_result(self, task_id: str) -> Result: + """Retrieve the stored result of a task.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + if stored.result is None: + raise ValueError(f"Task {task_id} has no result stored") + + return stored.result + + async def update_task_status(self, task_id: str, status: TaskStatus, error: str | None = None) -> None: + """Update a task's status.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + stored.task.status = status + if error is not None: + stored.task.error = error + + # If task is in a terminal state and has keepAlive, start cleanup timer + if is_terminal(status) and stored.task.keepAlive is not None: + self._cancel_cleanup(task_id) + self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0) + + async def list_tasks(self, cursor: str | None = None) -> dict[str, Any]: + """ + List tasks, optionally starting from a pagination cursor. + + Returns a dict with 'tasks' list and optional 'nextCursor' string. + """ + PAGE_SIZE = 10 + all_task_ids = list(self._tasks.keys()) + + start_index = 0 + if cursor is not None: + try: + cursor_index = all_task_ids.index(cursor) + start_index = cursor_index + 1 + except ValueError: + raise ValueError(f"Invalid cursor: {cursor}") + + page_task_ids = all_task_ids[start_index : start_index + PAGE_SIZE] + tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] + + next_cursor = page_task_ids[-1] if start_index + PAGE_SIZE < len(all_task_ids) and page_task_ids else None + + return {"tasks": tasks, "nextCursor": next_cursor} + + async def delete_task(self, task_id: str) -> None: + """Delete a task from storage.""" + if task_id not in self._tasks: + raise ValueError(f"Task with ID {task_id} not found") + + # Cancel any scheduled cleanup + self._cancel_cleanup(task_id) + + # Remove the task + self._tasks.pop(task_id) + + def _schedule_cleanup(self, task_id: str, delay_seconds: float) -> None: + """Schedule automatic cleanup of a task after the specified delay.""" + + async def cleanup() -> None: + await asyncio.sleep(delay_seconds) + self._tasks.pop(task_id, None) + self._cleanup_tasks.pop(task_id, None) + + task = asyncio.create_task(cleanup()) + self._cleanup_tasks[task_id] = task + + def _cancel_cleanup(self, task_id: str) -> None: + """Cancel any scheduled cleanup for a task.""" + cleanup_task = self._cleanup_tasks.pop(task_id, None) + if cleanup_task is not None and not cleanup_task.done(): + cleanup_task.cancel() + + def cleanup(self) -> None: + """Cleanup all timers and tasks (useful for testing or graceful shutdown).""" + for task in self._cleanup_tasks.values(): + if not task.done(): + task.cancel() + self._cleanup_tasks.clear() + self._tasks.clear() + + def get_all_tasks(self) -> list[Task]: + """Get all tasks (useful for debugging). Returns copies to prevent modification.""" + return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/examples/snippets/servers/__init__.py b/examples/snippets/servers/__init__.py index b9865e822..9c9a41360 100644 --- a/examples/snippets/servers/__init__.py +++ b/examples/snippets/servers/__init__.py @@ -22,7 +22,7 @@ def run_server(): print("Usage: server [transport]") print("Available servers: basic_tool, basic_resource, basic_prompt, tool_progress,") print(" sampling, elicitation, completion, notifications,") - print(" fastmcp_quickstart, structured_output, images") + print(" fastmcp_quickstart, structured_output, images, task_based_tool") print("Available transports: stdio (default), sse, streamable-http") sys.exit(1) diff --git a/examples/snippets/servers/task_based_tool.py b/examples/snippets/servers/task_based_tool.py new file mode 100644 index 000000000..e4f8ac983 --- /dev/null +++ b/examples/snippets/servers/task_based_tool.py @@ -0,0 +1,32 @@ +"""Example server demonstrating task-based execution with long-running tools.""" + +import asyncio + +from examples.shared.in_memory_task_store import InMemoryTaskStore +from mcp.server.fastmcp import FastMCP + +# Create a task store to enable task-based execution +task_store = InMemoryTaskStore() +mcp = FastMCP(name="Task-Based Tool Example", task_store=task_store) + + +@mcp.tool() +async def long_running_computation(data: str, delay_seconds: float = 2.0) -> str: + """ + Simulate a long-running computation that benefits from task-based execution. + + This tool demonstrates the 'call-now, fetch-later' pattern where clients can: + 1. Initiate the task without waiting + 2. Disconnect and reconnect later + 3. Poll for status and retrieve results when ready + + Args: + data: Input data to process + delay_seconds: Simulated processing time + """ + # Simulate long-running work + await asyncio.sleep(delay_seconds) + + # Return processed result + result = f"Processed: {data.upper()} (took {delay_seconds}s)" + return result diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9e9389ac1..b9e089ad3 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,19 +1,26 @@ +from __future__ import annotations + import logging from datetime import timedelta -from typing import Any, Protocol, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +from uuid import uuid4 import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from jsonschema import SchemaError, ValidationError, validate -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, BaseModel, TypeAdapter from typing_extensions import deprecated import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.task import TaskStore from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +if TYPE_CHECKING: + from mcp.shared.request import PendingRequest + DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") @@ -22,7 +29,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... @@ -30,14 +37,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext[ClientSession, Any] ) -> types.ListRootsResult | types.ErrorData: ... @@ -62,7 +69,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -72,7 +79,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( @@ -82,7 +89,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -96,6 +103,7 @@ async def _default_logging_callback( pass +ClientResultT = TypeVar("ClientResultT", BaseModel, types.ClientResult) ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -119,6 +127,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + task_store: TaskStore | None = None, ) -> None: super().__init__( read_stream, @@ -126,6 +135,7 @@ def __init__( types.ServerRequest, types.ServerNotification, read_timeout_seconds=read_timeout_seconds, + task_store=task_store, ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback @@ -149,6 +159,18 @@ async def initialize(self) -> types.InitializeResult: else None ) + # Build tasks capability - only if task store is configured + tasks = None + if self._task_store is not None: + tasks = types.ClientTasksCapability( + requests=types.ClientTasksRequestsCapability( + sampling=types.TaskSamplingCapability(createMessage=True), + elicitation=types.TaskElicitationCapability(create=True), + roots=types.TaskRootsCapability(list=True), + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True), + ) + ) + result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -159,6 +181,7 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, + tasks=tasks, ), clientInfo=self._client_info, ), @@ -322,6 +345,52 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) + def begin_call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + task: types.TaskMetadata | None = None, + meta: dict[str, Any] | None = None, + ) -> PendingRequest[types.CallToolResult]: + """ + Begin a tool call and return a PendingRequest for granular control over task-based execution. + + This is useful when you want to create a task for a long-running tool call and poll for results later. + + Args: + name: The tool name + arguments: Optional tool arguments + read_timeout_seconds: Optional timeout for reading response + progress_callback: Optional callback for progress notifications + task: Optional task metadata for task-based execution + meta: Optional additional metadata + + Returns: + A PendingRequest object that can be used to wait for the result + """ + _meta: types.RequestParams.Meta | None = None + if meta is not None: + _meta = types.RequestParams.Meta(**meta) + + # Automatically add task metadata if not provided + if task is None: + task = types.TaskMetadata(taskId=str(uuid4())) + + return self.begin_send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), + ) + ), + types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + task=task, + ) + async def call_tool( self, name: str, @@ -331,7 +400,11 @@ async def call_tool( *, meta: dict[str, Any] | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """ + Send a tools/call request with optional progress callback support. + + For task-based execution with granular control, use begin_call_tool() instead. + """ _meta: types.RequestParams.Meta | None = None if meta is not None: @@ -495,6 +568,42 @@ async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) + async def get_task(self, task_id: str) -> types.GetTaskResult: + """Get the current status of a task.""" + return await self.send_request( + types.ClientRequest(types.GetTaskRequest(method="tasks/get", params=types.GetTaskParams(taskId=task_id))), + types.GetTaskResult, + ) + + async def get_task_result(self, task_id: str, result_type: type[ClientResultT]) -> ClientResultT: + """Retrieve the result of a completed task.""" + return await self.send_request( + types.ClientRequest( + types.GetTaskPayloadRequest(method="tasks/result", params=types.GetTaskPayloadParams(taskId=task_id)) + ), + result_type, + ) + + async def list_tasks(self, cursor: str | None = None) -> types.ListTasksResult: + """List tasks, optionally starting from a pagination cursor.""" + return await self.send_request( + types.ClientRequest( + types.ListTasksRequest( + method="tasks/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor else None + ) + ), + types.ListTasksResult, + ) + + async def delete_task(self, task_id: str) -> types.EmptyResult: + """Delete a specific task.""" + return await self.send_request( + types.ClientRequest( + types.DeleteTaskRequest(method="tasks/delete", params=types.DeleteTaskParams(taskId=task_id)) + ), + types.EmptyResult, + ) + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, @@ -526,6 +635,98 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques with responder: return await responder.respond(types.ClientResult(root=types.EmptyResult())) + case types.GetTaskRequest(params=params): + # Handle get task requests if task store is available + if self._task_store: + task = await self._task_store.get_task(params.taskId) + if task is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, message="Failed to retrieve task: Task not found" + ) + ) + else: + with responder: + result = types.GetTaskResult( + taskId=task.taskId, + status=task.status, + keepAlive=task.keepAlive, + pollInterval=task.pollInterval, + error=task.error, + _meta={types.RELATED_TASK_META_KEY: {"taskId": params.taskId}}, + ) + await responder.respond(types.ClientResult(result)) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + + case types.GetTaskPayloadRequest(params=params): + # Handle get task result requests if task store is available + if self._task_store: + task = await self._task_store.get_task(params.taskId) + if task is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, message="Failed to retrieve task: Task not found" + ) + ) + elif task.status != "completed": + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"Cannot retrieve result: Task status is '{task.status}', not 'completed'", + ) + ) + else: + result = await self._task_store.get_task_result(params.taskId) + # Add related-task metadata + result_dict = result.model_dump(by_alias=True, mode="json", exclude_none=True) + if "_meta" not in result_dict: + result_dict["_meta"] = {} + result_dict["_meta"][types.RELATED_TASK_META_KEY] = {"taskId": params.taskId} + with responder: + await responder.respond(types.ClientResult.model_validate(result_dict)) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + + case types.ListTasksRequest(params=params): + # Handle list tasks requests if task store is available + if self._task_store: + try: + result = await self._task_store.list_tasks(params.cursor if params else None) + with responder: + await responder.respond( + types.ClientResult( + types.ListTasksResult( + tasks=result["tasks"], # type: ignore[arg-type] + nextCursor=result.get("nextCursor"), # type: ignore[arg-type] + _meta={}, + ) + ) + ) + except Exception as e: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_PARAMS, message=f"Failed to list tasks: {e}") + ) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + + case _: + # Other request types are not expected to be received by the client + pass + async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 719595916..520fc0ab5 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -61,6 +61,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT +from mcp.shared.task import TaskStore from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument @@ -168,6 +169,7 @@ def __init__( # noqa: PLR0913 lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None, auth: AuthSettings | None = None, transport_security: TransportSecuritySettings | None = None, + task_store: TaskStore | None = None, ): self.settings = Settings( debug=debug, @@ -197,6 +199,7 @@ def __init__( # noqa: PLR0913 # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore + task_store=task_store, ) self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9cec31bab..5060c51f9 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -90,6 +90,7 @@ async def main(): from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.task import TaskStore logger = logging.getLogger(__name__) @@ -111,10 +112,12 @@ def __init__( prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False, + tasks_changed: bool = False, ): self.prompts_changed = prompts_changed self.resources_changed = resources_changed self.tools_changed = tools_changed + self.tasks_changed = tasks_changed @asynccontextmanager @@ -142,6 +145,7 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + task_store: TaskStore | None = None, ): self.name = name self.version = version @@ -149,6 +153,7 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan + self.task_store = task_store self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { types.PingRequest: _ping_handler, } @@ -196,6 +201,7 @@ def get_capabilities( tools_capability = None logging_capability = None completions_capability = None + tasks_capability = None # Set prompt capabilities if handler exists if types.ListPromptsRequest in self.request_handlers: @@ -219,6 +225,51 @@ def get_capabilities( if types.CompleteRequest in self.request_handlers: completions_capability = types.CompletionsCapability() + # Set tasks capabilities if task store is configured + if self.task_store is not None: + # Build nested request capabilities based on available handlers + tools_req_cap = None + resources_req_cap = None + prompts_req_cap = None + tasks_ops_cap = None + + # Check for tool capabilities + has_call_tool = types.CallToolRequest in self.request_handlers + has_list_tools = types.ListToolsRequest in self.request_handlers + if has_call_tool or has_list_tools: + tools_req_cap = types.TaskToolsCapability( + call=True if has_call_tool else None, list=True if has_list_tools else None + ) + + # Check for resource capabilities + has_read_resource = types.ReadResourceRequest in self.request_handlers + has_list_resources = types.ListResourcesRequest in self.request_handlers + if has_read_resource or has_list_resources: + resources_req_cap = types.TaskResourcesCapability( + read=True if has_read_resource else None, list=True if has_list_resources else None + ) + + # Check for prompt capabilities + has_get_prompt = types.GetPromptRequest in self.request_handlers + has_list_prompts = types.ListPromptsRequest in self.request_handlers + if has_get_prompt or has_list_prompts: + prompts_req_cap = types.TaskPromptsCapability( + get=True if has_get_prompt else None, list=True if has_list_prompts else None + ) + + # Task operations are always available if task_store is configured + tasks_ops_cap = types.TasksOperationsCapability(get=True, list=True, result=True, delete=True) + + # Build the nested tasks capability + tasks_capability = types.ServerTasksCapability( + requests=types.ServerTasksRequestsCapability( + tools=tools_req_cap, + resources=resources_req_cap, + prompts=prompts_req_cap, + tasks=tasks_ops_cap, + ) + ) + return types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, @@ -226,6 +277,7 @@ def get_capabilities( logging=logging_capability, experimental=experimental_capabilities, completions=completions_capability, + tasks=tasks_capability, ) @property @@ -621,6 +673,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + task_store=self.task_store, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 7a99218fa..e2dff998c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -43,7 +43,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic import AnyUrl, BaseModel import mcp.types as types from mcp.server.models import InitializationOptions @@ -52,6 +52,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: BaseSession, RequestResponder, ) +from mcp.shared.task import TaskStore from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -61,6 +62,7 @@ class InitializationState(Enum): Initialized = 3 +ServerResultT = TypeVar("ServerResultT", BaseModel, types.ServerResult) ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( @@ -86,8 +88,11 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + task_store: TaskStore | None = None, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + super().__init__( + read_stream, write_stream, types.ClientRequest, types.ClientNotification, task_store=task_store + ) self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized ) @@ -102,6 +107,52 @@ def __init__( def client_params(self) -> types.InitializeRequestParams | None: return self._client_params + def _check_tasks_capability( + self, required: types.ClientTasksCapability, client: types.ClientTasksCapability + ) -> bool: + """Check if client supports required tasks capabilities.""" + if required.requests is None: + return True + if client.requests is None: + return False + + req_cap = required.requests + client_req_cap = client.requests + + # Check sampling requests + if req_cap.sampling is not None and ( + client_req_cap.sampling is None + or (req_cap.sampling.createMessage and not client_req_cap.sampling.createMessage) + ): + return False + + # Check elicitation requests + if req_cap.elicitation is not None and ( + client_req_cap.elicitation is None or (req_cap.elicitation.create and not client_req_cap.elicitation.create) + ): + return False + + # Check roots requests + if req_cap.roots is not None and ( + client_req_cap.roots is None or (req_cap.roots.list and not client_req_cap.roots.list) + ): + return False + + # Check tasks operations + if req_cap.tasks is not None: + if client_req_cap.tasks is None: + return False + tasks_checks = [ + not req_cap.tasks.get or client_req_cap.tasks.get, + not req_cap.tasks.list or client_req_cap.tasks.list, + not req_cap.tasks.result or client_req_cap.tasks.result, + not req_cap.tasks.delete or client_req_cap.tasks.delete, + ] + if not all(tasks_checks): + return False + + return True + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" if self._client_params is None: @@ -133,13 +184,39 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False + if capability.tasks is not None: + if client_caps.tasks is None: + return False + if not self._check_tasks_capability(capability.tasks, client_caps.tasks): + return False + return True async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): + async def _received_request( # noqa: PLR0912 + self, responder: RequestResponder[types.ClientRequest, types.ServerResult] + ): + # Handle task creation if task metadata is present + if responder.request_meta and responder.request_meta.task and self._task_store: + task_meta = responder.request_meta.task + # Create the task in the task store + await self._task_store.create_task(task_meta, responder.request_id, responder.request.root) # type: ignore[arg-type] + # Send task created notification with related task metadata + notification_params = types.TaskCreatedNotificationParams( + _meta=types.NotificationParams.Meta( + **{types.RELATED_TASK_META_KEY: types.RelatedTaskMetadata(taskId=task_meta.taskId)} + ) + ) + await self.send_notification( + types.ServerNotification( + types.TaskCreatedNotification(method="notifications/tasks/created", params=notification_params) + ), + related_request_id=responder.request_id, + ) + match responder.request.root: case types.InitializeRequest(params=params): requested_version = params.protocolVersion @@ -167,6 +244,144 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques case types.PingRequest(): # Ping requests are allowed at any time pass + case types.GetTaskRequest(params=params): + # Check if client has announced tasks capability + if self._client_params is None or self._client_params.capabilities.tasks is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_REQUEST, + message="Client has not announced tasks capability", + ) + ) + # Handle get task requests if task store is available + elif self._task_store: + task = await self._task_store.get_task(params.taskId) + if task is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, message="Failed to retrieve task: Task not found" + ) + ) + else: + with responder: + result = types.GetTaskResult( + taskId=task.taskId, + status=task.status, + keepAlive=task.keepAlive, + pollInterval=task.pollInterval, + error=task.error, + _meta={types.RELATED_TASK_META_KEY: {"taskId": params.taskId}}, + ) + await responder.respond(types.ServerResult(result)) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + case types.GetTaskPayloadRequest(params=params): + # Check if client has announced tasks capability + if self._client_params is None or self._client_params.capabilities.tasks is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_REQUEST, + message="Client has not announced tasks capability", + ) + ) + # Handle get task result requests if task store is available + elif self._task_store: + task = await self._task_store.get_task(params.taskId) + if task is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, message="Failed to retrieve task: Task not found" + ) + ) + elif task.status != "completed": + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"Cannot retrieve result: Task status is '{task.status}', not 'completed'", + ) + ) + else: + result = await self._task_store.get_task_result(params.taskId) + # Add related-task metadata + result_dict = result.model_dump(by_alias=True, mode="json", exclude_none=True) + if "_meta" not in result_dict: + result_dict["_meta"] = {} + result_dict["_meta"][types.RELATED_TASK_META_KEY] = {"taskId": params.taskId} + with responder: + await responder.respond(types.ServerResult.model_validate(result_dict)) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + case types.ListTasksRequest(params=params): + # Check if client has announced tasks capability + if self._client_params is None or self._client_params.capabilities.tasks is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_REQUEST, + message="Client has not announced tasks capability", + ) + ) + # Handle list tasks requests if task store is available + elif self._task_store: + try: + result = await self._task_store.list_tasks(params.cursor if params else None) + with responder: + await responder.respond( + types.ServerResult( + types.ListTasksResult( + tasks=result["tasks"], # type: ignore[arg-type] + nextCursor=result.get("nextCursor"), # type: ignore[arg-type] + _meta={}, + ) + ) + ) + except Exception as e: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_PARAMS, message=f"Failed to list tasks: {e}") + ) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) + case types.DeleteTaskRequest(params=params): + # Check if client has announced tasks capability + if self._client_params is None or self._client_params.capabilities.tasks is None: + with responder: + await responder.respond( + types.ErrorData( + code=types.INVALID_REQUEST, + message="Client has not announced tasks capability", + ) + ) + # Handle delete task requests if task store is available + elif self._task_store: + try: + await self._task_store.delete_task(params.taskId) + with responder: + await responder.respond(types.ServerResult(types.EmptyResult(_meta={}))) + except Exception as e: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_PARAMS, message=f"Failed to delete task: {e}") + ) + else: + with responder: + await responder.respond( + types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured") + ) case _: if self._initialization_state != InitializationState.Initialized: raise RuntimeError("Received request before initialization was complete") @@ -324,6 +539,42 @@ async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + async def get_task(self, task_id: str) -> types.GetTaskResult: + """Get the current status of a task.""" + return await self.send_request( + types.ServerRequest(types.GetTaskRequest(method="tasks/get", params=types.GetTaskParams(taskId=task_id))), + types.GetTaskResult, + ) + + async def get_task_result(self, task_id: str, result_type: type[ServerResultT]) -> ServerResultT: + """Retrieve the result of a completed task.""" + return await self.send_request( + types.ServerRequest( + types.GetTaskPayloadRequest(method="tasks/result", params=types.GetTaskPayloadParams(taskId=task_id)) + ), + result_type, + ) + + async def list_tasks(self, cursor: str | None = None) -> types.ListTasksResult: + """List tasks, optionally starting from a pagination cursor.""" + return await self.send_request( + types.ServerRequest( + types.ListTasksRequest( + method="tasks/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor else None + ) + ), + types.ListTasksResult, + ) + + async def delete_task(self, task_id: str) -> types.EmptyResult: + """Delete a specific task.""" + return await self.send_request( + types.ServerRequest( + types.DeleteTaskRequest(method="tasks/delete", params=types.DeleteTaskParams(taskId=task_id)) + ), + types.EmptyResult, + ) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/shared/request.py b/src/mcp/shared/request.py new file mode 100644 index 000000000..a521b1d78 --- /dev/null +++ b/src/mcp/shared/request.py @@ -0,0 +1,188 @@ +"""Pending request handling for task-based execution.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import anyio +from pydantic import BaseModel + +from mcp.shared.exceptions import McpError +from mcp.shared.task import is_terminal +from mcp.types import INVALID_PARAMS, GetTaskResult + +if TYPE_CHECKING: + from mcp.shared.session import BaseSession + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + +DEFAULT_POLLING_INTERVAL = 5.0 # 5 seconds + + +@dataclass +class TaskHandlerOptions: + """Options for handling task status updates during result polling.""" + + on_task_created: Callable[[], Awaitable[None]] | None = None + """Callback invoked when the task is created.""" + + on_task_status: Callable[[GetTaskResult], Awaitable[None]] | None = None + """Callback invoked each time task status is polled.""" + + +async def _default_handler(_: Any = None) -> None: + """Default no-op handler.""" + pass + + +class PendingRequest(Generic[ReceiveResultT]): + """ + Represents a pending request that may involve task-based execution. + + This class provides methods to wait for the result of a request, + with optional task polling and status callbacks. + """ + + def __init__( + self, + session: "BaseSession[Any, Any, Any, Any, Any]", + task_created_handle: Awaitable[None], + result_handle: Awaitable[ReceiveResultT], + result_type: type[ReceiveResultT], + task_id: str | None = None, + ) -> None: + """ + Initialize a PendingRequest. + + Args: + session: The session to use for task queries + task_created_handle: Awaitable that completes when task is created + result_handle: Awaitable that completes with the request result + task_id: Optional task ID if this is a task-based request + """ + self.session = session + self.task_created_handle = task_created_handle + self.result_handle = result_handle + self.result_type = result_type + self.task_id = task_id + + async def result(self, options: TaskHandlerOptions | None = None) -> ReceiveResultT: + """ + Wait for a result, calling callbacks if provided and a task was created. + + Args: + options: Optional callbacks for task creation and status updates + + Returns: + The result of the request + + Raises: + Any exception raised during request execution or task polling + """ + options = options or TaskHandlerOptions() + on_task_created = options.on_task_created or _default_handler + on_task_status = options.on_task_status or _default_handler + + if self.task_id is None: + # No task ID provided, just block for the result + return await self.result_handle + + # Race between task-based polling and direct result + # Whichever completes first (or fails last) is returned + exceptions: list[Exception] = [] + completed = 0 + result: ReceiveResultT | None = None + result_event = anyio.Event() + + async def wrapper(task: Callable[[], Awaitable[ReceiveResultT]]): + nonlocal result, completed + try: + value = await task() + if not result_event.is_set(): + result = value + result_event.set() # Task completed successfully + except Exception as e: + exceptions.append(e) + finally: + completed += 1 + if completed == 2 and not result_event.is_set(): + # All tasks completed, none succeeded + result_event.set() + + async def _wait_for_result_task() -> ReceiveResultT: + assert self.task_id + return await self._task_handler(self.task_id, on_task_status) + + async def _wait_for_task_creation() -> None: + # Wait for task creation notification + await self.task_created_handle + await on_task_created() + + async with anyio.create_task_group() as tg: + tg.start_soon(_wait_for_task_creation) + tg.start_soon(wrapper, _wait_for_result_task) + tg.start_soon(wrapper, self._wait_for_result) + + # Wait for first success + await result_event.wait() + + # Wait for first success or all completions + await result_event.wait() + + # If we got a result, cancel remaining tasks + if result is not None: + tg.cancel_scope.cancel() + + # If no result but we have exceptions, raise them + if result is None: + if len(exceptions) == 1: + raise exceptions[0] + else: + raise RuntimeError("All tasks failed", exceptions) + + return result + + async def _wait_for_result(self) -> ReceiveResultT: + return await self.result_handle + + async def _task_handler( + self, + task_id: str, + on_task_status: Callable[[GetTaskResult], Awaitable[None]], + ) -> ReceiveResultT: + """ + Encapsulate polling for a result, calling on_task_status after querying the task. + + Args: + task_id: The task ID to poll + on_task_status: Callback invoked on each status poll + + Returns: + The result of the task + + Raises: + Exception: If task polling or result retrieval fails + """ + # Poll for completion + task: GetTaskResult + poll_interval = DEFAULT_POLLING_INTERVAL * 1000.0 + while True: + try: + task = await self.session.get_task(task_id) + except McpError as e: + if e.error.code == INVALID_PARAMS: + # Task may not exist yet + await anyio.sleep(poll_interval / 1000.0) + continue + raise + await on_task_status(task) + + if is_terminal(task.status): + break + + # Wait before polling again + poll_interval = task.pollInterval if task.pollInterval is not None else DEFAULT_POLLING_INTERVAL * 1000 + await anyio.sleep(poll_interval / 1000.0) + + # Retrieve and return the result + return await self.session.get_task_result(task_id, self.result_type) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4e774984d..8c917452f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar import anyio import httpx @@ -13,26 +13,37 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.task import TaskStore from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, + RELATED_TASK_META_KEY, + TASK_META_KEY, CancelledNotification, ClientNotification, ClientRequest, ClientResult, ErrorData, + GetTaskResult, JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ListTasksResult, ProgressNotification, + RelatedTaskMetadata, RequestParams, ServerNotification, ServerRequest, ServerResult, + TaskCreatedNotification, + TaskMetadata, ) +if TYPE_CHECKING: + from mcp.shared.request import PendingRequest + SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) @@ -177,6 +188,9 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _pending_task_creations: dict[str, anyio.Event] + _request_id_to_task_id: dict[RequestId, str] + _task_store: TaskStore | None def __init__( self, @@ -186,6 +200,7 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + task_store: TaskStore | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -196,6 +211,9 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._pending_task_creations = {} + self._request_id_to_task_id = {} + self._task_store = task_store self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -224,6 +242,8 @@ async def send_request( request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, + task: TaskMetadata | None = None, + related_task: RelatedTaskMetadata | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -232,6 +252,15 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. + + Args: + request: The request to send + result_type: The expected result type + request_read_timeout_seconds: Optional timeout for reading response + metadata: Optional message metadata + progress_callback: Optional callback for progress notifications + task: Optional task metadata for task-based execution + related_task: Optional related task metadata """ request_id = self._request_id self._request_id = request_id + 1 @@ -251,6 +280,26 @@ async def send_request( # Store the callback for this request self._progress_callbacks[request_id] = progress_callback + # Inject task metadata if provided + if task is not None: + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"][TASK_META_KEY] = task.model_dump(by_alias=True, exclude_none=True) + # Track this request's task ID + self._request_id_to_task_id[request_id] = task.taskId + + # Inject related task metadata if provided + if related_task is not None: + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"][RELATED_TASK_META_KEY] = related_task.model_dump( + by_alias=True, exclude_none=True + ) + try: jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", @@ -290,9 +339,69 @@ async def send_request( finally: self._response_streams.pop(request_id, None) self._progress_callbacks.pop(request_id, None) + # Clean up task tracking + task_id = self._request_id_to_task_id.pop(request_id, None) + if task_id: + self._pending_task_creations.pop(task_id, None) await response_stream.aclose() await response_stream_reader.aclose() + def begin_send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + task: TaskMetadata | None = None, + related_task: RelatedTaskMetadata | None = None, + ) -> "PendingRequest[ReceiveResultT]": + """ + Begin a request and return a PendingRequest for granular control over task-based execution. + + This is useful when you want to create a task for a long-running request and poll for results later. + + Args: + request: The request to send + result_type: The expected result type + request_read_timeout_seconds: Optional timeout for reading response + metadata: Optional message metadata + progress_callback: Optional callback for progress notifications + task: Optional task metadata for task-based execution + related_task: Optional related task metadata + + Returns: + A PendingRequest object that can be used to wait for the result + """ + from mcp.shared.request import PendingRequest + + # Create an event for task creation notification if task is provided + task_created_event = anyio.Event() + if task: + self._pending_task_creations[task.taskId] = task_created_event + + # Create the actual request coroutine + result_coro = self.send_request( + request, + result_type, + request_read_timeout_seconds, + metadata, + progress_callback, + task, + related_task, + ) + + async def wait_for_task_creation() -> None: + await task_created_event.wait() + + return PendingRequest( + session=self, + task_created_handle=wait_for_task_creation(), + result_handle=result_coro, + result_type=result_type, + task_id=task.taskId if task else None, + ) + async def send_notification( self, notification: SendNotificationT, @@ -381,28 +490,14 @@ async def _receive_loop(self) -> None: ) # Handle cancellation notifications if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() + await self._handle_cancellation_notification(notification.root.params.requestId) else: # Handle progress notifications callback if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - except Exception as e: - logging.error( - "Progress callback raised an exception: %s", - e, - ) + await self._handle_progress_notification(notification.root) + # Handle task created notifications + elif isinstance(notification.root, TaskCreatedNotification): + await self._handle_task_created_notification(notification.root) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: @@ -450,6 +545,40 @@ async def _received_request(self, responder: RequestResponder[ReceiveRequestT, S forwarded on to the message stream. """ + async def _handle_cancellation_notification(self, cancelled_id: RequestId) -> None: + """Handle a cancellation notification for a request.""" + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + # If this request had a task, mark it as cancelled in storage + task_id: str | None = self._request_id_to_task_id.get(cancelled_id) + if task_id and self._task_store: + try: + await self._task_store.update_task_status(task_id, "cancelled") + except Exception as e: + logging.error(f"Failed to cancel task {task_id}: {e}") + + async def _handle_progress_notification(self, notification: ProgressNotification) -> None: + """Handle a progress notification by calling the registered callback.""" + progress_token = notification.params.progressToken + # If there is a progress callback for this token, call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + try: + await callback( + notification.params.progress, + notification.params.total, + notification.params.message, + ) + except Exception as e: + logging.error("Progress callback raised an exception: %s", e) + + async def _handle_task_created_notification(self, notification: TaskCreatedNotification) -> None: + """Handle a task created notification by signaling pending task creations.""" + if notification.params.meta and notification.params.meta.related_task: + task_id = notification.params.meta.related_task.taskId + if task_id in self._pending_task_creations: + self._pending_task_creations[task_id].set() + async def _received_notification(self, notification: ReceiveNotificationT) -> None: """ Can be overridden by subclasses to handle a notification without needing @@ -468,6 +597,18 @@ async def send_progress_notification( processed. """ + async def get_task(self, task_id: str) -> GetTaskResult: + """Get the current status of a task.""" + ... + + async def get_task_result(self, task_id: str, result_type: type[ReceiveResultT]) -> ReceiveResultT: + """Retrieve the result of a completed task.""" + ... + + async def list_tasks(self, cursor: str | None = None) -> ListTasksResult: + """List tasks, optionally starting from a pagination cursor.""" + ... + async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, diff --git a/src/mcp/shared/task.py b/src/mcp/shared/task.py new file mode 100644 index 000000000..6d304909d --- /dev/null +++ b/src/mcp/shared/task.py @@ -0,0 +1,128 @@ +"""Task storage interface and utilities for MCP task-based execution.""" + +from abc import ABC, abstractmethod +from typing import Any, Literal + +from mcp.types import Request, RequestId, Result, Task, TaskMetadata + +TaskStatus = Literal["submitted", "working", "input_required", "completed", "failed", "cancelled", "unknown"] + + +class TaskStore(ABC): + """ + Interface for storing and retrieving task state and results. + + Similar to Transport, this allows pluggable task storage implementations + (in-memory, database, distributed cache, etc.). + """ + + @abstractmethod + async def create_task(self, task: TaskMetadata, request_id: RequestId, request: Request[Any, Any]) -> None: + """ + Create a new task with the given metadata and original request. + + Args: + task: The task creation metadata from the request + request_id: The JSON-RPC request ID + request: The original request that triggered task creation + """ + ... + + @abstractmethod + async def get_task(self, task_id: str) -> Task | None: + """ + Get the current status of a task. + + Args: + task_id: The task identifier + + Returns: + The task state including status, keepAlive, pollInterval, and optional error, + or None if task not found + """ + ... + + @abstractmethod + async def store_task_result(self, task_id: str, result: Result) -> None: + """ + Store the result of a completed task. + + Args: + task_id: The task identifier + result: The result to store + """ + ... + + @abstractmethod + async def get_task_result(self, task_id: str) -> Result: + """ + Retrieve the stored result of a task. + + Args: + task_id: The task identifier + + Returns: + The stored result + + Raises: + Exception: If task not found or has no result + """ + ... + + @abstractmethod + async def update_task_status(self, task_id: str, status: TaskStatus, error: str | None = None) -> None: + """ + Update a task's status (e.g., to 'cancelled', 'failed', 'completed'). + + Args: + task_id: The task identifier + status: The new status + error: Optional error message if status is 'failed' or 'cancelled' + """ + ... + + @abstractmethod + async def list_tasks(self, cursor: str | None = None) -> dict[str, list[Task] | str | None]: + """ + List tasks, optionally starting from a pagination cursor. + + Args: + cursor: Optional cursor for pagination + + Returns: + A dictionary containing: + - 'tasks': list of Task objects + - 'nextCursor': optional string for next page (None if no more pages) + + Raises: + Exception: If cursor is invalid + """ + ... + + @abstractmethod + async def delete_task(self, task_id: str) -> None: + """ + Delete a task from storage. + + Args: + task_id: The task identifier + + Raises: + Exception: If task not found + """ + ... + + +def is_terminal(status: TaskStatus) -> bool: + """ + Check if a task status represents a terminal state. + + Terminal states are those where the task has finished and will not change. + + Args: + status: The task status to check + + Returns: + True if the status is terminal (completed, failed, cancelled, or unknown) + """ + return status in ("completed", "failed", "cancelled", "unknown") diff --git a/src/mcp/types.py b/src/mcp/types.py index 871322740..3fbc9b717 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -33,6 +33,9 @@ """ DEFAULT_NEGOTIATED_VERSION = "2025-03-26" +TASK_META_KEY = "modelcontextprotocol.io/task" +RELATED_TASK_META_KEY = "modelcontextprotocol.io/related-task" + ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] @@ -40,8 +43,36 @@ AnyFunction: TypeAlias = Callable[..., Any] +class TaskMetadata(BaseModel): + """Task creation metadata, used to ask that the server create a task to represent a request.""" + + taskId: str + """The task ID to use as a reference to the created task.""" + + keepAlive: int | None = None + """Time in milliseconds to ask to keep task results available after completion. Only used with taskId.""" + + model_config = ConfigDict(extra="allow") + + +class RelatedTaskMetadata(BaseModel): + """Task association metadata, used to signal which task a message originated from.""" + + taskId: str + """The task ID this message is related to.""" + + model_config = ConfigDict(extra="allow") + + class RequestParams(BaseModel): class Meta(BaseModel): + """ + Request metadata that can contain various optional fields. + + Includes typed access to task-related metadata fields that are serialized + with their full keys in the wire format. + """ + progressToken: ProgressToken | None = None """ If specified, the caller requests out-of-band progress notifications for @@ -50,6 +81,12 @@ class Meta(BaseModel): notifications. The receiver is not obligated to provide these notifications. """ + task: TaskMetadata | None = Field(alias=TASK_META_KEY, default=None) + """Task creation metadata for task-based execution.""" + + related_task: RelatedTaskMetadata | None = Field(alias=RELATED_TASK_META_KEY, default=None) + """Related task metadata for linking requests to parent tasks.""" + model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) @@ -65,6 +102,16 @@ class PaginatedRequestParams(RequestParams): class NotificationParams(BaseModel): class Meta(BaseModel): + """ + Notification metadata that can contain various optional fields. + + Includes typed access to related task metadata that is serialized + with its full key in the wire format. + """ + + related_task: RelatedTaskMetadata | None = Field(alias=RELATED_TASK_META_KEY, default=None) + """Related task metadata for linking notifications to parent tasks.""" + model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) @@ -262,6 +309,66 @@ class ElicitationCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksOperationsCapability(BaseModel): + """Capability for task operations shared by client and server.""" + + get: bool | None = None + """Whether tasks/get is supported.""" + list: bool | None = None + """Whether tasks/list is supported.""" + result: bool | None = None + """Whether tasks/result is supported.""" + delete: bool | None = None + """Whether tasks/delete is supported.""" + model_config = ConfigDict(extra="allow") + + +class TaskSamplingCapability(BaseModel): + """Capability for sampling requests within tasks.""" + + createMessage: bool | None = None + """Whether sampling/createMessage can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class TaskElicitationCapability(BaseModel): + """Capability for elicitation requests within tasks.""" + + create: bool | None = None + """Whether elicitation/create can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class TaskRootsCapability(BaseModel): + """Capability for roots requests within tasks.""" + + list: bool | None = None + """Whether roots/list can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class ClientTasksRequestsCapability(BaseModel): + """Requests that the client can make when executing tasks.""" + + sampling: TaskSamplingCapability | None = None + """Sampling requests capability during task execution.""" + elicitation: TaskElicitationCapability | None = None + """Elicitation requests capability during task execution.""" + roots: TaskRootsCapability | None = None + """Roots requests capability during task execution.""" + tasks: TasksOperationsCapability | None = None + """Task operations capability.""" + model_config = ConfigDict(extra="allow") + + +class ClientTasksCapability(BaseModel): + """Capability for client task operations.""" + + requests: ClientTasksRequestsCapability | None = None + """Requests the client can make when executing tasks.""" + model_config = ConfigDict(extra="allow") + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -273,6 +380,8 @@ class ClientCapabilities(BaseModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" + tasks: ClientTasksCapability | None = None + """Present if the client supports task operations.""" model_config = ConfigDict(extra="allow") @@ -314,6 +423,58 @@ class CompletionsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TaskToolsCapability(BaseModel): + """Capability for tools requests within tasks.""" + + call: bool | None = None + """Whether tools/call can be requested during task execution.""" + list: bool | None = None + """Whether tools/list can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class TaskResourcesCapability(BaseModel): + """Capability for resources requests within tasks.""" + + read: bool | None = None + """Whether resources/read can be requested during task execution.""" + list: bool | None = None + """Whether resources/list can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class TaskPromptsCapability(BaseModel): + """Capability for prompts requests within tasks.""" + + get: bool | None = None + """Whether prompts/get can be requested during task execution.""" + list: bool | None = None + """Whether prompts/list can be requested during task execution.""" + model_config = ConfigDict(extra="allow") + + +class ServerTasksRequestsCapability(BaseModel): + """Requests that the server can provide when executing tasks.""" + + tools: TaskToolsCapability | None = None + """Tools requests capability during task execution.""" + resources: TaskResourcesCapability | None = None + """Resources requests capability during task execution.""" + prompts: TaskPromptsCapability | None = None + """Prompts requests capability during task execution.""" + tasks: TasksOperationsCapability | None = None + """Task operations capability.""" + model_config = ConfigDict(extra="allow") + + +class ServerTasksCapability(BaseModel): + """Capability for server task operations.""" + + requests: ServerTasksRequestsCapability | None = None + """Requests the server can provide when executing tasks.""" + model_config = ConfigDict(extra="allow") + + class ServerCapabilities(BaseModel): """Capabilities that a server may support.""" @@ -329,6 +490,8 @@ class ServerCapabilities(BaseModel): """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" + tasks: ServerTasksCapability | None = None + """Present if the server supports task operations.""" model_config = ConfigDict(extra="allow") @@ -416,6 +579,143 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not params: ProgressNotificationParams +# Tasks + + +class Task(BaseModel): + """A pollable state object associated with a request.""" + + taskId: str + """The unique identifier for this task.""" + + status: Literal["submitted", "working", "input_required", "completed", "failed", "cancelled", "unknown"] + """ + The current status of the task: + - submitted: Task has been created and queued + - working: Task is actively being processed + - input_required: Task is waiting for additional input (e.g., from elicitation) + - completed: Task finished successfully + - failed: Task encountered an error + - cancelled: Task was cancelled by the client + - unknown: Task status could not be determined (terminal state, rarely occurs) + """ + + keepAlive: int | None + """ + Time in milliseconds to keep task results available after completion. + None means the task will not be automatically cleaned up. + """ + + pollInterval: int | None = None + """Recommended polling frequency in milliseconds for checking task status.""" + + error: str | None = None + """Error message if status is 'failed' or 'cancelled'.""" + + model_config = ConfigDict(extra="allow") + + +class TaskCreatedNotificationParams(NotificationParams): + """Parameters for task created notification.""" + + model_config = ConfigDict(extra="allow") + + +class TaskCreatedNotification(Notification[TaskCreatedNotificationParams, Literal["notifications/tasks/created"]]): + """An out-of-band notification used to inform the receiver of a task being created.""" + + method: Literal["notifications/tasks/created"] = "notifications/tasks/created" + params: TaskCreatedNotificationParams + + +class GetTaskParams(RequestParams): + """Parameters for getting task status.""" + + taskId: str + """The task identifier.""" + + model_config = ConfigDict(extra="allow") + + +class GetTaskRequest(Request[GetTaskParams, Literal["tasks/get"]]): + """A request to get the state of a specific task.""" + + method: Literal["tasks/get"] = "tasks/get" + params: GetTaskParams + + +class GetTaskResult(Result): + """The response to a tasks/get request.""" + + taskId: str + """The unique identifier for this task.""" + + status: Literal["submitted", "working", "input_required", "completed", "failed", "cancelled", "unknown"] + """The current status of the task.""" + + keepAlive: int | None = None + """Time in milliseconds to keep task results available after completion.""" + + pollInterval: int | None = None + """Recommended polling frequency in milliseconds for checking task status.""" + + error: str | None = None + """Error message if status is 'failed' or 'cancelled'.""" + + model_config = ConfigDict(extra="allow") + + +class GetTaskPayloadParams(RequestParams): + """Parameters for getting task result payload.""" + + taskId: str + """The task identifier.""" + + model_config = ConfigDict(extra="allow") + + +class GetTaskPayloadRequest(Request[GetTaskPayloadParams, Literal["tasks/result"]]): + """A request to get the result of a specific task.""" + + method: Literal["tasks/result"] = "tasks/result" + params: GetTaskPayloadParams + + +class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): + """A request to list tasks.""" + + method: Literal["tasks/list"] = "tasks/list" + params: PaginatedRequestParams | None = None + + +class ListTasksResult(Result): + """The response to a tasks/list request.""" + + tasks: list[Task] + """List of tasks.""" + + nextCursor: Cursor | None = None + """Opaque token for pagination.""" + + model_config = ConfigDict(extra="allow") + + +class DeleteTaskParams(RequestParams): + """Parameters for deleting a task.""" + + taskId: str + """The task identifier.""" + + model_config = ConfigDict(extra="allow") + + +class DeleteTaskRequest(Request[DeleteTaskParams, Literal["tasks/delete"]]): + """A request to delete a specific task.""" + + method: Literal["tasks/delete"] = "tasks/delete" + params: DeleteTaskParams + + class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): """Sent from the client to request a list of resources the server has.""" @@ -865,6 +1165,14 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + + taskHint: bool | None = None + """ + If true, this tool is expected to support task-augmented execution. + This allows clients to handle long-running operations through polling + the task system. + Default: false + """ model_config = ConfigDict(extra="allow") @@ -1262,13 +1570,23 @@ class ClientRequest( | UnsubscribeRequest | CallToolRequest | ListToolsRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | DeleteTaskRequest ] ): pass class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] + RootModel[ + CancelledNotification + | ProgressNotification + | InitializedNotification + | RootsListChangedNotification + | TaskCreatedNotification + ] ): pass @@ -1311,11 +1629,24 @@ class ElicitResult(Result): """ -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): +class ClientResult( + RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult | GetTaskResult | ListTasksResult] +): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): +class ServerRequest( + RootModel[ + PingRequest + | CreateMessageRequest + | ListRootsRequest + | ElicitRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | DeleteTaskRequest + ] +): pass @@ -1328,6 +1659,7 @@ class ServerNotification( | ResourceListChangedNotification | ToolListChangedNotification | PromptListChangedNotification + | TaskCreatedNotification ] ): pass @@ -1345,6 +1677,8 @@ class ServerResult( | ReadResourceResult | CallToolResult | ListToolsResult + | GetTaskResult + | ListTasksResult ] ): pass diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f2135e455..3e2140043 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -504,6 +504,141 @@ async def mock_server(): assert received_capabilities.roots.listChanged is True +@pytest.mark.anyio +async def test_client_capabilities_without_task_store(): + """Test that client does not announce tasks capability without task_store""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is not announced without task_store + assert received_capabilities is not None + assert received_capabilities.tasks is None + + +@pytest.mark.anyio +async def test_client_capabilities_with_task_store(): + """Test that client announces tasks capability with task_store""" + from examples.shared.in_memory_task_store import InMemoryTaskStore + + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + task_store = InMemoryTaskStore() + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + task_store=task_store, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is announced with task_store + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) + assert received_capabilities.tasks.requests is not None + # Verify all expected request capabilities are present + assert received_capabilities.tasks.requests.sampling is not None + assert received_capabilities.tasks.requests.elicitation is not None + assert received_capabilities.tasks.requests.roots is not None + assert received_capabilities.tasks.requests.tasks is not None + + @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: dict[str, Any] | None): diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 618d7bc61..1e3475027 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -20,6 +20,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from examples.shared.in_memory_task_store import InMemoryTaskStore from examples.snippets.servers import ( basic_prompt, basic_resource, @@ -30,6 +31,7 @@ notifications, sampling, structured_output, + task_based_tool, tool_progress, ) from mcp.client.session import ClientSession @@ -37,6 +39,7 @@ from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage +from mcp.shared.request import TaskHandlerOptions from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, @@ -45,6 +48,7 @@ ElicitRequestParams, ElicitResult, GetPromptResult, + GetTaskResult, InitializeResult, LoggingMessageNotification, LoggingMessageNotificationParams, @@ -124,6 +128,8 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No mcp = fastmcp_quickstart.mcp elif module_name == "structured_output": mcp = structured_output.mcp + elif module_name == "task_based_tool": + mcp = task_based_tool.mcp else: raise ImportError(f"Unknown module: {module_name}") @@ -215,6 +221,7 @@ async def sampling_callback( context: RequestContext[ClientSession, None], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" + del context, params # Unused but required by protocol return CreateMessageResult( role="assistant", content=TextContent( @@ -227,6 +234,7 @@ async def sampling_callback( async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" + del context # Unused but required by protocol # For restaurant booking test if "No tables available" in params.message: return ElicitResult( @@ -686,3 +694,78 @@ async def test_structured_output(server_transport: str, server_url: str) -> None assert "sunny" in result_text # condition assert "45" in result_text # humidity assert "5.2" in result_text # wind_speed + + +# Test task-based execution +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_transport", + [ + ("task_based_tool", "sse"), + ("task_based_tool", "streamable-http"), + ], + indirect=True, +) +async def test_task_based_tool(server_transport: str, server_url: str) -> None: + """Test task-based execution with begin_call_tool.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + # Create a task store for the client to support task-based execution + task_store = InMemoryTaskStore() + async with ClientSession(read_stream, write_stream, task_store=task_store) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Task-Based Tool Example" + + # Track callback invocations + task_created_called = False + task_status_updates: list[str] = [] + + async def on_task_created() -> None: + nonlocal task_created_called + task_created_called = True + + async def on_task_status(task_result: GetTaskResult) -> None: + task_status_updates.append(task_result.status) + + # Test begin_call_tool for task-based execution + pending_request = session.begin_call_tool( + "long_running_computation", + arguments={"data": "test_data", "delay_seconds": 1}, + ) + + # Wait for the result with callbacks + tool_result = await pending_request.result( + TaskHandlerOptions(on_task_created=on_task_created, on_task_status=on_task_status) + ) + + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed: TEST_DATA" in tool_result.content[0].text + assert "1s" in tool_result.content[0].text or "1.0s" in tool_result.content[0].text + + # Verify callbacks were invoked + assert task_created_called, "on_task_created callback was not invoked" + assert len(task_status_updates) > 0, "on_task_status callback was never invoked" + + # Due to the race between direct result and task polling: + # - With 1s delay and 5s default polling interval, direct result usually wins + # - We should see at least one status update (typically "submitted") + # - We may or may not see "completed" depending on timing + + # Verify we got at least one valid status + valid_statuses = ["submitted", "working", "completed"] + assert all(status in valid_statuses for status in task_status_updates), ( + f"Got invalid status in updates: {task_status_updates}" + ) + + # If direct result won the race, we may only see submitted/working; if polling won, we'll see completed + last_status = task_status_updates[-1] + assert last_status in ["submitted", "working", "completed"], ( + f"Unexpected last status: {last_status} from {task_status_updates}" + ) diff --git a/tests/shared/test_pending_request.py b/tests/shared/test_pending_request.py new file mode 100644 index 000000000..07a2bc5b3 --- /dev/null +++ b/tests/shared/test_pending_request.py @@ -0,0 +1,623 @@ +"""Unit tests for PendingRequest implementation.""" + +import asyncio +from collections.abc import Awaitable +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from mcp.shared.request import DEFAULT_POLLING_INTERVAL, PendingRequest, TaskHandlerOptions +from mcp.types import CallToolResult, GetTaskResult, TextContent + + +@pytest.fixture +def mock_session() -> MagicMock: + """Create a mock session for testing.""" + session = MagicMock() + session.get_task = AsyncMock() + session.get_task_result = AsyncMock() + return session + + +@pytest.fixture +def sample_result() -> CallToolResult: + """Create a sample result.""" + return CallToolResult(content=[TextContent(type="text", text="Success!")]) + + +class TestPendingRequestWithoutTask: + """Tests for PendingRequest without task ID (direct execution).""" + + @pytest.mark.anyio + async def test_result_without_task_id(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that request without task ID returns result directly.""" + + async def get_result(): + return sample_result + + # Create a never-completing coroutine for task_created_handle + # It won't be awaited since task_id is None + never_completes_future: Awaitable[None] = asyncio.Future() + + pending = PendingRequest( + session=mock_session, + task_created_handle=never_completes_future, + result_handle=get_result(), + result_type=CallToolResult, + task_id=None, + ) + + result = await pending.result() + assert result == sample_result + # Task methods should never be called + mock_session.get_task.assert_not_called() + mock_session.get_task_result.assert_not_called() + + @pytest.mark.anyio + async def test_callbacks_not_invoked_without_task_id(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that callbacks are not invoked when no task ID is provided.""" + + async def get_result(): + return sample_result + + # Create a never-completing future for task_created_handle + # It won't be awaited since task_id is None + never_completes_future: Awaitable[None] = asyncio.Future() + + on_task_created = AsyncMock() + on_task_status = AsyncMock() + + pending = PendingRequest( + session=mock_session, + task_created_handle=never_completes_future, + result_handle=get_result(), + result_type=CallToolResult, + task_id=None, + ) + + await pending.result(TaskHandlerOptions(on_task_created=on_task_created, on_task_status=on_task_status)) + + # Callbacks should not be invoked + on_task_created.assert_not_called() + on_task_status.assert_not_called() + + +class TestPendingRequestTaskPolling: + """Tests for PendingRequest with task-based execution.""" + + @pytest.mark.anyio + async def test_task_polling_basic_flow(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test basic task polling flow: submitted -> working -> completed.""" + task_created_event = asyncio.Event() + + async def task_created(): + await task_created_event.wait() + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + # Set up task status progression + mock_session.get_task.side_effect = [ + GetTaskResult(taskId="task-1", status="submitted", pollInterval=100), + GetTaskResult(taskId="task-1", status="working", pollInterval=100), + GetTaskResult(taskId="task-1", status="completed", pollInterval=100), + ] + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-1", + ) + + # Trigger task created notification after a short delay + async def notify_created(): + await asyncio.sleep(0.01) + task_created_event.set() + + asyncio.create_task(notify_created()) + + result = await pending.result() + assert result == sample_result + assert mock_session.get_task.call_count == 3 + mock_session.get_task_result.assert_called_once_with("task-1", CallToolResult) + + @pytest.mark.anyio + async def test_callback_invocation_timing(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that callbacks are invoked at the correct times.""" + task_created_event = asyncio.Event() + + async def task_created(): + await task_created_event.wait() + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + on_task_created = AsyncMock() + on_task_status = AsyncMock() + + # Set up task status progression + task_statuses = [ + GetTaskResult(taskId="task-2", status="submitted", pollInterval=50), + GetTaskResult(taskId="task-2", status="completed", pollInterval=50), + ] + mock_session.get_task.side_effect = task_statuses + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-2", + ) + + # Trigger task created notification + async def notify_created(): + await asyncio.sleep(0.01) + task_created_event.set() + + asyncio.create_task(notify_created()) + + await pending.result(TaskHandlerOptions(on_task_created=on_task_created, on_task_status=on_task_status)) + + # Verify callback invocations + on_task_created.assert_called_once() + assert on_task_status.call_count == 2 + # Check that status callback was called with each status + assert on_task_status.call_args_list[0][0][0].status == "submitted" + assert on_task_status.call_args_list[1][0][0].status == "completed" + + @pytest.mark.anyio + async def test_polling_interval_respects_poll_frequency( + self, mock_session: MagicMock, sample_result: CallToolResult + ): + """Test that polling interval respects pollInterval from task.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + # Track polling timestamps + poll_times: list[float] = [] + + async def mock_get_task(task_id: str): + poll_times.append(asyncio.get_event_loop().time()) + if len(poll_times) == 1: + return GetTaskResult(taskId=task_id, status="submitted", pollInterval=100) # 100ms + else: + return GetTaskResult(taskId=task_id, status="completed", pollInterval=100) + + mock_session.get_task.side_effect = mock_get_task + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-3", + ) + + await pending.result() + + # Verify polling interval (should be ~100ms between polls) + assert len(poll_times) == 2 + interval = poll_times[1] - poll_times[0] + # Allow some tolerance for timing variance + assert 0.08 < interval < 0.15 + + @pytest.mark.anyio + async def test_polling_uses_default_interval_when_not_specified( + self, mock_session: MagicMock, sample_result: CallToolResult + ): + """Test that default polling interval is used when pollInterval is None.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + poll_times: list[float] = [] + + async def mock_get_task(task_id: str): + poll_times.append(asyncio.get_event_loop().time()) + if len(poll_times) == 1: + return GetTaskResult(taskId=task_id, status="submitted", pollInterval=None) + else: + return GetTaskResult(taskId=task_id, status="completed", pollInterval=None) + + mock_session.get_task.side_effect = mock_get_task + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-4", + ) + + await pending.result() + + # Verify default polling interval (DEFAULT_POLLING_INTERVAL = 5 seconds) + assert len(poll_times) == 2 + interval = poll_times[1] - poll_times[0] + # Default is 5 seconds, allow some tolerance + assert DEFAULT_POLLING_INTERVAL - 0.5 < interval < DEFAULT_POLLING_INTERVAL + 0.5 + + +class TestPendingRequestRaceCondition: + """Tests for race condition handling between task polling and direct result.""" + + @pytest.mark.anyio + async def test_direct_result_wins_race(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that direct result path can complete before task polling.""" + result_event = asyncio.Event() + task_created_event = asyncio.Event() + + async def task_created(): + await task_created_event.wait() + + async def get_result(): + await result_event.wait() + return sample_result + + # Set up task polling to be slow + async def slow_get_task(task_id: str): + await asyncio.sleep(0.2) + return GetTaskResult(taskId=task_id, status="submitted", pollInterval=100) + + mock_session.get_task.side_effect = slow_get_task + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=get_result(), + result_type=CallToolResult, + task_id="task-5", + ) + + # Trigger task created and direct result quickly + async def complete_directly(): + await asyncio.sleep(0.01) + task_created_event.set() + await asyncio.sleep(0.01) + result_event.set() + + asyncio.create_task(complete_directly()) + + result = await pending.result() + assert result == sample_result + # get_task_result should not be called when direct path wins + mock_session.get_task_result.assert_not_called() + + @pytest.mark.anyio + async def test_task_polling_wins_race(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that task polling path can complete before direct result.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + # Set up task polling to complete quickly + mock_session.get_task.return_value = GetTaskResult(taskId="task-6", status="completed", pollInterval=100) + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-6", + ) + + result = await pending.result() + assert result == sample_result + mock_session.get_task.assert_called_once_with("task-6") + mock_session.get_task_result.assert_called_once_with("task-6", CallToolResult) + + +class TestPendingRequestErrorHandling: + """Tests for error propagation and handling.""" + + @pytest.mark.anyio + async def test_error_from_direct_result_path(self, mock_session: MagicMock): + """Test that errors from direct result path are propagated.""" + result_event = asyncio.Event() + task_created_event = asyncio.Event() + + async def task_created(): + await task_created_event.wait() + + async def get_result_with_error(): + await result_event.wait() + raise RuntimeError("Direct result failed") + + # Mock get_task to also fail after a delay, so both paths fail + async def slow_get_task(task_id: str): + await asyncio.sleep(0.1) + raise RuntimeError("Task polling also failed") + + mock_session.get_task.side_effect = slow_get_task + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=get_result_with_error(), + result_type=CallToolResult, + task_id="task-7", + ) + + # Trigger task created and error + async def trigger_error(): + await asyncio.sleep(0.01) + task_created_event.set() + await asyncio.sleep(0.01) + result_event.set() + + asyncio.create_task(trigger_error()) + + with pytest.raises(RuntimeError, match="Direct result failed"): + await pending.result() + + @pytest.mark.anyio + async def test_error_from_task_polling_path(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that errors from task polling are propagated.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def slow_result_with_error(): + await asyncio.sleep(0.1) # Takes longer than task polling error + raise RuntimeError("Direct result also failed") + + # Set up task polling to fail immediately + mock_session.get_task.side_effect = RuntimeError("Task polling failed") + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=slow_result_with_error(), + result_type=CallToolResult, + task_id="task-8", + ) + + with pytest.raises(RuntimeError, match="Task polling failed"): + await pending.result() + + @pytest.mark.anyio + async def test_error_from_task_result_retrieval(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that errors from get_task_result are propagated.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def slow_result_with_error(): + await asyncio.sleep(0.1) # Takes longer than result retrieval error + raise RuntimeError("Direct result also failed") + + # Task polling succeeds but result retrieval fails + mock_session.get_task.return_value = GetTaskResult(taskId="task-9", status="completed", pollInterval=100) + mock_session.get_task_result.side_effect = RuntimeError("Failed to retrieve result") + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=slow_result_with_error(), + result_type=CallToolResult, + task_id="task-9", + ) + + with pytest.raises(RuntimeError, match="Failed to retrieve result"): + await pending.result() + + +class TestPendingRequestCancellation: + """Tests for proper cleanup when pending requests are cancelled.""" + + @pytest.mark.anyio + async def test_cancellation_during_polling(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that cancelling result() properly cleans up tasks.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + # Set up task polling to never complete + async def never_complete_get_task(task_id: str): + await asyncio.sleep(10) + return GetTaskResult(taskId=task_id, status="submitted", pollInterval=100) + + mock_session.get_task.side_effect = never_complete_get_task + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-10", + ) + + # Create task and cancel it after a short delay + result_task = asyncio.create_task(pending.result()) + await asyncio.sleep(0.05) + result_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await result_task + + @pytest.mark.anyio + async def test_losing_path_is_cancelled(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test that the losing path in the race is properly cancelled.""" + task_created_event = asyncio.Event() + result_event = asyncio.Event() + + async def task_created(): + await task_created_event.wait() + + async def get_result(): + await result_event.wait() + return sample_result + + # Track if get_task is cancelled + get_task_cancelled = False + + async def cancellable_get_task(task_id: str): + nonlocal get_task_cancelled + try: + await asyncio.sleep(10) # Will be cancelled + return GetTaskResult(taskId=task_id, status="submitted", pollInterval=100) + except asyncio.CancelledError: + get_task_cancelled = True + raise + + mock_session.get_task.side_effect = cancellable_get_task + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=get_result(), + result_type=CallToolResult, + task_id="task-11", + ) + + # Set up direct result to win quickly + async def complete_directly(): + await asyncio.sleep(0.01) + task_created_event.set() + await asyncio.sleep(0.01) + result_event.set() + + asyncio.create_task(complete_directly()) + + result = await pending.result() + assert result == sample_result + + # Give cancellation time to propagate + await asyncio.sleep(0.1) + assert get_task_cancelled + + +class TestPendingRequestStatusTransitions: + """Tests for various task status transition scenarios.""" + + @pytest.mark.anyio + async def test_immediate_completion(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test task that is already completed on first poll.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + mock_session.get_task.return_value = GetTaskResult(taskId="task-12", status="completed", pollInterval=100) + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-12", + ) + + result = await pending.result() + assert result == sample_result + # Should only poll once + mock_session.get_task.assert_called_once_with("task-12") + + @pytest.mark.anyio + async def test_failed_status(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test task that transitions to failed status.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + mock_session.get_task.side_effect = [ + GetTaskResult(taskId="task-13", status="submitted", pollInterval=50), + GetTaskResult(taskId="task-13", status="failed", pollInterval=50, error="Something went wrong"), + ] + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + task_id="task-13", + result_type=CallToolResult, + ) + + # Failed is a terminal state, so it should stop polling and retrieve result + result = await pending.result() + assert result == sample_result + assert mock_session.get_task.call_count == 2 + + @pytest.mark.anyio + async def test_cancelled_status(self, mock_session: MagicMock, sample_result: CallToolResult): + """Test task that transitions to cancelled status.""" + task_created_event = asyncio.Event() + task_created_event.set() + + async def task_created(): + pass + + async def never_completes(): + await asyncio.Future() # Never completes + return sample_result + + mock_session.get_task.side_effect = [ + GetTaskResult(taskId="task-14", status="working", pollInterval=50), + GetTaskResult(taskId="task-14", status="cancelled", pollInterval=50, error="User cancelled"), + ] + mock_session.get_task_result.return_value = sample_result + + pending = PendingRequest( + session=mock_session, + task_created_handle=task_created(), + result_handle=never_completes(), + result_type=CallToolResult, + task_id="task-14", + ) + + # Cancelled is a terminal state + result = await pending.result() + assert result == sample_result + assert mock_session.get_task.call_count == 2 diff --git a/tests/shared/test_task_store.py b/tests/shared/test_task_store.py new file mode 100644 index 000000000..d6c052987 --- /dev/null +++ b/tests/shared/test_task_store.py @@ -0,0 +1,355 @@ +"""Unit tests for TaskStore implementation.""" + +import asyncio + +import pytest + +from examples.shared.in_memory_task_store import InMemoryTaskStore +from mcp.types import CallToolRequest, CallToolRequestParams, RequestId, TaskMetadata + + +@pytest.fixture +def task_store() -> InMemoryTaskStore: + """Create a fresh InMemoryTaskStore for each test.""" + return InMemoryTaskStore() + + +@pytest.fixture +def sample_task_metadata() -> TaskMetadata: + """Create sample task metadata.""" + return TaskMetadata(taskId="test-task-123", keepAlive=5000) + + +@pytest.fixture +def sample_request() -> CallToolRequest: + """Create a sample request.""" + return CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={"arg": "value"})) + + +class TestCreateTask: + """Tests for TaskStore.create_task().""" + + @pytest.mark.anyio + async def test_create_task_basic( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test creating a basic task.""" + request_id: RequestId = "req-1" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Verify task was created with correct initial state + task = await task_store.get_task("test-task-123") + assert task is not None + assert task.taskId == "test-task-123" + assert task.status == "submitted" + assert task.keepAlive == 5000 + assert task.pollInterval == 500 # Default value + + @pytest.mark.anyio + async def test_create_task_without_keep_alive(self, task_store: InMemoryTaskStore, sample_request: CallToolRequest): + """Test creating a task without keepAlive.""" + task_metadata = TaskMetadata(taskId="test-task-no-keepalive") + request_id: RequestId = "req-2" + await task_store.create_task(task_metadata, request_id, sample_request) + + task = await task_store.get_task("test-task-no-keepalive") + assert task is not None + assert task.keepAlive is None + # Should not schedule cleanup if no keepAlive + assert "test-task-no-keepalive" not in task_store._cleanup_tasks + + @pytest.mark.anyio + async def test_create_task_schedules_cleanup( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test that creating a task with keepAlive schedules cleanup.""" + request_id: RequestId = "req-3" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Verify cleanup task was scheduled + assert "test-task-123" in task_store._cleanup_tasks + cleanup_task = task_store._cleanup_tasks["test-task-123"] + assert not cleanup_task.done() + + @pytest.mark.anyio + async def test_create_duplicate_task_raises_error( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test that creating a task with duplicate ID raises ValueError.""" + request_id: RequestId = "req-4" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Attempt to create another task with same ID + with pytest.raises(ValueError, match="Task with ID test-task-123 already exists"): + await task_store.create_task(sample_task_metadata, "req-5", sample_request) + + +class TestGetTask: + """Tests for TaskStore.get_task().""" + + @pytest.mark.anyio + async def test_get_existing_task( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test retrieving an existing task.""" + request_id: RequestId = "req-6" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + task = await task_store.get_task("test-task-123") + assert task is not None + assert task.taskId == "test-task-123" + assert task.status == "submitted" + + @pytest.mark.anyio + async def test_get_nonexistent_task_returns_none(self, task_store: InMemoryTaskStore): + """Test that getting a non-existent task returns None.""" + task = await task_store.get_task("nonexistent-task") + assert task is None + + +class TestStoreTaskResult: + """Tests for TaskStore.store_task_result().""" + + @pytest.mark.anyio + async def test_store_result_for_completed_task( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test storing a result for a completed task.""" + from mcp.types import CallToolResult, TextContent + + request_id: RequestId = "req-7" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Update task to completed status + await task_store.update_task_status("test-task-123", "completed") + + # Store result + result = CallToolResult(content=[TextContent(type="text", text="Success!")]) + await task_store.store_task_result("test-task-123", result) + + # Verify result was stored + retrieved_result = await task_store.get_task_result("test-task-123") + retrieved_result = CallToolResult.model_validate(retrieved_result) + assert isinstance(retrieved_result.content[0], TextContent) + assert retrieved_result.content[0].text == "Success!" + + +class TestGetTaskResult: + """Tests for TaskStore.get_task_result().""" + + @pytest.mark.anyio + async def test_get_result_for_completed_task( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test retrieving result for a completed task.""" + from mcp.types import CallToolResult, TextContent + + request_id: RequestId = "req-8" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + await task_store.update_task_status("test-task-123", "completed") + + result = CallToolResult(content=[TextContent(type="text", text="Result!")]) + await task_store.store_task_result("test-task-123", result) + + retrieved_result = await task_store.get_task_result("test-task-123") + retrieved_result = CallToolResult.model_validate(retrieved_result) + assert isinstance(retrieved_result.content[0], TextContent) + assert retrieved_result.content[0].text == "Result!" + + @pytest.mark.anyio + async def test_get_result_for_incomplete_task_raises_error( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test that getting result for incomplete task raises ValueError.""" + request_id: RequestId = "req-9" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Task is still in 'submitted' status + with pytest.raises(ValueError, match="Task test-task-123 has no result stored"): + await task_store.get_task_result("test-task-123") + + @pytest.mark.anyio + async def test_get_result_for_nonexistent_task_raises_error(self, task_store: InMemoryTaskStore): + """Test that getting result for non-existent task raises ValueError.""" + with pytest.raises(ValueError, match="Task with ID nonexistent not found"): + await task_store.get_task_result("nonexistent") + + +class TestUpdateTaskStatus: + """Tests for TaskStore.update_task_status().""" + + @pytest.mark.anyio + async def test_update_status_to_working( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test updating task status to working.""" + request_id: RequestId = "req-10" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + await task_store.update_task_status("test-task-123", "working") + + task = await task_store.get_task("test-task-123") + assert task + assert task.status == "working" + + @pytest.mark.anyio + async def test_update_status_to_completed( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test updating task status to completed.""" + request_id: RequestId = "req-11" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + await task_store.update_task_status("test-task-123", "completed") + + task = await task_store.get_task("test-task-123") + assert task + assert task.status == "completed" + assert task.error is None + + @pytest.mark.anyio + async def test_update_status_to_failed_with_error( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test updating task status to failed with error message.""" + request_id: RequestId = "req-12" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + await task_store.update_task_status("test-task-123", "failed", error="Something went wrong") + + task = await task_store.get_task("test-task-123") + assert task + assert task.status == "failed" + assert task.error == "Something went wrong" + + @pytest.mark.anyio + async def test_update_status_to_terminal_reschedules_cleanup( + self, task_store: InMemoryTaskStore, sample_task_metadata: TaskMetadata, sample_request: CallToolRequest + ): + """Test that updating status to terminal state reschedules cleanup.""" + request_id: RequestId = "req-13" + await task_store.create_task(sample_task_metadata, request_id, sample_request) + + # Verify cleanup was scheduled + assert "test-task-123" in task_store._cleanup_tasks + original_cleanup = task_store._cleanup_tasks["test-task-123"] + + # Update status to 'completed' (terminal state, should cancel old and reschedule new cleanup) + await task_store.update_task_status("test-task-123", "completed") + + # Give the cancellation a moment to complete + await asyncio.sleep(0) + + # Original cleanup should be cancelled + assert original_cleanup.cancelled() + # New cleanup should be scheduled + assert "test-task-123" in task_store._cleanup_tasks + new_cleanup = task_store._cleanup_tasks["test-task-123"] + assert new_cleanup != original_cleanup + assert not new_cleanup.done() + + @pytest.mark.anyio + async def test_update_nonexistent_task_raises_error(self, task_store: InMemoryTaskStore): + """Test that updating non-existent task raises ValueError.""" + with pytest.raises(ValueError, match="Task with ID nonexistent not found"): + await task_store.update_task_status("nonexistent", "completed") + + +class TestListTasks: + """Tests for TaskStore.list_tasks().""" + + @pytest.mark.anyio + async def test_list_tasks_empty(self, task_store: InMemoryTaskStore): + """Test listing tasks when store is empty.""" + result = await task_store.list_tasks() + assert result["tasks"] == [] + assert result.get("nextCursor") is None + + @pytest.mark.anyio + async def test_list_tasks_single_page(self, task_store: InMemoryTaskStore, sample_request: CallToolRequest): + """Test listing tasks that fit on a single page.""" + # Create 3 tasks + for i in range(3): + task_meta = TaskMetadata(taskId=f"task-{i}") + await task_store.create_task(task_meta, f"req-{i}", sample_request) + + result = await task_store.list_tasks() + assert len(result["tasks"]) == 3 + assert result.get("nextCursor") is None + + @pytest.mark.anyio + async def test_list_tasks_pagination(self, task_store: InMemoryTaskStore, sample_request: CallToolRequest): + """Test listing tasks with pagination.""" + # Create 15 tasks (more than PAGE_SIZE=10) + for i in range(15): + task_meta = TaskMetadata(taskId=f"task-{i:02d}") + await task_store.create_task(task_meta, f"req-{i}", sample_request) + + # First page + result = await task_store.list_tasks() + assert len(result["tasks"]) == 10 + assert result["nextCursor"] is not None + + # Second page + result2 = await task_store.list_tasks(cursor=result["nextCursor"]) + assert len(result2["tasks"]) == 5 + assert result2.get("nextCursor") is None + + @pytest.mark.anyio + async def test_list_tasks_invalid_cursor_raises_error(self, task_store: InMemoryTaskStore): + """Test that invalid cursor raises ValueError.""" + with pytest.raises(ValueError, match="Invalid cursor"): + await task_store.list_tasks(cursor="invalid-cursor") + + +class TestTaskCleanup: + """Tests for task cleanup functionality.""" + + @pytest.mark.anyio + async def test_cleanup_removes_completed_task(self, task_store: InMemoryTaskStore, sample_request: CallToolRequest): + """Test that cleanup removes task after keepAlive expires.""" + # Create task with very short keepAlive + task_meta = TaskMetadata(taskId="cleanup-task", keepAlive=100) # 100ms + await task_store.create_task(task_meta, "req-cleanup", sample_request) + + # Update to completed to trigger cleanup timer + await task_store.update_task_status("cleanup-task", "completed") + + # Wait for cleanup (100ms + small buffer) + await asyncio.sleep(0.15) + + # Task should be removed + task = await task_store.get_task("cleanup-task") + assert task is None + + @pytest.mark.anyio + async def test_cleanup_rescheduled_on_terminal_status( + self, task_store: InMemoryTaskStore, sample_request: CallToolRequest + ): + """Test that cleanup is rescheduled when updating to terminal status.""" + task_meta = TaskMetadata(taskId="cancel-cleanup-task", keepAlive=5000) + await task_store.create_task(task_meta, "req-cancel", sample_request) + + # Update to completed (starts cleanup timer) + await task_store.update_task_status("cancel-cleanup-task", "completed") + + cleanup_task = task_store._cleanup_tasks.get("cancel-cleanup-task") + assert cleanup_task is not None + assert not cleanup_task.done() + + # Keep a reference to the first cleanup task + first_cleanup = cleanup_task + + # Update status again to 'failed' (terminal state, should reschedule cleanup) + await task_store.update_task_status("cancel-cleanup-task", "failed") + + # Give the cancellation a moment to complete + await asyncio.sleep(0) + + # First cleanup should be cancelled + assert first_cleanup.cancelled() + # New cleanup should be scheduled + assert "cancel-cleanup-task" in task_store._cleanup_tasks + second_cleanup = task_store._cleanup_tasks["cancel-cleanup-task"] + assert second_cleanup != first_cleanup + assert not second_cleanup.done() diff --git a/tests/test_task_capabilities.py b/tests/test_task_capabilities.py new file mode 100644 index 000000000..6b9e894a3 --- /dev/null +++ b/tests/test_task_capabilities.py @@ -0,0 +1,183 @@ +"""Tests for task execution capabilities.""" + +from typing import Any + +from pydantic import AnyUrl + +from examples.shared.in_memory_task_store import InMemoryTaskStore +from mcp import types +from mcp.server.lowlevel import NotificationOptions, Server + + +class TestCapabilitySerialization: + """Test that task capabilities serialize/deserialize correctly.""" + + def test_client_tasks_capability_full(self): + """Test full client tasks capability serialization.""" + cap = types.ClientTasksCapability( + requests=types.ClientTasksRequestsCapability( + sampling=types.TaskSamplingCapability(createMessage=True), + elicitation=types.TaskElicitationCapability(create=True), + roots=types.TaskRootsCapability(list=True), + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True), + ) + ) + # Serialize and deserialize + data = cap.model_dump(by_alias=True, mode="json", exclude_none=True) + deserialized = types.ClientTasksCapability.model_validate(data) + assert deserialized.requests is not None + assert deserialized.requests.sampling is not None + assert deserialized.requests.sampling.createMessage is True + assert deserialized.requests.elicitation is not None + assert deserialized.requests.elicitation.create is True + assert deserialized.requests.roots is not None + assert deserialized.requests.roots.list is True + assert deserialized.requests.tasks is not None + assert deserialized.requests.tasks.get is True + assert deserialized.requests.tasks.list is True + assert deserialized.requests.tasks.result is True + assert deserialized.requests.tasks.delete is True + + def test_server_tasks_capability_full(self): + """Test full server tasks capability serialization.""" + cap = types.ServerTasksCapability( + requests=types.ServerTasksRequestsCapability( + tools=types.TaskToolsCapability(call=True, list=True), + resources=types.TaskResourcesCapability(read=True, list=True), + prompts=types.TaskPromptsCapability(get=True, list=True), + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True), + ) + ) + # Serialize and deserialize + data = cap.model_dump(by_alias=True, mode="json", exclude_none=True) + deserialized = types.ServerTasksCapability.model_validate(data) + assert deserialized.requests is not None + assert deserialized.requests.tools is not None + assert deserialized.requests.tools.call is True + assert deserialized.requests.tools.list is True + assert deserialized.requests.resources is not None + assert deserialized.requests.resources.read is True + assert deserialized.requests.resources.list is True + assert deserialized.requests.prompts is not None + assert deserialized.requests.prompts.get is True + assert deserialized.requests.prompts.list is True + assert deserialized.requests.tasks is not None + assert deserialized.requests.tasks.get is True + assert deserialized.requests.tasks.list is True + assert deserialized.requests.tasks.result is True + assert deserialized.requests.tasks.delete is True + + def test_client_capabilities_with_tasks(self): + """Test ClientCapabilities with tasks field.""" + caps = types.ClientCapabilities( + sampling=types.SamplingCapability(), + tasks=types.ClientTasksCapability( + requests=types.ClientTasksRequestsCapability(tasks=types.TasksOperationsCapability(get=True, list=True)) + ), + ) + data = caps.model_dump(by_alias=True, mode="json", exclude_none=True) + deserialized = types.ClientCapabilities.model_validate(data) + assert deserialized.tasks is not None + assert deserialized.tasks.requests is not None + assert deserialized.tasks.requests.tasks is not None + assert deserialized.tasks.requests.tasks.get is True + assert deserialized.tasks.requests.tasks.list is True + + def test_server_capabilities_with_tasks(self): + """Test ServerCapabilities with tasks field.""" + caps = types.ServerCapabilities( + logging=types.LoggingCapability(), + tasks=types.ServerTasksCapability( + requests=types.ServerTasksRequestsCapability( + tools=types.TaskToolsCapability(call=True), + tasks=types.TasksOperationsCapability(get=True, delete=True), + ) + ), + ) + data = caps.model_dump(by_alias=True, mode="json", exclude_none=True) + deserialized = types.ServerCapabilities.model_validate(data) + assert deserialized.tasks is not None + assert deserialized.tasks.requests is not None + assert deserialized.tasks.requests.tools is not None + assert deserialized.tasks.requests.tools.call is True + assert deserialized.tasks.requests.tasks is not None + assert deserialized.tasks.requests.tasks.get is True + assert deserialized.tasks.requests.tasks.delete is True + + +class TestServerCapabilityAdvertisement: + """Test that server advertises task capabilities correctly.""" + + def test_no_tasks_capability_without_task_store(self): + """Server should not advertise tasks capability without task store.""" + server = Server("test") + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is None + + def test_tasks_capability_with_task_store(self): + """Server should advertise tasks capability with task store.""" + task_store = InMemoryTaskStore() + server = Server("test", task_store=task_store) + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tasks is not None + # All task operations should be available + assert caps.tasks.requests.tasks.get is True + assert caps.tasks.requests.tasks.list is True + assert caps.tasks.requests.tasks.result is True + assert caps.tasks.requests.tasks.delete is True + + def test_tasks_capability_includes_tools_when_available(self): + """Server should include tools in tasks capability when handler exists.""" + task_store = InMemoryTaskStore() + server = Server("test", task_store=task_store) + + # Register tool handler + @server.call_tool() + async def my_tool(arguments: dict[str, Any]) -> list[types.TextContent]: + return [types.TextContent(type="text", text="test")] + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tools is not None + assert caps.tasks.requests.tools.call is True + + def test_tasks_capability_includes_resources_when_available(self): + """Server should include resources in tasks capability when handler exists.""" + task_store = InMemoryTaskStore() + server = Server("test", task_store=task_store) + + # Register resource handler + @server.read_resource() + async def read_resource(uri: AnyUrl) -> str: + return "test" + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.resources is not None + assert caps.tasks.requests.resources.read is True + + def test_tasks_capability_includes_prompts_when_available(self): + """Server should include prompts in tasks capability when handler exists.""" + task_store = InMemoryTaskStore() + server = Server("test", task_store=task_store) + + # Register prompt handler + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: + return types.GetPromptResult( + messages=[types.PromptMessage(role="user", content=types.TextContent(type="text", text="test"))] + ) + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.prompts is not None + assert caps.tasks.requests.prompts.get is True + + +# Note: Additional integration tests for client capability announcement and validation +# are covered by the existing test suite which uses create_connected_server_and_client_session()