From 8e5d8d7acce130d1235a3375304a4c401677779e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Mon, 15 May 2023 17:59:56 +0200 Subject: [PATCH 1/3] net/http/httptest: add support for ResponseController to ResponseRecorder CL #54136 (implemented in Go 1.20) added the "http".ResponseController type, which allows manipulating per-request timeouts. This is especially useful for programs managing long-running HTTP connections such as Mercure. However, testing HTTP handlers leveraging per-request timeouts is currently cumbersome (even if doable) because "net/http/httptest".ResponseRecorder isn't compatible yet with "http".ResponseController. This patch makes ResponseRecorder compatible with "http".ResponseController. All new methods are part of the contract that response types must honor to be usable with "http".ResponseController. NewRecorderWithDeadlineAwareRequest() is necessary to test read deadlines, as calling rw.SetReadDeadline() must change the deadline on the request body. Fixes #60229. --- src/net/http/httptest/example_test.go | 28 ++++++++ src/net/http/httptest/recorder.go | 93 +++++++++++++++++++++++++- src/net/http/httptest/recorder_test.go | 66 ++++++++++++++++++ 3 files changed, 185 insertions(+), 2 deletions(-) diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go index a6738432ebf306..4fc286c0e62e82 100644 --- a/src/net/http/httptest/example_test.go +++ b/src/net/http/httptest/example_test.go @@ -10,6 +10,8 @@ import ( "log" "net/http" "net/http/httptest" + "strings" + "time" ) func ExampleResponseRecorder() { @@ -34,6 +36,32 @@ func ExampleResponseRecorder() { // Hello World! } +func ExampleResponseRecorder_requestController() { + handler := func(w http.ResponseWriter, r *http.Request) { + rc := http.NewResponseController(w) + rc.SetReadDeadline(time.Now().Add(1 * time.Second)) + rc.SetWriteDeadline(time.Now().Add(3 * time.Second)) + + io.WriteString(w, "Hello, with deadlines!") + } + + req := httptest.NewRequest("GET", "http://example.com/bar", strings.NewReader("bar")) + w, req := httptest.NewRecorderWithDeadlineAwareRequest(req) + handler(w, req) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + fmt.Println(resp.StatusCode) + fmt.Println(resp.Header.Get("Content-Type")) + fmt.Println(string(body)) + + // Output: + // 200 + // text/html; charset=utf-8 + // Hello, with deadlines! +} + func ExampleServer() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, client") diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 1c1d8801558ed7..afe795b2c4f245 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -6,12 +6,15 @@ package httptest import ( "bytes" + "errors" "fmt" "io" "net/http" "net/textproto" + "os" "strconv" "strings" + "time" "golang.org/x/net/http/httpguts" ) @@ -42,9 +45,18 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool + // ReadDeadline is the write deadline that has been set using + // "net/http".ResponseController + ReadDeadline time.Time + + // WriteDeadline is the write deadline that has been set using + // "net/http".ResponseController + WriteDeadline time.Time + result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool + requestBody *deadlineBodyReader } // NewRecorder returns an initialized ResponseRecorder. @@ -56,6 +68,21 @@ func NewRecorder() *ResponseRecorder { } } +// NewRecorderWithDeadlineAwareRequest returns an initialized ResponseRecorder +// and wraps the body of the HTTP request passed as parameter in a special "io".ReadCloser +// that supports read deadlines. +// The request read deadline can be set using ResponseRecorder.SetReadDeadline +// and "http".ResponseController. +// The body of returned the HTTP request returns an error when the read deadline is reached. +// The read deadline can be inspected by reading ResponseRecorder.ReadDeadline. +func NewRecorderWithDeadlineAwareRequest(r *http.Request) (*ResponseRecorder, *http.Request) { + rw := NewRecorder() + rw.requestBody = &deadlineBodyReader{r.Body, time.Time{}} + r.Body = rw.requestBody + + return rw, r +} + // DefaultRemoteAddr is the default remote address to return in RemoteAddr if // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. const DefaultRemoteAddr = "1.2.3.4" @@ -105,6 +132,10 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { // Write implements http.ResponseWriter. The data in buf is written to // rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { + return 0, os.ErrDeadlineExceeded + } + rw.writeHeader(buf, "") if rw.Body != nil { rw.Body.Write(buf) @@ -115,6 +146,10 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { // WriteString implements io.StringWriter. The data in str is written // to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { + if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { + return 0, os.ErrDeadlineExceeded + } + rw.writeHeader(nil, str) if rw.Body != nil { rw.Body.WriteString(str) @@ -154,13 +189,54 @@ func (rw *ResponseRecorder) WriteHeader(code int) { rw.snapHeader = rw.HeaderMap.Clone() } -// Flush implements http.Flusher. To test whether Flush was +// FlushError allows using "net/http".ResponseController.Flush() +// with the recorder. To test whether Flush was // called, see rw.Flushed. -func (rw *ResponseRecorder) Flush() { +func (rw *ResponseRecorder) FlushError() error { + if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { + return os.ErrDeadlineExceeded + } + if !rw.wroteHeader { rw.WriteHeader(200) } rw.Flushed = true + + return nil +} + +// Flush implements http.Flusher. To test whether Flush was +// called, see rw.Flushed. +func (rw *ResponseRecorder) Flush() { + rw.FlushError() +} + +// SetReadDeadline allows using "net/http".ResponseController.SetReadDeadline() +// with the recorder. +// To retrieve the deadline, use rw.ReadDeadline. +// To use this method, be sure NewRecorderWithDeadlineAwareRequest +func (rw *ResponseRecorder) SetReadDeadline(deadline time.Time) error { + if rw.requestBody == nil { + return errors.New("The request has not been created using NewRecorderWithDeadlineAwareRequest()") + } + + if deadline.After(rw.ReadDeadline) { + rw.ReadDeadline = deadline + rw.requestBody.deadline = deadline + } + + return nil +} + +// SetWriteDeadline allows using "net/http".ResponseController.SetWriteDeadline() +// with the recorder. +// To retrieve the deadline, use rw.WriteDeadline. +func (rw *ResponseRecorder) SetWriteDeadline(deadline time.Time) error { + if deadline.After(rw.WriteDeadline) { + rw.WriteDeadline = deadline + } + + return nil } // Result returns the response generated by the handler. @@ -253,3 +329,16 @@ func parseContentLength(cl string) int64 { } return int64(n) } + +type deadlineBodyReader struct { + io.ReadCloser + deadline time.Time +} + +func (r *deadlineBodyReader) Read(p []byte) (n int, err error) { + if time.Now().After(r.deadline) { + return 0, os.ErrDeadlineExceeded + } + + return r.ReadCloser.Read(p) +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 4782eced43e6ce..45c8ebb5c3b1a1 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -5,10 +5,13 @@ package httptest import ( + "errors" "fmt" "io" "net/http" + "os" "testing" + "time" ) func TestRecorder(t *testing.T) { @@ -369,3 +372,66 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { }) } } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +func TestSetWriteDeadline(t *testing.T) { + rw := NewRecorder() + rc := http.NewResponseController(rw) + + expected := time.Now().Add(1 * time.Millisecond) + if err := rc.SetWriteDeadline(expected); err != nil { + t.Errorf(`"ResponseController.WriteDeadline(): got unexpected error %q`, err) + } + + if rw.WriteDeadline != expected { + t.Errorf(`"ResponseRecorder.WriteDeadline: got %q want %q`, rw.WriteDeadline, expected) + } + + if _, err := io.Copy(rw, neverEnding('a')); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf(`"ResponseRecorder.Write(): got %q want %q`, err, os.ErrDeadlineExceeded) + } + + if _, err := rw.WriteString("a"); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf(`"ResponseRecorder.WriteString(): got %q want %q`, err, os.ErrDeadlineExceeded) + } + + if err := rw.FlushError(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf(`"ResponseRecorder.FlushError(): got %q want %q`, err, os.ErrDeadlineExceeded) + } + + if b, _ := rw.Body.ReadByte(); b != 'a' { + t.Errorf(`"ResponseRecorder.Body starts with %q ; want "a"`, b) + } +} + +func TestSetReadDeadline(t *testing.T) { + req, _ := http.NewRequest("GET", "https://example.com", neverEnding('a')) + rw, req := NewRecorderWithDeadlineAwareRequest(req) + rc := http.NewResponseController(rw) + + expected := time.Now().Add(1 * time.Millisecond) + if err := rc.SetReadDeadline(expected); err != nil { + t.Errorf(`"ResponseController.SetReadDeadline(): got unexpected error %q`, err) + } + + if rw.ReadDeadline != expected { + t.Errorf(`"ResponseRecorder.ReadDeadline: got %q want %q`, rw.ReadDeadline, expected) + } + + data, err := io.ReadAll(req.Body) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf(`"ResponseRecorder.GetDeadlineRequestBody(): got %q want %q`, err, os.ErrDeadlineExceeded) + } + + if b := data[0]; b != 'a' { + t.Errorf(`Request Body starts with %q ; want "a"`, b) + } +} From 577f6cf626c2993728417b8e2bea5e4cbb582eb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Tue, 16 May 2023 18:50:25 +0200 Subject: [PATCH 2/3] Update src/net/http/httptest/recorder.go Co-authored-by: Pascal Borreli --- src/net/http/httptest/recorder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index afe795b2c4f245..eb4091c242f5dc 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -45,7 +45,7 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool - // ReadDeadline is the write deadline that has been set using + // ReadDeadline is the read deadline that has been set using // "net/http".ResponseController ReadDeadline time.Time From 82674e66c28f0ac2c6d889162ef034213887385b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Wed, 17 May 2023 12:01:26 +0200 Subject: [PATCH 3/3] don't enforce deadlines --- src/net/http/httptest/example_test.go | 2 +- src/net/http/httptest/recorder.go | 71 +++++--------------------- src/net/http/httptest/recorder_test.go | 52 +++---------------- 3 files changed, 22 insertions(+), 103 deletions(-) diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go index 4fc286c0e62e82..fda68a9c9a90b2 100644 --- a/src/net/http/httptest/example_test.go +++ b/src/net/http/httptest/example_test.go @@ -46,7 +46,7 @@ func ExampleResponseRecorder_requestController() { } req := httptest.NewRequest("GET", "http://example.com/bar", strings.NewReader("bar")) - w, req := httptest.NewRecorderWithDeadlineAwareRequest(req) + w := httptest.NewRecorder() handler(w, req) resp := w.Result() diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index eb4091c242f5dc..0d807b0c1df9fb 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -6,12 +6,10 @@ package httptest import ( "bytes" - "errors" "fmt" "io" "net/http" "net/textproto" - "os" "strconv" "strings" "time" @@ -45,18 +43,17 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool - // ReadDeadline is the read deadline that has been set using + // ReadDeadline is the last read deadline that has been set using // "net/http".ResponseController ReadDeadline time.Time - // WriteDeadline is the write deadline that has been set using + // WriteDeadline is the last write deadline that has been set using // "net/http".ResponseController WriteDeadline time.Time result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool - requestBody *deadlineBodyReader } // NewRecorder returns an initialized ResponseRecorder. @@ -68,21 +65,6 @@ func NewRecorder() *ResponseRecorder { } } -// NewRecorderWithDeadlineAwareRequest returns an initialized ResponseRecorder -// and wraps the body of the HTTP request passed as parameter in a special "io".ReadCloser -// that supports read deadlines. -// The request read deadline can be set using ResponseRecorder.SetReadDeadline -// and "http".ResponseController. -// The body of returned the HTTP request returns an error when the read deadline is reached. -// The read deadline can be inspected by reading ResponseRecorder.ReadDeadline. -func NewRecorderWithDeadlineAwareRequest(r *http.Request) (*ResponseRecorder, *http.Request) { - rw := NewRecorder() - rw.requestBody = &deadlineBodyReader{r.Body, time.Time{}} - r.Body = rw.requestBody - - return rw, r -} - // DefaultRemoteAddr is the default remote address to return in RemoteAddr if // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. const DefaultRemoteAddr = "1.2.3.4" @@ -132,10 +114,6 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { // Write implements http.ResponseWriter. The data in buf is written to // rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { - if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { - return 0, os.ErrDeadlineExceeded - } - rw.writeHeader(buf, "") if rw.Body != nil { rw.Body.Write(buf) @@ -146,10 +124,6 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { // WriteString implements io.StringWriter. The data in str is written // to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { - if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { - return 0, os.ErrDeadlineExceeded - } - rw.writeHeader(nil, str) if rw.Body != nil { rw.Body.WriteString(str) @@ -193,10 +167,6 @@ func (rw *ResponseRecorder) WriteHeader(code int) { // with the recorder. To test whether Flush was // called, see rw.Flushed. func (rw *ResponseRecorder) FlushError() error { - if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) { - return os.ErrDeadlineExceeded - } - if !rw.wroteHeader { rw.WriteHeader(200) } @@ -213,28 +183,28 @@ func (rw *ResponseRecorder) Flush() { // SetReadDeadline allows using "net/http".ResponseController.SetReadDeadline() // with the recorder. +// +// The deadline is recorded but is not enforced. +// To prevent flaky tests reads made after the deadline will work +// as if no deadline was set. +// // To retrieve the deadline, use rw.ReadDeadline. -// To use this method, be sure NewRecorderWithDeadlineAwareRequest func (rw *ResponseRecorder) SetReadDeadline(deadline time.Time) error { - if rw.requestBody == nil { - return errors.New("The request has not been created using NewRecorderWithDeadlineAwareRequest()") - } - - if deadline.After(rw.ReadDeadline) { - rw.ReadDeadline = deadline - rw.requestBody.deadline = deadline - } + rw.ReadDeadline = deadline return nil } // SetWriteDeadline allows using "net/http".ResponseController.SetWriteDeadline() // with the recorder. +// +// The deadline is recorded but is not enforced. +// To prevent flaky tests writes made after the deadline will work +// as if no deadline was set. +// // To retrieve the deadline, use rw.WriteDeadline. func (rw *ResponseRecorder) SetWriteDeadline(deadline time.Time) error { - if deadline.After(rw.WriteDeadline) { - rw.WriteDeadline = deadline - } + rw.WriteDeadline = deadline return nil } @@ -329,16 +299,3 @@ func parseContentLength(cl string) int64 { } return int64(n) } - -type deadlineBodyReader struct { - io.ReadCloser - deadline time.Time -} - -func (r *deadlineBodyReader) Read(p []byte) (n int, err error) { - if time.Now().After(r.deadline) { - return 0, os.ErrDeadlineExceeded - } - - return r.ReadCloser.Read(p) -} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 45c8ebb5c3b1a1..a23096a6a98568 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -5,11 +5,9 @@ package httptest import ( - "errors" "fmt" "io" "net/http" - "os" "testing" "time" ) @@ -373,20 +371,11 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { } } -type neverEnding byte - -func (b neverEnding) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(b) - } - return len(p), nil -} - -func TestSetWriteDeadline(t *testing.T) { +func TestRecorderSetWriteDeadline(t *testing.T) { rw := NewRecorder() rc := http.NewResponseController(rw) - expected := time.Now().Add(1 * time.Millisecond) + expected := time.Now().Add(1 * time.Second) if err := rc.SetWriteDeadline(expected); err != nil { t.Errorf(`"ResponseController.WriteDeadline(): got unexpected error %q`, err) } @@ -394,44 +383,17 @@ func TestSetWriteDeadline(t *testing.T) { if rw.WriteDeadline != expected { t.Errorf(`"ResponseRecorder.WriteDeadline: got %q want %q`, rw.WriteDeadline, expected) } - - if _, err := io.Copy(rw, neverEnding('a')); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf(`"ResponseRecorder.Write(): got %q want %q`, err, os.ErrDeadlineExceeded) - } - - if _, err := rw.WriteString("a"); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf(`"ResponseRecorder.WriteString(): got %q want %q`, err, os.ErrDeadlineExceeded) - } - - if err := rw.FlushError(); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf(`"ResponseRecorder.FlushError(): got %q want %q`, err, os.ErrDeadlineExceeded) - } - - if b, _ := rw.Body.ReadByte(); b != 'a' { - t.Errorf(`"ResponseRecorder.Body starts with %q ; want "a"`, b) - } } -func TestSetReadDeadline(t *testing.T) { - req, _ := http.NewRequest("GET", "https://example.com", neverEnding('a')) - rw, req := NewRecorderWithDeadlineAwareRequest(req) - rc := http.NewResponseController(rw) +func TestRecorderSetReadDeadline(t *testing.T) { + rw := NewRecorder() - expected := time.Now().Add(1 * time.Millisecond) - if err := rc.SetReadDeadline(expected); err != nil { - t.Errorf(`"ResponseController.SetReadDeadline(): got unexpected error %q`, err) + expected := time.Now().Add(1 * time.Second) + if err := rw.SetReadDeadline(expected); err != nil { + t.Errorf(`"ResponseRecorder.SetReadDeadline(): got unexpected error %q`, err) } if rw.ReadDeadline != expected { t.Errorf(`"ResponseRecorder.ReadDeadline: got %q want %q`, rw.ReadDeadline, expected) } - - data, err := io.ReadAll(req.Body) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf(`"ResponseRecorder.GetDeadlineRequestBody(): got %q want %q`, err, os.ErrDeadlineExceeded) - } - - if b := data[0]; b != 'a' { - t.Errorf(`Request Body starts with %q ; want "a"`, b) - } }