Skip to content

Commit 7c4fe62

Browse files
authored
Add per-session notifications handling (#46)
1 parent a73d7cf commit 7c4fe62

File tree

5 files changed

+330
-103
lines changed

5 files changed

+330
-103
lines changed

examples/everything/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ func handleSendNotification(
300300
server := server.ServerFromContext(ctx)
301301

302302
err := server.SendNotificationToClient(
303+
ctx,
303304
"notifications/progress",
304305
map[string]interface{}{
305306
"progress": 10,
@@ -336,6 +337,7 @@ func handleLongRunningOperationTool(
336337
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
337338
if progressToken != nil {
338339
server.SendNotificationToClient(
340+
ctx,
339341
"notifications/progress",
340342
map[string]interface{}{
341343
"progress": i,

server/server.go

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,23 @@ type ServerTool struct {
4646
Handler ToolHandlerFunc
4747
}
4848

49-
// NotificationContext provides client identification for notifications
50-
type NotificationContext struct {
51-
ClientID string
52-
SessionID string
49+
// ClientSession represents an active session that can be used by MCPServer to interact with client.
50+
type ClientSession interface {
51+
// NotificationChannel provides a channel suitable for sending notifications to client.
52+
NotificationChannel() chan<- mcp.JSONRPCNotification
53+
// SessionID is a unique identifier used to track user session.
54+
SessionID() string
5355
}
5456

55-
// ServerNotification combines the notification with client context
56-
type ServerNotification struct {
57-
Context NotificationContext
58-
Notification mcp.JSONRPCNotification
57+
// clientSessionKey is the context key for storing current client notification channel.
58+
type clientSessionKey struct{}
59+
60+
// ClientSessionFromContext retrieves current client notification context from context.
61+
func ClientSessionFromContext(ctx context.Context) ClientSession {
62+
if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
63+
return session
64+
}
65+
return nil
5966
}
6067

6168
// NotificationHandlerFunc handles incoming notifications.
@@ -75,9 +82,7 @@ type MCPServer struct {
7582
tools map[string]ServerTool
7683
notificationHandlers map[string]NotificationHandlerFunc
7784
capabilities serverCapabilities
78-
notifications chan ServerNotification
79-
clientMu sync.Mutex // Separate mutex for client context
80-
currentClient NotificationContext
85+
sessions sync.Map
8186
initialized atomic.Bool // Use atomic for the initialized flag
8287
}
8388

@@ -92,30 +97,70 @@ func ServerFromContext(ctx context.Context) *MCPServer {
9297
return nil
9398
}
9499

95-
// WithContext sets the current client context and returns the provided context
100+
// WithContext sets the current client session and returns the provided context
96101
func (s *MCPServer) WithContext(
97102
ctx context.Context,
98-
notifCtx NotificationContext,
103+
session ClientSession,
99104
) context.Context {
100-
s.clientMu.Lock()
101-
s.currentClient = notifCtx
102-
s.clientMu.Unlock()
103-
return ctx
105+
return context.WithValue(ctx, clientSessionKey{}, session)
106+
}
107+
108+
// RegisterSession saves session that should be notified in case if some server attributes changed.
109+
func (s *MCPServer) RegisterSession(
110+
session ClientSession,
111+
) error {
112+
sessionID := session.SessionID()
113+
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
114+
return fmt.Errorf("session %s is already registered", sessionID)
115+
}
116+
return nil
117+
}
118+
119+
// UnregisterSession removes from storage session that is shut down.
120+
func (s *MCPServer) UnregisterSession(
121+
sessionID string,
122+
) {
123+
s.sessions.Delete(sessionID)
124+
}
125+
126+
// sendNotificationToAllClients sends a notification to all the currently active clients.
127+
func (s *MCPServer) sendNotificationToAllClients(
128+
method string,
129+
params map[string]any,
130+
) {
131+
notification := mcp.JSONRPCNotification{
132+
JSONRPC: mcp.JSONRPC_VERSION,
133+
Notification: mcp.Notification{
134+
Method: method,
135+
Params: mcp.NotificationParams{
136+
AdditionalFields: params,
137+
},
138+
},
139+
}
140+
141+
s.sessions.Range(func(k, v any) bool {
142+
if session, ok := v.(ClientSession); ok {
143+
select {
144+
case session.NotificationChannel() <- notification:
145+
default:
146+
// TODO: log blocked channel in the future versions
147+
}
148+
}
149+
return true
150+
})
104151
}
105152

106153
// SendNotificationToClient sends a notification to the current client
107154
func (s *MCPServer) SendNotificationToClient(
155+
ctx context.Context,
108156
method string,
109-
params map[string]interface{},
157+
params map[string]any,
110158
) error {
111-
if s.notifications == nil {
159+
session := ClientSessionFromContext(ctx)
160+
if session == nil {
112161
return fmt.Errorf("notification channel not initialized")
113162
}
114163

115-
s.clientMu.Lock()
116-
clientContext := s.currentClient
117-
s.clientMu.Unlock()
118-
119164
notification := mcp.JSONRPCNotification{
120165
JSONRPC: mcp.JSONRPC_VERSION,
121166
Notification: mcp.Notification{
@@ -127,10 +172,7 @@ func (s *MCPServer) SendNotificationToClient(
127172
}
128173

129174
select {
130-
case s.notifications <- ServerNotification{
131-
Context: clientContext,
132-
Notification: notification,
133-
}:
175+
case session.NotificationChannel() <- notification:
134176
return nil
135177
default:
136178
return fmt.Errorf("notification channel full or blocked")
@@ -220,7 +262,6 @@ func NewMCPServer(
220262
name: name,
221263
version: version,
222264
notificationHandlers: make(map[string]NotificationHandlerFunc),
223-
notifications: make(chan ServerNotification, 100),
224265
capabilities: serverCapabilities{
225266
tools: nil,
226267
resources: nil,
@@ -491,9 +532,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
491532

492533
// Send notification if server is already initialized
493534
if initialized {
494-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
495-
// We can't return the error, but in a future version we could log it
496-
}
535+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
497536
}
498537
}
499538

@@ -516,9 +555,7 @@ func (s *MCPServer) DeleteTools(names ...string) {
516555

517556
// Send notification if server is already initialized
518557
if initialized {
519-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
520-
// We can't return the error, but in a future version we could log it
521-
}
558+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
522559
}
523560
}
524561

0 commit comments

Comments
 (0)