Skip to content

Commit b77217e

Browse files
authored
[disable-stream] Add WithDisableStreaming option to StreamableHTTP server to allow disabling streaming (#613)
* [disable-stream] Add `WithDisableStreaming` option to StreamableHTTP server to allow disabling streaming [Per the spec](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server), a server is allowed to respond with 405 in response to a request for streaming. In our use case, we do not need streaming, and do not want to support it at a network layer. * [disable-stream] bounce * [disable-stream] tests * [disable-stream] fix docstring * [disable-stream] logging * [disable-stream] bounce
1 parent ddf1299 commit b77217e

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

server/streamable_http.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
7272
}
7373
}
7474

75+
// WithDisableStreaming prevents the server from responding to GET requests with
76+
// a streaming response. Instead, it will respond with a 405 Method Not Allowed status.
77+
// This can be useful in scenarios where streaming is not desired or supported.
78+
// The default is false, meaning streaming is enabled.
79+
func WithDisableStreaming(disable bool) StreamableHTTPOption {
80+
return func(s *StreamableHTTPServer) {
81+
s.disableStreaming = disable
82+
}
83+
}
84+
7585
// WithHTTPContextFunc sets a function that will be called to customise the context
7686
// to the server using the incoming request.
7787
// This can be used to inject context values from headers, for example.
@@ -145,6 +155,7 @@ type StreamableHTTPServer struct {
145155
listenHeartbeatInterval time.Duration
146156
logger util.Logger
147157
sessionLogLevels *sessionLogLevelsStore
158+
disableStreaming bool
148159

149160
tlsCertFile string
150161
tlsKeyFile string
@@ -451,6 +462,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
451462
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
452463
// get request is for listening to notifications
453464
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
465+
if s.disableStreaming {
466+
s.logger.Infof("Rejected GET request: streaming is disabled (session: %s)", r.Header.Get(HeaderKeySessionID))
467+
http.Error(w, "Streaming is disabled on this server", http.StatusMethodNotAllowed)
468+
return
469+
}
454470

455471
sessionID := r.Header.Get(HeaderKeySessionID)
456472
// the specification didn't say we should validate the session id

server/streamable_http_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,100 @@ func TestStreamableHTTPServer_TLS(t *testing.T) {
12851285
})
12861286
}
12871287

1288+
func TestStreamableHTTPServer_WithDisableStreaming(t *testing.T) {
1289+
t.Run("WithDisableStreaming blocks GET requests", func(t *testing.T) {
1290+
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
1291+
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(true))
1292+
defer server.Close()
1293+
1294+
// Attempt a GET request (which should be blocked)
1295+
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
1296+
if err != nil {
1297+
t.Fatalf("Failed to create request: %v", err)
1298+
}
1299+
req.Header.Set("Content-Type", "text/event-stream")
1300+
1301+
resp, err := server.Client().Do(req)
1302+
if err != nil {
1303+
t.Fatalf("Failed to send request: %v", err)
1304+
}
1305+
defer resp.Body.Close()
1306+
1307+
// Verify the request is rejected with 405 Method Not Allowed
1308+
if resp.StatusCode != http.StatusMethodNotAllowed {
1309+
t.Errorf("Expected status 405 Method Not Allowed, got %d", resp.StatusCode)
1310+
}
1311+
1312+
// Verify the error message
1313+
bodyBytes, err := io.ReadAll(resp.Body)
1314+
if err != nil {
1315+
t.Fatalf("Failed to read response body: %v", err)
1316+
}
1317+
1318+
expectedMessage := "Streaming is disabled on this server"
1319+
if !strings.Contains(string(bodyBytes), expectedMessage) {
1320+
t.Errorf("Expected error message to contain '%s', got '%s'", expectedMessage, string(bodyBytes))
1321+
}
1322+
})
1323+
1324+
t.Run("POST requests still work with WithDisableStreaming", func(t *testing.T) {
1325+
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
1326+
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(true))
1327+
defer server.Close()
1328+
1329+
// POST requests should still work
1330+
resp, err := postJSON(server.URL, initRequest)
1331+
if err != nil {
1332+
t.Fatalf("Failed to send message: %v", err)
1333+
}
1334+
defer resp.Body.Close()
1335+
1336+
if resp.StatusCode != http.StatusOK {
1337+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
1338+
}
1339+
1340+
// Verify the response is valid
1341+
bodyBytes, _ := io.ReadAll(resp.Body)
1342+
var responseMessage jsonRPCResponse
1343+
if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil {
1344+
t.Fatalf("Failed to unmarshal response: %v", err)
1345+
}
1346+
if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION {
1347+
t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"])
1348+
}
1349+
})
1350+
1351+
t.Run("Streaming works when WithDisableStreaming is false", func(t *testing.T) {
1352+
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
1353+
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(false))
1354+
defer server.Close()
1355+
1356+
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
1357+
defer cancel()
1358+
1359+
// GET request should work when streaming is enabled
1360+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
1361+
if err != nil {
1362+
t.Fatalf("Failed to create request: %v", err)
1363+
}
1364+
req.Header.Set("Content-Type", "text/event-stream")
1365+
1366+
resp, err := server.Client().Do(req)
1367+
if err != nil {
1368+
t.Fatalf("Failed to send request: %v", err)
1369+
}
1370+
defer resp.Body.Close()
1371+
1372+
if resp.StatusCode != http.StatusOK {
1373+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
1374+
}
1375+
1376+
if resp.Header.Get("content-type") != "text/event-stream" {
1377+
t.Errorf("Expected content-type text/event-stream, got %s", resp.Header.Get("content-type"))
1378+
}
1379+
})
1380+
}
1381+
12881382
func postJSON(url string, bodyObject any) (*http.Response, error) {
12891383
jsonBody, _ := json.Marshal(bodyObject)
12901384
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

0 commit comments

Comments
 (0)