diff --git a/server/session.go b/server/session.go index 0ded99fb..48fd52d7 100644 --- a/server/session.go +++ b/server/session.go @@ -171,6 +171,9 @@ func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { s.sessions.Range(func(k, v any) bool { if session, ok := v.(ClientSession); ok && session.Initialized() { + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } select { case session.NotificationChannel() <- notification: // Successfully sent notification diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 9e444c53..32dccc3a 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -2251,3 +2251,91 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { } }) } + +// TestStreamableHTTP_AddToolDuringToolCall tests that adding a tool while a tool call +// is in progress doesn't break the client's response. +// This is a regression test for issue #638 where notifications sent via +// sendNotificationToAllClients during an in-progress request would cause +// the response to fail with "unexpected nil response". +func TestStreamableHTTP_AddToolDuringToolCall(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0", + WithToolCapabilities(true), // Enable tool list change notifications + ) + // Add a tool that takes some time to complete + mcpServer.AddTool(mcp.NewTool("slow_tool", + mcp.WithDescription("A tool that takes time to complete"), + ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Simulate work that takes some time + time.Sleep(100 * time.Millisecond) + return mcp.NewToolResultText("done"), nil + }) + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) + defer server.Close() + // Initialize to get session + resp, err := postJSON(server.URL, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + sessionID := resp.Header.Get(HeaderKeySessionID) + resp.Body.Close() + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + // Start the tool call in a goroutine + resultChan := make(chan struct { + statusCode int + body string + err error + }) + go func() { + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "slow_tool", + }, + } + toolBody, _ := json.Marshal(toolRequest) + req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + resp, err := server.Client().Do(req) + if err != nil { + resultChan <- struct { + statusCode int + body string + err error + }{0, "", err} + return + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + resultChan <- struct { + statusCode int + body string + err error + }{resp.StatusCode, string(body), nil} + }() + // Wait a bit then add a new tool while the slow_tool is executing + // This triggers sendNotificationToAllClients + time.Sleep(50 * time.Millisecond) + mcpServer.AddTool(mcp.NewTool("new_tool", + mcp.WithDescription("A new tool added during execution"), + ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("new tool result"), nil + }) + // Wait for the tool call to complete + result := <-resultChan + if result.err != nil { + t.Fatalf("Tool call failed with error: %v", result.err) + } + if result.statusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d. Body: %s", result.statusCode, result.body) + } + // The response should contain the tool result + // It may be SSE format (text/event-stream) due to the notification upgrade + if !strings.Contains(result.body, "done") { + t.Errorf("Expected response to contain 'done', got: %s", result.body) + } +}