diff --git a/mcp/types.go b/mcp/types.go index 241b55ce9..13aee511e 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -761,6 +761,26 @@ const ( LoggingLevelEmergency LoggingLevel = "emergency" ) +var levelToInt = map[LoggingLevel]int{ + LoggingLevelDebug: 0, + LoggingLevelInfo: 1, + LoggingLevelNotice: 2, + LoggingLevelWarning: 3, + LoggingLevelError: 4, + LoggingLevelCritical: 5, + LoggingLevelAlert: 6, + LoggingLevelEmergency: 7, +} + +func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { + ia, oka := levelToInt[l] + ib, okb := levelToInt[minLevel] + if !oka || !okb { + return false + } + return ia >= ib +} + /* Sampling */ const ( diff --git a/server/session.go b/server/session.go index a79da22ca..165ecaedd 100644 --- a/server/session.go +++ b/server/session.go @@ -96,35 +96,38 @@ func (s *MCPServer) RegisterSession( return nil } -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - ctx context.Context, - sessionID string, -) { - sessionValue, ok := s.sessions.LoadAndDelete(sessionID) - if !ok { - return - } - if session, ok := sessionValue.(ClientSession); ok { - s.hooks.UnregisterSession(ctx, session) - } -} - -// SendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) SendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ +func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification { + return mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ - Method: method, + Method: notification.Method, Params: mcp.NotificationParams{ - AdditionalFields: params, + AdditionalFields: map[string]any{ + "level": notification.Params.Level, + "logger": notification.Params.Logger, + "data": notification.Params.Data, + }, }, }, } +} +func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification)) +} + +func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { s.sessions.Range(func(k, v any) bool { if session, ok := v.(ClientSession); ok && session.Initialized() { select { @@ -140,7 +143,7 @@ func (s *MCPServer) SendNotificationToAllClients( ctx := context.Background() // Use the error hook to report the blocked channel hooks.onError(ctx, nil, "notification", map[string]any{ - "method": method, + "method": notification.Method, "sessionID": sessionID, }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) }(session.SessionID(), hooks) @@ -151,22 +154,71 @@ func (s *MCPServer) SendNotificationToAllClients( }) } -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return ErrNotificationNotInitialized - } - +func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error { // upgrades the client-server communication to SSE stream when the server sends notifications to the client if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification)) +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -176,13 +228,26 @@ func (s *MCPServer) SendNotificationToClient( }, }, } + s.sendNotificationToAllClients(notification) +} +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) sendNotificationCore( + ctx context.Context, + session ClientSession, + notification mcp.JSONRPCNotification, +) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } select { case session.NotificationChannel() <- notification: return nil default: // Channel is blocked, if there's an error hook, use it if s.hooks != nil && len(s.hooks.OnError) > 0 { + method := notification.Method err := ErrNotificationChannelBlocked // Copy hooks pointer to local variable to avoid race condition hooks := s.hooks @@ -198,6 +263,28 @@ func (s *MCPServer) SendNotificationToClient( } } +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationCore(ctx, session, notification) +} + // SendNotificationToSpecificClient sends a notification to a specific client by session ID func (s *MCPServer) SendNotificationToSpecificClient( sessionID string, @@ -208,17 +295,10 @@ func (s *MCPServer) SendNotificationToSpecificClient( if !ok { return ErrSessionNotFound } - session, ok := sessionValue.(ClientSession) if !ok || !session.Initialized() { return ErrSessionNotInitialized } - - // upgrades the client-server communication to SSE stream when the server sends notifications to the client - if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { - sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() - } - notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -228,27 +308,7 @@ func (s *MCPServer) SendNotificationToSpecificClient( }, }, } - - select { - case session.NotificationChannel() <- notification: - return nil - default: - // Channel is blocked, if there's an error hook, use it - if s.hooks != nil && len(s.hooks.OnError) > 0 { - err := ErrNotificationChannelBlocked - ctx := context.Background() - // Copy hooks pointer to local variable to avoid race condition - hooks := s.hooks - go func(sID string, hooks *Hooks) { - // Use the error hook to report the blocked channel - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": method, - "sessionID": sID, - }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) - }(sessionID, hooks) - } - return ErrNotificationChannelBlocked - } + return s.sendNotificationToSpecificClient(session, notification) } // AddSessionTool adds a tool for a specific session diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..22da95714 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -1126,3 +1126,384 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") } + +// New test function to cover log notification functionality +func TestMCPServer_SendLogMessageToClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create a session that supports logging + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + + // Set log level to Info + session.SetLogLevel(mcp.LoggingLevelInfo) + + // Register session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Create session context + sessionCtx := server.WithContext(ctx, session) + + // Test cases + tests := []struct { + name string + level mcp.LoggingLevel + expectSent bool + expectError bool + }{ + { + name: "higher level log should be sent", + level: mcp.LoggingLevelWarning, // Higher than Info + expectSent: true, + }, + { + name: "same level log should be sent", + level: mcp.LoggingLevelInfo, + expectSent: true, + }, + { + name: "lower level log should not be sent", + level: mcp.LoggingLevelDebug, // Lower than Info + expectSent: false, + }, + { + name: "uninitialized session should return error", + level: mcp.LoggingLevelError, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectError { + // Create uninitialized session + uninitSession := &sessionTestClientWithLogging{ + sessionID: "uninit-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + uninitCtx := server.WithContext(ctx, uninitSession) + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + err := server.SendLogMessageToClient(uninitCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrNotificationNotInitialized, err) + return + } + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + err := server.SendLogMessageToClient(sessionCtx, notification) + require.NoError(t, err) + + if tt.expectSent { + select { + case notif := <-sessionChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, tt.level, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + assert.Equal(t, "test message", notif.Params.AdditionalFields["data"]) + case <-time.After(500 * time.Millisecond): + t.Error("Expected log notification not received") + } + } else { + select { + case <-sessionChan: + t.Error("Unexpected log notification received") + case <-time.After(50 * time.Millisecond): + // No notification expected + } + } + }) + } +} + +func TestMCPServer_SendLogMessageToSpecificClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create two sessions + session1Chan := make(chan mcp.JSONRPCNotification, 10) + session1 := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: session1Chan, + } + session1.Initialize() + session1.SetLogLevel(mcp.LoggingLevelInfo) + + session2Chan := make(chan mcp.JSONRPCNotification, 10) + session2 := &sessionTestClientWithLogging{ + sessionID: "session-2", + notificationChannel: session2Chan, + } + session2.Initialize() + session2.SetLogLevel(mcp.LoggingLevelWarning) // Higher log level + + // Register sessions + require.NoError(t, server.RegisterSession(ctx, session1)) + require.NoError(t, server.RegisterSession(ctx, session2)) + + // Test cases + tests := []struct { + name string + sessionID string + level mcp.LoggingLevel + expectSent bool + expectError bool + errorType error + }{ + { + name: "valid session and level should be sent", + sessionID: session1.SessionID(), + level: mcp.LoggingLevelInfo, + expectSent: true, + }, + { + name: "log below session level should not be sent", + sessionID: session1.SessionID(), + level: mcp.LoggingLevelDebug, + expectSent: false, + }, + { + name: "valid session with higher level should be sent", + sessionID: session2.SessionID(), + level: mcp.LoggingLevelError, + expectSent: true, + }, + { + name: "non-existent session should return error", + sessionID: "non-existent", + level: mcp.LoggingLevelError, + expectError: true, + errorType: ErrSessionNotFound, + }, + { + name: "uninitialized session should return error", + sessionID: "uninitialized-session", + level: mcp.LoggingLevelError, + expectError: true, + errorType: ErrSessionNotInitialized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sessionID == "uninitialized-session" { + uninitSession := &sessionTestClientWithLogging{ + sessionID: "uninitialized-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + require.NoError(t, server.RegisterSession(ctx, uninitSession)) + } + + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + + err := server.SendLogMessageToSpecificClient(tt.sessionID, notification) + + if tt.expectError { + require.Error(t, err) + if tt.errorType != nil { + assert.ErrorIs(t, err, tt.errorType) + } + return + } + + require.NoError(t, err) + + var targetChan chan mcp.JSONRPCNotification + if tt.sessionID == session1.SessionID() { + targetChan = session1Chan + } else if tt.sessionID == session2.SessionID() { + targetChan = session2Chan + } + + if tt.expectSent && targetChan != nil { + select { + case notif := <-targetChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, tt.level, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + assert.Equal(t, "test message", notif.Params.AdditionalFields["data"]) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + } else if targetChan != nil { + select { + case <-targetChan: + t.Error("Unexpected log notification received") + case <-time.After(50 * time.Millisecond): + // No notification expected + } + } + }) + } +} + +func TestMCPServer_LoggingWithUnsupportedSessions(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create three types of sessions: + // 1. Logging-supported session + // 2. Logging-unsupported session + // 3. Uninitialized session + + // Logging-supported session + loggingSessionChan := make(chan mcp.JSONRPCNotification, 10) + loggingSession := &sessionTestClientWithLogging{ + sessionID: "logging-session", + notificationChannel: loggingSessionChan, + } + loggingSession.Initialize() + loggingSession.SetLogLevel(mcp.LoggingLevelInfo) + + // Logging-unsupported session + nonLoggingSessionChan := make(chan mcp.JSONRPCNotification, 10) + nonLoggingSession := &sessionTestClient{ + sessionID: "non-logging-session", + notificationChannel: nonLoggingSessionChan, + } + nonLoggingSession.Initialize() + + // Uninitialized session + uninitializedSessionChan := make(chan mcp.JSONRPCNotification, 10) + uninitializedSession := &sessionTestClientWithLogging{ + sessionID: "uninitialized-session", + notificationChannel: uninitializedSessionChan, + initialized: false, + } + + // Register all sessions + require.NoError(t, server.RegisterSession(ctx, loggingSession)) + require.NoError(t, server.RegisterSession(ctx, nonLoggingSession)) + require.NoError(t, server.RegisterSession(ctx, uninitializedSession)) + + // Info-level log notification + notification := mcp.NewLoggingMessageNotification(mcp.LoggingLevelInfo, "test-logger", "test message for ") + + t.Run("SendLogMessageToClient", func(t *testing.T) { + // Logging-supported session + loggingCtx := server.WithContext(ctx, loggingSession) + err := server.SendLogMessageToClient(loggingCtx, notification) + require.NoError(t, err) + select { + case notif := <-loggingSessionChan: + assert.Equal(t, "notifications/message", notif.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + + // Logging-unsupported session + nonLoggingCtx := server.WithContext(ctx, nonLoggingSession) + err = server.SendLogMessageToClient(nonLoggingCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrSessionDoesNotSupportLogging, err) + + // Uninitialized session + uninitCtx := server.WithContext(ctx, uninitializedSession) + err = server.SendLogMessageToClient(uninitCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrNotificationNotInitialized, err) + }) + + t.Run("SendLogMessageToSpecificClient", func(t *testing.T) { + err := server.SendLogMessageToSpecificClient(loggingSession.SessionID(), notification) + require.NoError(t, err) + select { + case notif := <-loggingSessionChan: + assert.Equal(t, "notifications/message", notif.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + + err = server.SendLogMessageToSpecificClient(nonLoggingSession.SessionID(), notification) + require.Error(t, err) + assert.Equal(t, ErrSessionDoesNotSupportLogging, err) + + err = server.SendLogMessageToSpecificClient(uninitializedSession.SessionID(), notification) + require.Error(t, err) + assert.Equal(t, ErrSessionNotInitialized, err) + }) +} + +func TestMCPServer_LoggingNotificationFormat(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + session.SetLogLevel(mcp.LoggingLevelDebug) + + // Register session + require.NoError(t, server.RegisterSession(ctx, session)) + + // Send log messages with different formats + testCases := []struct { + name string + data interface{} + expected interface{} + }{ + { + name: "string data", + data: "simple log message", + expected: "simple log message", + }, + { + name: "structured data", + data: map[string]interface{}{"key": "value", "num": 42}, + expected: map[string]interface{}{"key": "value", "num": 42}, + }, + { + name: "error data", + data: errors.New("error message"), + expected: errors.New("error message"), + }, + { + name: "nil data", + data: nil, + expected: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + notification := mcp.NewLoggingMessageNotification(mcp.LoggingLevelInfo, "test-logger", tc.data) + + err := server.SendLogMessageToSpecificClient(session.SessionID(), notification) + require.NoError(t, err) + + select { + case notif := <-sessionChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, mcp.LoggingLevelInfo, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + + // Validate log data format + dataField := notif.Params.AdditionalFields["data"] + switch expected := tc.expected.(type) { + case string: + assert.Equal(t, expected, dataField) + case map[string]interface{}: + assert.IsType(t, map[string]interface{}{}, dataField) + dataMap := dataField.(map[string]interface{}) + for k, v := range expected { + assert.Equal(t, v, dataMap[k]) + } + case nil: + assert.Nil(t, dataField) + } + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + }) + } +} diff --git a/server/streamable_http.go b/server/streamable_http.go index 1312c9753..4a2bdb63e 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -129,6 +129,7 @@ type StreamableHTTPServer struct { sessionIdManager SessionIdManager listenHeartbeatInterval time.Duration logger util.Logger + sessionLogLevels *sessionLogLevelsStore } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -136,6 +137,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S s := &StreamableHTTPServer{ server: server, sessionTools: newSessionToolsStore(), + sessionLogLevels: newSessionLogLevelsStore(), endpointPath: "/mcp", sessionIdManager: &InsecureStatefulSessionIdManager{}, logger: util.DefaultLogger(), @@ -255,7 +257,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -365,7 +367,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) if err := s.server.RegisterSession(r.Context(), session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) return @@ -468,7 +470,7 @@ func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Reque // remove the session relateddata from the sessionToolsStore s.sessionTools.delete(sessionID) - + s.sessionLogLevels.delete(sessionID) // remove current session's requstID information s.sessionRequestIDs.Delete(sessionID) @@ -511,6 +513,38 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { } // --- session --- +type sessionLogLevelsStore struct { + mu sync.RWMutex + logs map[string]mcp.LoggingLevel +} + +func newSessionLogLevelsStore() *sessionLogLevelsStore { + return &sessionLogLevelsStore{ + logs: make(map[string]mcp.LoggingLevel), + } +} + +func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.logs[sessionID] + if !ok { + return mcp.LoggingLevelError + } + return val +} + +func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) { + s.mu.Lock() + defer s.mu.Unlock() + s.logs[sessionID] = level +} + +func (s *sessionLogLevelsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.logs, sessionID) +} type sessionToolsStore struct { mu sync.RWMutex @@ -549,14 +583,17 @@ type streamableHttpSession struct { notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore upgradeToSSE atomic.Bool + logLevels *sessionLogLevelsStore } -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { - return &streamableHttpSession{ +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { + s := &streamableHttpSession{ sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), tools: toolStore, + logLevels: levels, } + return s } func (s *streamableHttpSession) SessionID() string { @@ -577,6 +614,14 @@ func (s *streamableHttpSession) Initialized() bool { return true } +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.logLevels.set(s.sessionID, level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + return s.logLevels.get(s.sessionID) +} + var _ ClientSession = (*streamableHttpSession)(nil) func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { @@ -587,7 +632,10 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { s.tools.set(s.sessionID, tools) } -var _ SessionWithTools = (*streamableHttpSession)(nil) +var ( + _ SessionWithTools = (*streamableHttpSession)(nil) + _ SessionWithLogging = (*streamableHttpSession)(nil) +) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { s.upgradeToSSE.Store(true) @@ -615,10 +663,12 @@ type StatelessSessionIdManager struct{} func (s *StatelessSessionIdManager) Generate() string { return "" } + func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { // In stateless mode, ignore session IDs completely - don't validate or reject them return false, nil } + func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, nil } @@ -633,6 +683,7 @@ const idPrefix = "mcp-session-" func (s *InsecureStatefulSessionIdManager) Generate() string { return idPrefix + uuid.New().String() } + func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { // validate the session id is a valid uuid if !strings.HasPrefix(sessionID, idPrefix) { @@ -643,6 +694,7 @@ func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTermina } return false, nil } + func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, nil } diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 0e1c7a65b..f45fa983f 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -725,6 +725,74 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { }) } +func TestStreamableHTTP_SessionWithLogging(t *testing.T) { + t.Run("SessionWithLogging implementation", func(t *testing.T) { + hooks := &Hooks{} + var logSession *streamableHttpSession + var mu sync.Mutex + + hooks.AddAfterSetLevel(func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + if s, ok := ClientSessionFromContext(ctx).(*streamableHttpSession); ok { + mu.Lock() + logSession = s + mu.Unlock() + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks), WithLogging()) + testServer := NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + // obtain a valid session ID first + initResp, err := postJSON(testServer.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send init request: %v", err) + } + defer initResp.Body.Close() + sessionID := initResp.Header.Get(headerKeySessionID) + if sessionID == "" { + t.Fatal("Expected session id in header") + } + + setLevelRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": map[string]any{ + "level": mcp.LoggingLevelCritical, + }, + } + + reqBody, _ := json.Marshal(setLevelRequest) + req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set(headerKeySessionID, sessionID) + + resp, err := testServer.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + mu.Lock() + if logSession == nil { + mu.Unlock() + t.Fatal("Session was not captured") + } + if logSession.GetLogLevel() != mcp.LoggingLevelCritical { + t.Errorf("Expected critical level, got %v", logSession.GetLogLevel()) + } + mu.Unlock() + }) +} + func TestStreamableHTTPServer_WithOptions(t *testing.T) { t.Run("WithStreamableHTTPServer sets httpServer field", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0")