Skip to content

Commit baa7153

Browse files
authored
Http headers (#480)
* Add HTTP headers to Request - Pass headers from request via ctx - Add headers from ctx to CallToolRequest to further pass it to handlers. * Add http-headers field to all Request Structs * Use request_handler.tmpl to add http headers * Add tests and documentation for Header passthrough * Chore: fix comment in CallToolRequest struct * Fix Race condition while setting context
1 parent ffea75f commit baa7153

File tree

12 files changed

+367
-1
lines changed

12 files changed

+367
-1
lines changed

mcp/prompts.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package mcp
22

3+
import "net/http"
4+
35
/* Prompts */
46

57
// ListPromptsRequest is sent from the client to request a list of prompts and
68
// prompt templates the server has.
79
type ListPromptsRequest struct {
810
PaginatedRequest
11+
Header http.Header `json:"-"`
912
}
1013

1114
// ListPromptsResult is the server's response to a prompts/list request from
@@ -20,6 +23,7 @@ type ListPromptsResult struct {
2023
type GetPromptRequest struct {
2124
Request
2225
Params GetPromptParams `json:"params"`
26+
Header http.Header `json:"-"`
2327
}
2428

2529
type GetPromptParams struct {

mcp/tools.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"net/http"
78
"reflect"
89
"strconv"
910
)
@@ -14,6 +15,7 @@ var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSc
1415
// server has.
1516
type ListToolsRequest struct {
1617
PaginatedRequest
18+
Header http.Header `json:"-"`
1719
}
1820

1921
// ListToolsResult is the server's response to a tools/list request from the
@@ -45,6 +47,7 @@ type CallToolResult struct {
4547
// CallToolRequest is used by the client to invoke a tool provided by the server.
4648
type CallToolRequest struct {
4749
Request
50+
Header http.Header `json:"-"` // HTTP headers from the original request
4851
Params CallToolParams `json:"params"`
4952
}
5053

mcp/types.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strconv"
1010

1111
"github.com/yosida95/uritemplate/v3"
12+
"net/http"
1213
)
1314

1415
type MCPMethod string
@@ -399,6 +400,7 @@ type CancelledNotificationParams struct {
399400
type InitializeRequest struct {
400401
Request
401402
Params InitializeParams `json:"params"`
403+
Header http.Header `json:"-"`
402404
}
403405

404406
type InitializeParams struct {
@@ -489,6 +491,7 @@ type Implementation struct {
489491
// or else may be disconnected.
490492
type PingRequest struct {
491493
Request
494+
Header http.Header `json:"-"`
492495
}
493496

494497
/* Progress notifications */
@@ -541,6 +544,7 @@ type PaginatedResult struct {
541544
// the server has.
542545
type ListResourcesRequest struct {
543546
PaginatedRequest
547+
Header http.Header `json:"-"`
544548
}
545549

546550
// ListResourcesResult is the server's response to a resources/list request
@@ -554,6 +558,7 @@ type ListResourcesResult struct {
554558
// resource templates the server has.
555559
type ListResourceTemplatesRequest struct {
556560
PaginatedRequest
561+
Header http.Header `json:"-"`
557562
}
558563

559564
// ListResourceTemplatesResult is the server's response to a
@@ -567,6 +572,7 @@ type ListResourceTemplatesResult struct {
567572
// specific resource URI.
568573
type ReadResourceRequest struct {
569574
Request
575+
Header http.Header `json:"-"`
570576
Params ReadResourceParams `json:"params"`
571577
}
572578

@@ -598,6 +604,7 @@ type ResourceListChangedNotification struct {
598604
type SubscribeRequest struct {
599605
Request
600606
Params SubscribeParams `json:"params"`
607+
Header http.Header `json:"-"`
601608
}
602609

603610
type SubscribeParams struct {
@@ -612,6 +619,7 @@ type SubscribeParams struct {
612619
type UnsubscribeRequest struct {
613620
Request
614621
Params UnsubscribeParams `json:"params"`
622+
Header http.Header `json:"-"`
615623
}
616624

617625
type UnsubscribeParams struct {
@@ -717,6 +725,7 @@ func (BlobResourceContents) isResourceContents() {}
717725
type SetLevelRequest struct {
718726
Request
719727
Params SetLevelParams `json:"params"`
728+
Header http.Header `json:"-"`
720729
}
721730

722731
type SetLevelParams struct {
@@ -980,6 +989,7 @@ type ModelHint struct {
980989
type CompleteRequest struct {
981990
Request
982991
Params CompleteParams `json:"params"`
992+
Header http.Header `json:"-"`
983993
}
984994

985995
type CompleteParams struct {
@@ -1032,6 +1042,7 @@ type PromptReference struct {
10321042
// structure or access specific locations that the client has permission to read from.
10331043
type ListRootsRequest struct {
10341044
Request
1045+
Header http.Header `json:"-"`
10351046
}
10361047

10371048
// ListRootsResult is the client's response to a roots/list request from the server.

server/ctx.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package server
2+
3+
type contextKey int
4+
5+
const (
6+
// This const is used as key for context value lookup
7+
requestHeader contextKey = iota
8+
)

server/internal/gen/request_handler.go.tmpl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ func (s *MCPServer) HandleMessage(
7272
)
7373
}
7474

75+
// Get request header from ctx
76+
h := ctx.Value(requestHeader)
77+
headers, ok := h.(http.Header)
78+
79+
if headers == nil || !ok {
80+
headers = make(http.Header)
81+
}
82+
7583
switch baseMessage.Method {
7684
{{- range .}}
7785
case mcp.{{.MethodName}}:
@@ -90,6 +98,7 @@ func (s *MCPServer) HandleMessage(
9098
err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
9199
}
92100
} else {
101+
request.Header = headers
93102
s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request)
94103
result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request)
95104
}

server/request_handler.go

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/sse.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
504504
w.WriteHeader(http.StatusAccepted)
505505

506506
// Create a new context for handling the message that will be canceled when the message handling is done
507-
messageCtx, cancel := context.WithCancel(detachedCtx)
507+
messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header)
508+
messageCtx, cancel := context.WithCancel(messageCtx)
508509

509510
go func(ctx context.Context) {
510511
defer cancel()

0 commit comments

Comments
 (0)