Skip to content

Commit c69e686

Browse files
cmarcelobradfitz
authored andcommitted
net/http/httptest: record trailing headers in ResponseRecorder
Trailers() returns the headers that were set by the handler after the headers were written "to the wire" (in this case HeaderMap) and that were also specified in a proper header called "Trailer". Neither HeaderMap or trailerMap (used for Trailers()) are manipulated by the handler code, instead a third stagingMap is given to the handler. This avoid a reference kept by handler to affect the recorded results. If a handler just modify the header but doesn't call any Write or Flush method from ResponseWriter (or Flusher) interface, HeaderMap will not be updated. In this case, calling Flush in the recorder is enough to get the HeaderMap filled. Fixes #14531. Fixes #8857. Change-Id: I42842341ec3e95c7b87d7e6f178c65cd03d63cc3 Reviewed-on: https://go-review.googlesource.com/20047 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 4c8589c commit c69e686

File tree

2 files changed

+118
-12
lines changed

2 files changed

+118
-12
lines changed

src/net/http/httptest/recorder.go

+54-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ type ResponseRecorder struct {
1818
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
1919
Flushed bool
2020

21+
stagingMap http.Header // map that handlers manipulate to set headers
22+
trailerMap http.Header // lazily filled when Trailers() is called
23+
2124
wroteHeader bool
2225
}
2326

@@ -36,10 +39,10 @@ const DefaultRemoteAddr = "1.2.3.4"
3639

3740
// Header returns the response headers.
3841
func (rw *ResponseRecorder) Header() http.Header {
39-
m := rw.HeaderMap
42+
m := rw.stagingMap
4043
if m == nil {
4144
m = make(http.Header)
42-
rw.HeaderMap = m
45+
rw.stagingMap = m
4346
}
4447
return m
4548
}
@@ -59,16 +62,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
5962
str = str[:512]
6063
}
6164

62-
_, hasType := rw.HeaderMap["Content-Type"]
63-
hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
65+
m := rw.Header()
66+
67+
_, hasType := m["Content-Type"]
68+
hasTE := m.Get("Transfer-Encoding") != ""
6469
if !hasType && !hasTE {
6570
if b == nil {
6671
b = []byte(str)
6772
}
68-
if rw.HeaderMap == nil {
69-
rw.HeaderMap = make(http.Header)
70-
}
71-
rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
73+
m.Set("Content-Type", http.DetectContentType(b))
7274
}
7375

7476
rw.WriteHeader(200)
@@ -92,11 +94,21 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) {
9294
return len(str), nil
9395
}
9496

95-
// WriteHeader sets rw.Code.
97+
// WriteHeader sets rw.Code. After it is called, changing rw.Header
98+
// will not affect rw.HeaderMap.
9699
func (rw *ResponseRecorder) WriteHeader(code int) {
97-
if !rw.wroteHeader {
98-
rw.Code = code
99-
rw.wroteHeader = true
100+
if rw.wroteHeader {
101+
return
102+
}
103+
rw.Code = code
104+
rw.wroteHeader = true
105+
if rw.HeaderMap == nil {
106+
rw.HeaderMap = make(http.Header)
107+
}
108+
for k, vv := range rw.stagingMap {
109+
vv2 := make([]string, len(vv))
110+
copy(vv2, vv)
111+
rw.HeaderMap[k] = vv2
100112
}
101113
}
102114

@@ -107,3 +119,33 @@ func (rw *ResponseRecorder) Flush() {
107119
}
108120
rw.Flushed = true
109121
}
122+
123+
// Trailers returns any trailers set by the handler. It must be called
124+
// after the handler finished running.
125+
func (rw *ResponseRecorder) Trailers() http.Header {
126+
if rw.trailerMap != nil {
127+
return rw.trailerMap
128+
}
129+
trailers, ok := rw.HeaderMap["Trailer"]
130+
if !ok {
131+
rw.trailerMap = make(http.Header)
132+
return rw.trailerMap
133+
}
134+
rw.trailerMap = make(http.Header, len(trailers))
135+
for _, k := range trailers {
136+
switch k {
137+
case "Transfer-Encoding", "Content-Length", "Trailer":
138+
// Ignore since forbidden by RFC 2616 14.40.
139+
continue
140+
}
141+
k = http.CanonicalHeaderKey(k)
142+
vv, ok := rw.stagingMap[k]
143+
if !ok {
144+
continue
145+
}
146+
vv2 := make([]string, len(vv))
147+
copy(vv2, vv)
148+
rw.trailerMap[k] = vv2
149+
}
150+
return rw.trailerMap
151+
}

src/net/http/httptest/recorder_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@ func TestRecorder(t *testing.T) {
4747
return nil
4848
}
4949
}
50+
hasNotHeaders := func(keys ...string) checkFunc {
51+
return func(rec *ResponseRecorder) error {
52+
for _, k := range keys {
53+
_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
54+
if ok {
55+
return fmt.Errorf("unexpected header %s", k)
56+
}
57+
}
58+
return nil
59+
}
60+
}
61+
hasTrailer := func(key, want string) checkFunc {
62+
return func(rec *ResponseRecorder) error {
63+
if got := rec.Trailers().Get(key); got != want {
64+
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
65+
}
66+
return nil
67+
}
68+
}
69+
hasNotTrailers := func(keys ...string) checkFunc {
70+
return func(rec *ResponseRecorder) error {
71+
trailers := rec.Trailers()
72+
for _, k := range keys {
73+
_, ok := trailers[http.CanonicalHeaderKey(k)]
74+
if ok {
75+
return fmt.Errorf("unexpected trailer %s", k)
76+
}
77+
}
78+
return nil
79+
}
80+
}
5081

5182
tests := []struct {
5283
name string
@@ -130,6 +161,39 @@ func TestRecorder(t *testing.T) {
130161
},
131162
check(hasHeader("Content-Type", "text/html; charset=utf-8")),
132163
},
164+
{
165+
"Header is not changed after write",
166+
func(w http.ResponseWriter, r *http.Request) {
167+
hdr := w.Header()
168+
hdr.Set("Key", "correct")
169+
w.WriteHeader(200)
170+
hdr.Set("Key", "incorrect")
171+
},
172+
check(hasHeader("Key", "correct")),
173+
},
174+
{
175+
"Trailer headers are correctly recorded",
176+
func(w http.ResponseWriter, r *http.Request) {
177+
w.Header().Set("Non-Trailer", "correct")
178+
w.Header().Set("Trailer", "Trailer-A")
179+
w.Header().Add("Trailer", "Trailer-B")
180+
w.Header().Add("Trailer", "Trailer-C")
181+
io.WriteString(w, "<html>")
182+
w.Header().Set("Non-Trailer", "incorrect")
183+
w.Header().Set("Trailer-A", "valuea")
184+
w.Header().Set("Trailer-C", "valuec")
185+
w.Header().Set("Trailer-NotDeclared", "should be omitted")
186+
},
187+
check(
188+
hasStatus(200),
189+
hasHeader("Content-Type", "text/html; charset=utf-8"),
190+
hasHeader("Non-Trailer", "correct"),
191+
hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
192+
hasTrailer("Trailer-A", "valuea"),
193+
hasTrailer("Trailer-C", "valuec"),
194+
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
195+
),
196+
},
133197
}
134198
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
135199
for _, tt := range tests {

0 commit comments

Comments
 (0)