diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 1c1d8801558ed7..4e38f1f4496441 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -16,6 +16,17 @@ import ( "golang.org/x/net/http/httpguts" ) +// InformationalResponse is an HTTP response sent with a [1xx status code]. +// +// [1xx status code]: https://httpwg.org/specs/rfc9110.html#status.1xx +type InformationalResponse struct { + // Code is the 1xx HTTP response code of this informational response. + Code int + + // Header contains the headers of this informational response. + Header http.Header +} + // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { @@ -27,6 +38,9 @@ type ResponseRecorder struct { // method. Code int + // Informational HTTP responses (1xx status code) sent before the main response. + InformationalResponses []InformationalResponse + // HeaderMap contains the headers explicitly set by the Handler. // It is an internal detail. // @@ -146,11 +160,20 @@ func (rw *ResponseRecorder) WriteHeader(code int) { } checkWriteHeaderCode(code) - rw.Code = code - rw.wroteHeader = true + if rw.HeaderMap == nil { rw.HeaderMap = make(http.Header) } + + if code >= 100 && code < 200 { + ir := InformationalResponse{code, rw.HeaderMap.Clone()} + rw.InformationalResponses = append(rw.InformationalResponses, ir) + + return + } + + rw.Code = code + rw.wroteHeader = true rw.snapHeader = rw.HeaderMap.Clone() } diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 4782eced43e6ce..5fd48be78c17a9 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" ) @@ -123,6 +124,15 @@ func TestRecorder(t *testing.T) { return nil } } + hasInformationalResponses := func(ir []InformationalResponse) checkFunc { + return func(rec *ResponseRecorder) error { + if !reflect.DeepEqual(ir, rec.InformationalResponses) { + return fmt.Errorf("InformationalResponses = %v; want %v", rec.InformationalResponses, ir) + } + + return nil + } + } for _, tt := range [...]struct { name string @@ -294,6 +304,26 @@ func TestRecorder(t *testing.T) { check(hasResultContents("")), // check we don't crash reading the body }, + { + "1xx status code", + func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusContinue) + rw.Header().Add("Foo", "bar") + + rw.WriteHeader(http.StatusEarlyHints) + rw.Header().Add("Baz", "bat") + + rw.Header().Del("Foo") + }, + check( + hasInformationalResponses([]InformationalResponse{ + InformationalResponse{100, http.Header{}}, + InformationalResponse{103, http.Header{"Foo": []string{"bar"}}}, + }), + hasHeader("Baz", "bat"), + hasNotHeaders("Foo"), + ), + }, } { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest("GET", "http://foo.com/", nil)