@@ -23,35 +23,59 @@ import (
23
23
// while sending requests over regular HTTP POST calls. The client handles
24
24
// automatic reconnection and message routing between requests and responses.
25
25
type SSEMCPClient struct {
26
- baseURL * url.URL
27
- endpoint * url.URL
28
- httpClient * http.Client
29
- requestID atomic.Int64
30
- responses map [int64 ]chan RPCResponse
31
- mu sync.RWMutex
32
- done chan struct {}
33
- initialized bool
34
- notifications []func (mcp.JSONRPCNotification )
35
- notifyMu sync.RWMutex
36
- endpointChan chan struct {}
37
- capabilities mcp.ServerCapabilities
26
+ baseURL * url.URL
27
+ endpoint * url.URL
28
+ httpClient * http.Client
29
+ requestID atomic.Int64
30
+ responses map [int64 ]chan RPCResponse
31
+ mu sync.RWMutex
32
+ done chan struct {}
33
+ initialized bool
34
+ notifications []func (mcp.JSONRPCNotification )
35
+ notifyMu sync.RWMutex
36
+ endpointChan chan struct {}
37
+ capabilities mcp.ServerCapabilities
38
+ headers map [string ]string
39
+ sseReadTimeout time.Duration
40
+ }
41
+
42
+ type ClientOption func (* SSEMCPClient )
43
+
44
+ func WithHeaders (headers map [string ]string ) ClientOption {
45
+ return func (sc * SSEMCPClient ) {
46
+ sc .headers = headers
47
+ }
48
+ }
49
+
50
+ func WithSSEReadTimeout (timeout time.Duration ) ClientOption {
51
+ return func (sc * SSEMCPClient ) {
52
+ sc .sseReadTimeout = timeout
53
+ }
38
54
}
39
55
40
56
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
41
57
// Returns an error if the URL is invalid.
42
- func NewSSEMCPClient (baseURL string ) (* SSEMCPClient , error ) {
58
+ func NewSSEMCPClient (baseURL string , options ... ClientOption ) (* SSEMCPClient , error ) {
43
59
parsedURL , err := url .Parse (baseURL )
44
60
if err != nil {
45
61
return nil , fmt .Errorf ("invalid URL: %w" , err )
46
62
}
47
63
48
- return & SSEMCPClient {
49
- baseURL : parsedURL ,
50
- httpClient : & http.Client {},
51
- responses : make (map [int64 ]chan RPCResponse ),
52
- done : make (chan struct {}),
53
- endpointChan : make (chan struct {}),
54
- }, nil
64
+ smc := & SSEMCPClient {
65
+ baseURL : parsedURL ,
66
+ httpClient : & http.Client {},
67
+ responses : make (map [int64 ]chan RPCResponse ),
68
+ done : make (chan struct {}),
69
+ endpointChan : make (chan struct {}),
70
+ sseReadTimeout : 30 * time .Second ,
71
+ headers : make (map [string ]string ),
72
+ }
73
+
74
+ for _ , opt := range options {
75
+ opt (smc )
76
+ }
77
+
78
+ return smc , nil
55
79
}
56
80
57
81
// Start initiates the SSE connection to the server and waits for the endpoint information.
@@ -104,41 +128,49 @@ func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
104
128
br := bufio .NewReader (reader )
105
129
var event , data string
106
130
131
+ ctx , cancel := context .WithTimeout (context .Background (), c .sseReadTimeout )
132
+ defer cancel ()
133
+
107
134
for {
108
- line , err := br .ReadString ('\n' )
109
- if err != nil {
110
- if err == io .EOF {
111
- // Process any pending event before exit
135
+ select {
136
+ case <- ctx .Done ():
137
+ return
138
+ default :
139
+ line , err := br .ReadString ('\n' )
140
+ if err != nil {
141
+ if err == io .EOF {
142
+ // Process any pending event before exit
143
+ if event != "" && data != "" {
144
+ c .handleSSEEvent (event , data )
145
+ }
146
+ break
147
+ }
148
+ select {
149
+ case <- c .done :
150
+ return
151
+ default :
152
+ fmt .Printf ("SSE stream error: %v\n " , err )
153
+ return
154
+ }
155
+ }
156
+
157
+ // Remove only newline markers
158
+ line = strings .TrimRight (line , "\r \n " )
159
+ if line == "" {
160
+ // Empty line means end of event
112
161
if event != "" && data != "" {
113
162
c .handleSSEEvent (event , data )
163
+ event = ""
164
+ data = ""
114
165
}
115
- break
116
- }
117
- select {
118
- case <- c .done :
119
- return
120
- default :
121
- fmt .Printf ("SSE stream error: %v\n " , err )
122
- return
166
+ continue
123
167
}
124
- }
125
168
126
- // Remove only newline markers
127
- line = strings .TrimRight (line , "\r \n " )
128
- if line == "" {
129
- // Empty line means end of event
130
- if event != "" && data != "" {
131
- c .handleSSEEvent (event , data )
132
- event = ""
133
- data = ""
169
+ if strings .HasPrefix (line , "event:" ) {
170
+ event = strings .TrimSpace (strings .TrimPrefix (line , "event:" ))
171
+ } else if strings .HasPrefix (line , "data:" ) {
172
+ data = strings .TrimSpace (strings .TrimPrefix (line , "data:" ))
134
173
}
135
- continue
136
- }
137
-
138
- if strings .HasPrefix (line , "event:" ) {
139
- event = strings .TrimSpace (strings .TrimPrefix (line , "event:" ))
140
- } else if strings .HasPrefix (line , "data:" ) {
141
- data = strings .TrimSpace (strings .TrimPrefix (line , "data:" ))
142
174
}
143
175
}
144
176
}
@@ -269,6 +301,10 @@ func (c *SSEMCPClient) sendRequest(
269
301
}
270
302
271
303
req .Header .Set ("Content-Type" , "application/json" )
304
+ // set custom http headers
305
+ for k , v := range c .headers {
306
+ req .Header .Set (k , v )
307
+ }
272
308
273
309
resp , err := c .httpClient .Do (req )
274
310
if err != nil {
0 commit comments