Skip to content

Commit 42f634a

Browse files
committed
net/http/httptest: add support for http.ResponseController to
ResponseRecorder CL #54136 (implemented in Go 1.20) added the "http".ResponseController type, which allows manipulating per-request timeouts. This is especially useful for programs managing long-running HTTP connections such as Mercure. However, testing HTTP handlers leveraging per-request timeouts is currently cumbersome (even if doable) because "net/http/httptest".ResponseRecorder isn't compatible yet with "http".ResponseController. This patch makes ResponseRecorder compatible with "http".ResponseController. All new methods are part of the contract that response types must honor to be usable with "http".ResponseController. NewRecorderWithDeadlineAwareRequest() is necessary to test read deadlines, as calling rw.SetReadDeadline() must change the deadline on the request body. Fixes #60229.
1 parent 1b896bf commit 42f634a

File tree

4 files changed

+187
-4
lines changed

4 files changed

+187
-4
lines changed

src/net/http/httptest/example_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"log"
1111
"net/http"
1212
"net/http/httptest"
13+
"strings"
14+
"time"
1315
)
1416

1517
func ExampleResponseRecorder() {
@@ -34,6 +36,32 @@ func ExampleResponseRecorder() {
3436
// <html><body>Hello World!</body></html>
3537
}
3638

39+
func ExampleResponseRecorder_requestController() {
40+
handler := func(w http.ResponseWriter, r *http.Request) {
41+
rc := http.NewResponseController(w)
42+
rc.SetReadDeadline(time.Now().Add(1 * time.Second))
43+
rc.SetWriteDeadline(time.Now().Add(3 * time.Second))
44+
45+
io.WriteString(w, "<html><body>Hello, with deadlines!</body></html>")
46+
}
47+
48+
req := httptest.NewRequest("GET", "http://example.com/bar", strings.NewReader("bar"))
49+
w, req := httptest.NewRecorderWithDeadlineAwareRequest(req)
50+
handler(w, req)
51+
52+
resp := w.Result()
53+
body, _ := io.ReadAll(resp.Body)
54+
55+
fmt.Println(resp.StatusCode)
56+
fmt.Println(resp.Header.Get("Content-Type"))
57+
fmt.Println(string(body))
58+
59+
// Output:
60+
// 200
61+
// text/html; charset=utf-8
62+
// <html><body>Hello, with deadlines!</body></html>
63+
}
64+
3765
func ExampleServer() {
3866
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3967
fmt.Fprintln(w, "Hello, client")

src/net/http/httptest/recorder.go

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ package httptest
66

77
import (
88
"bytes"
9+
"errors"
910
"fmt"
1011
"io"
1112
"net/http"
1213
"net/textproto"
14+
"os"
1315
"strconv"
1416
"strings"
17+
"time"
1518

1619
"golang.org/x/net/http/httpguts"
1720
)
@@ -42,9 +45,18 @@ type ResponseRecorder struct {
4245
// Flushed is whether the Handler called Flush.
4346
Flushed bool
4447

48+
// ReadDeadline is the write deadline that has been set using
49+
// "net/http".ResponseController
50+
ReadDeadline time.Time
51+
52+
// WriteDeadline is the write deadline that has been set using
53+
// "net/http".ResponseController
54+
WriteDeadline time.Time
55+
4556
result *http.Response // cache of Result's return value
4657
snapHeader http.Header // snapshot of HeaderMap at first Write
4758
wroteHeader bool
59+
requestBody *deadlineBodyReader
4860
}
4961

5062
// NewRecorder returns an initialized ResponseRecorder.
@@ -56,6 +68,21 @@ func NewRecorder() *ResponseRecorder {
5668
}
5769
}
5870

71+
// NewRecorderWithDeadlineAwareRequest returns an initialized ResponseRecorder
72+
// and wraps the body of the HTTP request passed as parameter in a special "io".ReadCloser
73+
// that supports read deadlines.
74+
// The request read deadline can be set using ResponseRecorder.SetReadDeadline
75+
// and "http".ResponseController.
76+
// The body of returned the HTTP request returns an error when the read deadline is reached.
77+
// The read deadline can be inspected by reading ResponseRecorder.ReadDeadline.
78+
func NewRecorderWithDeadlineAwareRequest(r *http.Request) (*ResponseRecorder, *http.Request) {
79+
rw := NewRecorder()
80+
rw.requestBody = &deadlineBodyReader{r.Body, time.Time{}}
81+
r.Body = rw.requestBody
82+
83+
return rw, r
84+
}
85+
5986
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
6087
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
6188
const DefaultRemoteAddr = "1.2.3.4"
@@ -105,6 +132,10 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
105132
// Write implements http.ResponseWriter. The data in buf is written to
106133
// rw.Body, if not nil.
107134
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
135+
if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) {
136+
return 0, os.ErrDeadlineExceeded
137+
}
138+
108139
rw.writeHeader(buf, "")
109140
if rw.Body != nil {
110141
rw.Body.Write(buf)
@@ -115,6 +146,10 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
115146
// WriteString implements io.StringWriter. The data in str is written
116147
// to rw.Body, if not nil.
117148
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
149+
if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) {
150+
return 0, os.ErrDeadlineExceeded
151+
}
152+
118153
rw.writeHeader(nil, str)
119154
if rw.Body != nil {
120155
rw.Body.WriteString(str)
@@ -154,13 +189,54 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
154189
rw.snapHeader = rw.HeaderMap.Clone()
155190
}
156191

157-
// Flush implements http.Flusher. To test whether Flush was
192+
// FlushError allows using "net/http".ResponseController.Flush()
193+
// with the recorder. To test whether Flush was
158194
// called, see rw.Flushed.
159-
func (rw *ResponseRecorder) Flush() {
195+
func (rw *ResponseRecorder) FlushError() error {
196+
if !rw.WriteDeadline.IsZero() && time.Now().After(rw.WriteDeadline) {
197+
return os.ErrDeadlineExceeded
198+
}
199+
160200
if !rw.wroteHeader {
161201
rw.WriteHeader(200)
162202
}
163203
rw.Flushed = true
204+
205+
return nil
206+
}
207+
208+
// Flush implements http.Flusher. To test whether Flush was
209+
// called, see rw.Flushed.
210+
func (rw *ResponseRecorder) Flush() {
211+
rw.FlushError()
212+
}
213+
214+
// SetReadDeadline allows using "net/http".ResponseController.SetReadDeadline()
215+
// with the recorder.
216+
// To retrieve the deadline, use rw.ReadDeadline.
217+
// To use this method, be sure NewRecorderWithDeadlineAwareRequest
218+
func (rw *ResponseRecorder) SetReadDeadline(deadline time.Time) error {
219+
if rw.requestBody == nil {
220+
return errors.New("The request has not been created using NewRecorderWithDeadlineAwareRequest()")
221+
}
222+
223+
if deadline.After(rw.ReadDeadline) {
224+
rw.ReadDeadline = deadline
225+
rw.requestBody.deadline = deadline
226+
}
227+
228+
return nil
229+
}
230+
231+
// SetWriteDeadline allows using "net/http".ResponseController.SetWriteDeadline()
232+
// with the recorder.
233+
// To retrieve the deadline, use rw.WriteDeadline.
234+
func (rw *ResponseRecorder) SetWriteDeadline(deadline time.Time) error {
235+
if deadline.After(rw.WriteDeadline) {
236+
rw.WriteDeadline = deadline
237+
}
238+
239+
return nil
164240
}
165241

166242
// Result returns the response generated by the handler.
@@ -253,3 +329,16 @@ func parseContentLength(cl string) int64 {
253329
}
254330
return int64(n)
255331
}
332+
333+
type deadlineBodyReader struct {
334+
io.ReadCloser
335+
deadline time.Time
336+
}
337+
338+
func (r *deadlineBodyReader) Read(p []byte) (n int, err error) {
339+
if time.Now().After(r.deadline) {
340+
return 0, os.ErrDeadlineExceeded
341+
}
342+
343+
return r.ReadCloser.Read(p)
344+
}

src/net/http/httptest/recorder_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
package httptest
66

77
import (
8+
"errors"
89
"fmt"
910
"io"
1011
"net/http"
12+
"os"
1113
"testing"
14+
"time"
1215
)
1316

1417
func TestRecorder(t *testing.T) {
@@ -369,3 +372,66 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
369372
})
370373
}
371374
}
375+
376+
type neverEnding byte
377+
378+
func (b neverEnding) Read(p []byte) (n int, err error) {
379+
for i := range p {
380+
p[i] = byte(b)
381+
}
382+
return len(p), nil
383+
}
384+
385+
func TestSetWriteDeadline(t *testing.T) {
386+
rw := NewRecorder()
387+
rc := http.NewResponseController(rw)
388+
389+
expected := time.Now().Add(1 * time.Millisecond)
390+
if err := rc.SetWriteDeadline(expected); err != nil {
391+
t.Errorf(`"ResponseController.WriteDeadline(): got unexpected error %q`, err)
392+
}
393+
394+
if rw.WriteDeadline != expected {
395+
t.Errorf(`"ResponseRecorder.WriteDeadline: got %q want %q`, rw.WriteDeadline, expected)
396+
}
397+
398+
if _, err := io.Copy(rw, neverEnding('a')); !errors.Is(err, os.ErrDeadlineExceeded) {
399+
t.Errorf(`"ResponseRecorder.Write(): got %q want %q`, err, os.ErrDeadlineExceeded)
400+
}
401+
402+
if _, err := rw.WriteString("a"); !errors.Is(err, os.ErrDeadlineExceeded) {
403+
t.Errorf(`"ResponseRecorder.WriteString(): got %q want %q`, err, os.ErrDeadlineExceeded)
404+
}
405+
406+
if err := rw.FlushError(); !errors.Is(err, os.ErrDeadlineExceeded) {
407+
t.Errorf(`"ResponseRecorder.FlushError(): got %q want %q`, err, os.ErrDeadlineExceeded)
408+
}
409+
410+
if b, _ := rw.Body.ReadByte(); b != 'a' {
411+
t.Errorf(`"ResponseRecorder.Body starts with %q ; want "a"`, b)
412+
}
413+
}
414+
415+
func TestSetReadDeadline(t *testing.T) {
416+
req, _ := http.NewRequest("GET", "https://example.com", neverEnding('a'))
417+
rw, req := NewRecorderWithDeadlineAwareRequest(req)
418+
rc := http.NewResponseController(rw)
419+
420+
expected := time.Now().Add(1 * time.Millisecond)
421+
if err := rc.SetReadDeadline(expected); err != nil {
422+
t.Errorf(`"ResponseController.SetReadDeadline(): got unexpected error %q`, err)
423+
}
424+
425+
if rw.ReadDeadline != expected {
426+
t.Errorf(`"ResponseRecorder.ReadDeadline: got %q want %q`, rw.ReadDeadline, expected)
427+
}
428+
429+
data, err := io.ReadAll(req.Body)
430+
if !errors.Is(err, os.ErrDeadlineExceeded) {
431+
t.Errorf(`"ResponseRecorder.GetDeadlineRequestBody(): got %q want %q`, err, os.ErrDeadlineExceeded)
432+
}
433+
434+
if b := data[0]; b != 'a' {
435+
t.Errorf(`Request Body starts with %q ; want "a"`, b)
436+
}
437+
}

src/net/http/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,11 +1389,11 @@ func (cw *chunkWriter) writeHeader(p []byte) {
13891389
case bdy.unreadDataSizeLocked() >= maxPostHandlerReadBytes:
13901390
tooBig = true
13911391
default:
1392-
discard = true
1392+
discard = false
13931393
}
13941394
bdy.mu.Unlock()
13951395
default:
1396-
discard = true
1396+
discard = false
13971397
}
13981398

13991399
if discard {

0 commit comments

Comments
 (0)