Skip to content

Commit 6b9da27

Browse files
committed
net/http/httptest: add support for 1XX responses
The existing implementation doesn't allow tracing 1xx responses. This patch allows using net/http/httptrace to inspect 1XX responses. Updates golang#26089.
1 parent 9160e15 commit 6b9da27

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

src/net/http/httptest/recorder.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"net/http"
12+
"net/http/httptrace"
1213
"net/textproto"
1314
"strconv"
1415
"strings"
@@ -42,6 +43,9 @@ type ResponseRecorder struct {
4243
// Flushed is whether the Handler called Flush.
4344
Flushed bool
4445

46+
// ClientTrace is used to trace 1XX responses
47+
ClientTrace *httptrace.ClientTrace
48+
4549
result *http.Response // cache of Result's return value
4650
snapHeader http.Header // snapshot of HeaderMap at first Write
4751
wroteHeader bool
@@ -146,6 +150,20 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
146150
}
147151

148152
checkWriteHeaderCode(code)
153+
154+
if rw.ClientTrace != nil && code >= 100 && code < 200 {
155+
if code == 100 {
156+
rw.ClientTrace.Got100Continue()
157+
}
158+
// treat 101 as a terminal status, see issue 26161
159+
if code != http.StatusSwitchingProtocols {
160+
if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil {
161+
panic(err)
162+
}
163+
return
164+
}
165+
}
166+
149167
rw.Code = code
150168
rw.wroteHeader = true
151169
if rw.HeaderMap == nil {

src/net/http/httptest/recorder_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"fmt"
99
"io"
1010
"net/http"
11+
"net/http/httptrace"
12+
"net/textproto"
1113
"testing"
1214
)
1315

@@ -369,3 +371,58 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
369371
})
370372
}
371373
}
374+
375+
func TestRecorderClientTrace(t *testing.T) {
376+
handler := func(rw http.ResponseWriter, _ *http.Request) {
377+
rw.WriteHeader(http.StatusContinue)
378+
379+
rw.Header().Add("Foo", "bar")
380+
rw.WriteHeader(http.StatusEarlyHints)
381+
382+
rw.Header().Add("Baz", "bat")
383+
}
384+
385+
var received100, received103 bool
386+
trace := &httptrace.ClientTrace{
387+
Got100Continue: func() {
388+
received100 = true
389+
},
390+
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
391+
switch code {
392+
case http.StatusContinue:
393+
case http.StatusEarlyHints:
394+
received103 = true
395+
if header.Get("Foo") != "bar" {
396+
t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo"))
397+
}
398+
if header.Get("Bar") != "" {
399+
t.Error("Unexpected Bar header")
400+
}
401+
default:
402+
t.Errorf("Unexpected status code %d", code)
403+
}
404+
405+
return nil
406+
},
407+
}
408+
409+
r, _ := http.NewRequest("GET", "http://example.org/", nil)
410+
rw := NewRecorder()
411+
rw.ClientTrace = trace
412+
handler(rw, r)
413+
414+
if !received100 {
415+
t.Error("Got100Continue not called")
416+
}
417+
if !received103 {
418+
t.Error("103 request not received")
419+
}
420+
421+
header := rw.Result().Header
422+
if header.Get("Foo") != "bar" {
423+
t.Errorf("Expected Foo=bar, got %s", header.Get("Foo"))
424+
}
425+
if header.Get("Baz") != "bat" {
426+
t.Errorf("Expected Baz=bat, got %s", header.Get("Baz"))
427+
}
428+
}

0 commit comments

Comments
 (0)