Skip to content

Commit 0b80659

Browse files
committed
net/http/httptest: restore historic ResponseRecorder.HeaderMap behavior
In Go versions 1 up to and including Go 1.6, ResponseRecorder.HeaderMap was both the map that handlers got access to, and was the map tests checked their results against. That did not mimic the behavior of the real HTTP server (Issue #8857), so HeaderMap was changed to be a snapshot at the first write in https://golang.org/cl/20047. But that broke cases where the Handler never did a write (#15560), so revert the behavior. Instead, introduce the ResponseWriter.Result method, returning an *http.Response. It subsumes ResponseWriter.Trailers which was added for Go 1.7 in CL 20047. Result().Header now contains the correct answer, and HeaderMap is unchanged in behavior from previous Go releases, so we don't break people's tests. People wanting the correct behavior can use ResponseWriter.Result. Fixes #15560 Updates #8857 Change-Id: I7ea9b56a6b843103784553d67f67847b5315b3d2 Reviewed-on: https://go-review.googlesource.com/23257 Reviewed-by: Damien Neil <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent 3b50adb commit 0b80659

File tree

2 files changed

+124
-39
lines changed

2 files changed

+124
-39
lines changed

src/net/http/httptest/recorder.go

+68-33
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package httptest
66

77
import (
88
"bytes"
9+
"io/ioutil"
910
"net/http"
1011
)
1112

@@ -17,9 +18,8 @@ type ResponseRecorder struct {
1718
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
1819
Flushed bool
1920

20-
stagingMap http.Header // map that handlers manipulate to set headers
21-
trailerMap http.Header // lazily filled when Trailers() is called
22-
21+
result *http.Response // cache of Result's return value
22+
snapHeader http.Header // snapshot of HeaderMap at first Write
2323
wroteHeader bool
2424
}
2525

@@ -38,10 +38,10 @@ const DefaultRemoteAddr = "1.2.3.4"
3838

3939
// Header returns the response headers.
4040
func (rw *ResponseRecorder) Header() http.Header {
41-
m := rw.stagingMap
41+
m := rw.HeaderMap
4242
if m == nil {
4343
m = make(http.Header)
44-
rw.stagingMap = m
44+
rw.HeaderMap = m
4545
}
4646
return m
4747
}
@@ -104,11 +104,17 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
104104
if rw.HeaderMap == nil {
105105
rw.HeaderMap = make(http.Header)
106106
}
107-
for k, vv := range rw.stagingMap {
107+
rw.snapHeader = cloneHeader(rw.HeaderMap)
108+
}
109+
110+
func cloneHeader(h http.Header) http.Header {
111+
h2 := make(http.Header, len(h))
112+
for k, vv := range h {
108113
vv2 := make([]string, len(vv))
109114
copy(vv2, vv)
110-
rw.HeaderMap[k] = vv2
115+
h2[k] = vv2
111116
}
117+
return h2
112118
}
113119

114120
// Flush sets rw.Flushed to true.
@@ -119,32 +125,61 @@ func (rw *ResponseRecorder) Flush() {
119125
rw.Flushed = true
120126
}
121127

122-
// Trailers returns any trailers set by the handler. It must be called
123-
// after the handler finished running.
124-
func (rw *ResponseRecorder) Trailers() http.Header {
125-
if rw.trailerMap != nil {
126-
return rw.trailerMap
127-
}
128-
trailers, ok := rw.HeaderMap["Trailer"]
129-
if !ok {
130-
rw.trailerMap = make(http.Header)
131-
return rw.trailerMap
132-
}
133-
rw.trailerMap = make(http.Header, len(trailers))
134-
for _, k := range trailers {
135-
switch k {
136-
case "Transfer-Encoding", "Content-Length", "Trailer":
137-
// Ignore since forbidden by RFC 2616 14.40.
138-
continue
139-
}
140-
k = http.CanonicalHeaderKey(k)
141-
vv, ok := rw.stagingMap[k]
142-
if !ok {
143-
continue
128+
// Result returns the response generated by the handler.
129+
//
130+
// The returned Response will have at least its StatusCode,
131+
// Header, Body, and optionally Trailer populated.
132+
// More fields may be populated in the future, so callers should
133+
// not DeepEqual the result in tests.
134+
//
135+
// The Response.Header is a snapshot of the headers at the time of the
136+
// first write call, or at the time of this call, if the handler never
137+
// did a write.
138+
//
139+
// Result must only be called after the handler has finished running.
140+
func (rw *ResponseRecorder) Result() *http.Response {
141+
if rw.result != nil {
142+
return rw.result
143+
}
144+
if rw.snapHeader == nil {
145+
rw.snapHeader = cloneHeader(rw.HeaderMap)
146+
}
147+
res := &http.Response{
148+
Proto: "HTTP/1.1",
149+
ProtoMajor: 1,
150+
ProtoMinor: 1,
151+
StatusCode: rw.Code,
152+
Header: rw.snapHeader,
153+
}
154+
rw.result = res
155+
if res.StatusCode == 0 {
156+
res.StatusCode = 200
157+
}
158+
res.Status = http.StatusText(res.StatusCode)
159+
if rw.Body != nil {
160+
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
161+
}
162+
163+
if trailers, ok := rw.snapHeader["Trailer"]; ok {
164+
res.Trailer = make(http.Header, len(trailers))
165+
for _, k := range trailers {
166+
// TODO: use http2.ValidTrailerHeader, but we can't
167+
// get at it easily because it's bundled into net/http
168+
// unexported. This is good enough for now:
169+
switch k {
170+
case "Transfer-Encoding", "Content-Length", "Trailer":
171+
// Ignore since forbidden by RFC 2616 14.40.
172+
continue
173+
}
174+
k = http.CanonicalHeaderKey(k)
175+
vv, ok := rw.HeaderMap[k]
176+
if !ok {
177+
continue
178+
}
179+
vv2 := make([]string, len(vv))
180+
copy(vv2, vv)
181+
res.Trailer[k] = vv2
144182
}
145-
vv2 := make([]string, len(vv))
146-
copy(vv2, vv)
147-
rw.trailerMap[k] = vv2
148183
}
149-
return rw.trailerMap
184+
return res
150185
}

src/net/http/httptest/recorder_test.go

+56-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ func TestRecorder(t *testing.T) {
2323
return nil
2424
}
2525
}
26+
hasResultStatus := func(wantCode int) checkFunc {
27+
return func(rec *ResponseRecorder) error {
28+
if rec.Result().StatusCode != wantCode {
29+
return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
30+
}
31+
return nil
32+
}
33+
}
2634
hasContents := func(want string) checkFunc {
2735
return func(rec *ResponseRecorder) error {
2836
if rec.Body.String() != want {
@@ -39,36 +47,44 @@ func TestRecorder(t *testing.T) {
3947
return nil
4048
}
4149
}
42-
hasHeader := func(key, want string) checkFunc {
50+
hasOldHeader := func(key, want string) checkFunc {
4351
return func(rec *ResponseRecorder) error {
4452
if got := rec.HeaderMap.Get(key); got != want {
45-
return fmt.Errorf("header %s = %q; want %q", key, got, want)
53+
return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
54+
}
55+
return nil
56+
}
57+
}
58+
hasHeader := func(key, want string) checkFunc {
59+
return func(rec *ResponseRecorder) error {
60+
if got := rec.Result().Header.Get(key); got != want {
61+
return fmt.Errorf("final header %s = %q; want %q", key, got, want)
4662
}
4763
return nil
4864
}
4965
}
5066
hasNotHeaders := func(keys ...string) checkFunc {
5167
return func(rec *ResponseRecorder) error {
5268
for _, k := range keys {
53-
_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
69+
v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
5470
if ok {
55-
return fmt.Errorf("unexpected header %s", k)
71+
return fmt.Errorf("unexpected header %s with value %q", k, v)
5672
}
5773
}
5874
return nil
5975
}
6076
}
6177
hasTrailer := func(key, want string) checkFunc {
6278
return func(rec *ResponseRecorder) error {
63-
if got := rec.Trailers().Get(key); got != want {
79+
if got := rec.Result().Trailer.Get(key); got != want {
6480
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
6581
}
6682
return nil
6783
}
6884
}
6985
hasNotTrailers := func(keys ...string) checkFunc {
7086
return func(rec *ResponseRecorder) error {
71-
trailers := rec.Trailers()
87+
trailers := rec.Result().Trailer
7288
for _, k := range keys {
7389
_, ok := trailers[http.CanonicalHeaderKey(k)]
7490
if ok {
@@ -194,6 +210,40 @@ func TestRecorder(t *testing.T) {
194210
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
195211
),
196212
},
213+
{
214+
"Header set without any write", // Issue 15560
215+
func(w http.ResponseWriter, r *http.Request) {
216+
w.Header().Set("X-Foo", "1")
217+
218+
// Simulate somebody using
219+
// new(ResponseRecorder) instead of
220+
// using the constructor which sets
221+
// this to 200
222+
w.(*ResponseRecorder).Code = 0
223+
},
224+
check(
225+
hasOldHeader("X-Foo", "1"),
226+
hasStatus(0),
227+
hasHeader("X-Foo", "1"),
228+
hasResultStatus(200),
229+
),
230+
},
231+
{
232+
"HeaderMap vs FinalHeaders", // more for Issue 15560
233+
func(w http.ResponseWriter, r *http.Request) {
234+
h := w.Header()
235+
h.Set("X-Foo", "1")
236+
w.Write([]byte("hi"))
237+
h.Set("X-Foo", "2")
238+
h.Set("X-Bar", "2")
239+
},
240+
check(
241+
hasOldHeader("X-Foo", "2"),
242+
hasOldHeader("X-Bar", "2"),
243+
hasHeader("X-Foo", "1"),
244+
hasNotHeaders("X-Bar"),
245+
),
246+
},
197247
}
198248
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
199249
for _, tt := range tests {

0 commit comments

Comments
 (0)