Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mcp/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,26 @@ const (
LoggingLevelEmergency LoggingLevel = "emergency"
)

var levelToInt = map[LoggingLevel]int{
LoggingLevelDebug: 0,
LoggingLevelInfo: 1,
LoggingLevelNotice: 2,
LoggingLevelWarning: 3,
LoggingLevelError: 4,
LoggingLevelCritical: 5,
LoggingLevelAlert: 6,
LoggingLevelEmergency: 7,
}

func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool {
ia, oka := levelToInt[l]
ib, okb := levelToInt[minLevel]
if !oka || !okb {
return false
}
return ia >= ib
}

/* Sampling */

const (
Expand Down
184 changes: 122 additions & 62 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,35 +96,38 @@ func (s *MCPServer) RegisterSession(
return nil
}

// UnregisterSession removes from storage session that is shut down.
func (s *MCPServer) UnregisterSession(
ctx context.Context,
sessionID string,
) {
sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
if !ok {
return
}
if session, ok := sessionValue.(ClientSession); ok {
s.hooks.UnregisterSession(ctx, session)
}
}

// SendNotificationToAllClients sends a notification to all the currently active clients.
func (s *MCPServer) SendNotificationToAllClients(
method string,
params map[string]any,
) {
notification := mcp.JSONRPCNotification{
func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification {
return mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: method,
Method: notification.Method,
Params: mcp.NotificationParams{
AdditionalFields: params,
AdditionalFields: map[string]any{
"level": notification.Params.Level,
"logger": notification.Params.Logger,
"data": notification.Params.Data,
},
},
},
}
}

func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error {
session := ClientSessionFromContext(ctx)
if session == nil || !session.Initialized() {
return ErrNotificationNotInitialized
}
sessionLogging, ok := session.(SessionWithLogging)
if !ok {
return ErrSessionDoesNotSupportLogging
}
if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
return nil
}
return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification))
}

func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) {
s.sessions.Range(func(k, v any) bool {
if session, ok := v.(ClientSession); ok && session.Initialized() {
select {
Expand All @@ -140,7 +143,7 @@ func (s *MCPServer) SendNotificationToAllClients(
ctx := context.Background()
// Use the error hook to report the blocked channel
hooks.onError(ctx, nil, "notification", map[string]any{
"method": method,
"method": notification.Method,
"sessionID": sessionID,
}, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
}(session.SessionID(), hooks)
Expand All @@ -151,22 +154,71 @@ func (s *MCPServer) SendNotificationToAllClients(
})
}

// SendNotificationToClient sends a notification to the current client
func (s *MCPServer) SendNotificationToClient(
ctx context.Context,
method string,
params map[string]any,
) error {
session := ClientSessionFromContext(ctx)
if session == nil || !session.Initialized() {
return ErrNotificationNotInitialized
}

func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error {
// upgrades the client-server communication to SSE stream when the server sends notifications to the client
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
}
select {
case session.NotificationChannel() <- notification:
return nil
default:
// Channel is blocked, if there's an error hook, use it
if s.hooks != nil && len(s.hooks.OnError) > 0 {
err := ErrNotificationChannelBlocked
ctx := context.Background()
// Copy hooks pointer to local variable to avoid race condition
hooks := s.hooks
go func(sID string, hooks *Hooks) {
// Use the error hook to report the blocked channel
hooks.onError(ctx, nil, "notification", map[string]any{
"method": notification.Method,
"sessionID": sID,
}, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
}(session.SessionID(), hooks)
}
return ErrNotificationChannelBlocked
}
}

func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error {
sessionValue, ok := s.sessions.Load(sessionID)
if !ok {
return ErrSessionNotFound
}
session, ok := sessionValue.(ClientSession)
if !ok || !session.Initialized() {
return ErrSessionNotInitialized
}
sessionLogging, ok := session.(SessionWithLogging)
if !ok {
return ErrSessionDoesNotSupportLogging
}
if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
return nil
}
return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification))
}

// UnregisterSession removes from storage session that is shut down.
func (s *MCPServer) UnregisterSession(
ctx context.Context,
sessionID string,
) {
sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
if !ok {
return
}
if session, ok := sessionValue.(ClientSession); ok {
s.hooks.UnregisterSession(ctx, session)
}
}

// SendNotificationToAllClients sends a notification to all the currently active clients.
func (s *MCPServer) SendNotificationToAllClients(
method string,
params map[string]any,
) {
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Expand All @@ -176,13 +228,26 @@ func (s *MCPServer) SendNotificationToClient(
},
},
}
s.sendNotificationToAllClients(notification)
}

// SendNotificationToClient sends a notification to the current client
func (s *MCPServer) sendNotificationCore(
ctx context.Context,
session ClientSession,
notification mcp.JSONRPCNotification,
) error {
// upgrades the client-server communication to SSE stream when the server sends notifications to the client
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
}
select {
case session.NotificationChannel() <- notification:
return nil
default:
// Channel is blocked, if there's an error hook, use it
if s.hooks != nil && len(s.hooks.OnError) > 0 {
method := notification.Method
err := ErrNotificationChannelBlocked
// Copy hooks pointer to local variable to avoid race condition
hooks := s.hooks
Expand All @@ -198,6 +263,28 @@ func (s *MCPServer) SendNotificationToClient(
}
}

// SendNotificationToClient sends a notification to the current client
func (s *MCPServer) SendNotificationToClient(
ctx context.Context,
method string,
params map[string]any,
) error {
session := ClientSessionFromContext(ctx)
if session == nil || !session.Initialized() {
return ErrNotificationNotInitialized
}
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: method,
Params: mcp.NotificationParams{
AdditionalFields: params,
},
},
}
return s.sendNotificationCore(ctx, session, notification)
}

// SendNotificationToSpecificClient sends a notification to a specific client by session ID
func (s *MCPServer) SendNotificationToSpecificClient(
sessionID string,
Expand All @@ -208,17 +295,10 @@ func (s *MCPServer) SendNotificationToSpecificClient(
if !ok {
return ErrSessionNotFound
}

session, ok := sessionValue.(ClientSession)
if !ok || !session.Initialized() {
return ErrSessionNotInitialized
}

// upgrades the client-server communication to SSE stream when the server sends notifications to the client
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
}

notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Expand All @@ -228,27 +308,7 @@ func (s *MCPServer) SendNotificationToSpecificClient(
},
},
}

select {
case session.NotificationChannel() <- notification:
return nil
default:
// Channel is blocked, if there's an error hook, use it
if s.hooks != nil && len(s.hooks.OnError) > 0 {
err := ErrNotificationChannelBlocked
ctx := context.Background()
// Copy hooks pointer to local variable to avoid race condition
hooks := s.hooks
go func(sID string, hooks *Hooks) {
// Use the error hook to report the blocked channel
hooks.onError(ctx, nil, "notification", map[string]any{
"method": method,
"sessionID": sID,
}, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
}(sessionID, hooks)
}
return ErrNotificationChannelBlocked
}
return s.sendNotificationToSpecificClient(session, notification)
}

// AddSessionTool adds a tool for a specific session
Expand Down
Loading