Skip to content

Commit c994f08

Browse files
net/http: return an error if Write is called after WriteTimeout
response.Write now returns an error if the called happened after the configured server WriteTimeout. Fixes #21389
1 parent bd5595d commit c994f08

File tree

2 files changed

+81
-4
lines changed

2 files changed

+81
-4
lines changed

src/net/http/serve_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,47 @@ func TestOnlyWriteTimeout(t *testing.T) {
973973
}
974974
}
975975

976+
func TestErrorAfterWriteTimeout(t *testing.T) {
977+
setParallel(t)
978+
defer afterTest(t)
979+
writeTimeout := 200 * time.Millisecond
980+
var afterTimeoutErrc = make(chan error, 1)
981+
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
982+
time.Sleep(2 * writeTimeout)
983+
984+
_, err := w.Write([]byte("test"))
985+
afterTimeoutErrc <- err
986+
}))
987+
ts.Config.WriteTimeout = writeTimeout
988+
ts.Start()
989+
defer ts.Close()
990+
991+
c := ts.Client()
992+
993+
errc := make(chan error, 1)
994+
go func() {
995+
res, err := c.Get(ts.URL)
996+
if err != nil {
997+
errc <- err
998+
return
999+
}
1000+
_, err = io.Copy(io.Discard, res.Body)
1001+
res.Body.Close()
1002+
errc <- err
1003+
}()
1004+
select {
1005+
case err := <-errc:
1006+
if err == nil {
1007+
t.Errorf("expected an error from Get request")
1008+
}
1009+
case <-time.After(10 * time.Second):
1010+
t.Fatal("timeout waiting for Get error")
1011+
}
1012+
if err := <-afterTimeoutErrc; err == nil {
1013+
t.Error("expected write error after timeout")
1014+
}
1015+
}
1016+
9761017
// trackLastConnListener tracks the last net.Conn that was accepted.
9771018
type trackLastConnListener struct {
9781019
net.Listener

src/net/http/server.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) {
395395
return
396396
}
397397

398-
func (cw *chunkWriter) flush() {
398+
func (cw *chunkWriter) flush() error {
399399
if !cw.wroteHeader {
400400
cw.writeHeader(nil)
401401
}
402-
cw.res.conn.bufw.Flush()
402+
return cw.res.conn.bufw.Flush()
403403
}
404404

405405
func (cw *chunkWriter) close() {
@@ -443,6 +443,14 @@ type response struct {
443443
w *bufio.Writer // buffers output in chunks to chunkWriter
444444
cw chunkWriter
445445

446+
// writeTimeoutTimer is set when the server has a WriteTimeout configured
447+
// and triggers when a write timed out
448+
// writeDeadline is used to enable direct flushing of writes after the
449+
// timeout so writers receive an error and can handle it
450+
writeTimeoutTimer *time.Timer
451+
writeDeadline bool
452+
writeDeadlineMu sync.Mutex
453+
446454
// handlerHeader is the Header that Handlers get access to,
447455
// which may be retained and mutated even after WriteHeader.
448456
// handlerHeader is copied into cw.header at WriteHeader
@@ -1045,6 +1053,9 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
10451053
if isH2Upgrade {
10461054
w.closeAfterReply = true
10471055
}
1056+
if d := c.server.WriteTimeout; d > 0 {
1057+
w.setWriteTimeout(d)
1058+
}
10481059
w.cw.res = w
10491060
w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
10501061
return w, nil
@@ -1590,6 +1601,16 @@ func (w *response) WriteString(data string) (n int, err error) {
15901601
return w.write(len(data), nil, data)
15911602
}
15921603

1604+
// setWriteTimeout lets the response know if the write was supposed to be
1605+
// timed out, timed out requests will force be flushed on every write
1606+
func (w *response) setWriteTimeout(d time.Duration) {
1607+
w.writeTimeoutTimer = time.AfterFunc(d, func() {
1608+
w.writeDeadlineMu.Lock()
1609+
w.writeDeadline = true
1610+
w.writeDeadlineMu.Unlock()
1611+
})
1612+
}
1613+
15931614
// either dataB or dataS is non-zero.
15941615
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
15951616
if w.conn.hijacked() {
@@ -1625,10 +1646,22 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
16251646
return 0, ErrContentLength
16261647
}
16271648
if dataB != nil {
1628-
return w.w.Write(dataB)
1649+
n, err = w.w.Write(dataB)
16291650
} else {
1630-
return w.w.WriteString(dataS)
1651+
n, err = w.w.WriteString(dataS)
16311652
}
1653+
if err == nil {
1654+
w.writeDeadlineMu.Lock()
1655+
wd := w.writeDeadline
1656+
w.writeDeadlineMu.Unlock()
1657+
1658+
if wd {
1659+
// r.Flush returns no errors, flush manually
1660+
w.w.Flush()
1661+
err = w.cw.flush()
1662+
}
1663+
}
1664+
return
16321665
}
16331666

16341667
func (w *response) finishRequest() {
@@ -1643,6 +1676,9 @@ func (w *response) finishRequest() {
16431676
w.cw.close()
16441677
w.conn.bufw.Flush()
16451678

1679+
if w.writeTimeoutTimer != nil {
1680+
w.writeTimeoutTimer.Stop()
1681+
}
16461682
w.conn.r.abortPendingRead()
16471683

16481684
// Close the body (regardless of w.closeAfterReply) so we can

0 commit comments

Comments
 (0)