Skip to content

Commit 540bb53

Browse files
committed
http2: add Transport.WriteByteTimeout
Add a Transport-level knob to set a timeout for writes to net.Conns. If a write exceeds the timeout without making any progress (at least one byte written), the connection is closed. Fixes golang/go#48830. Change-Id: If0f57996d11c92bced30e07d1e238cbf8994acb4 Reviewed-on: https://go-review.googlesource.com/c/net/+/354431 Trust: Damien Neil <[email protected]> Run-TryBot: Damien Neil <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent d418f37 commit 540bb53

File tree

2 files changed

+100
-6
lines changed

2 files changed

+100
-6
lines changed

http2/transport.go

+30-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"net/http"
2525
"net/http/httptrace"
2626
"net/textproto"
27+
"os"
2728
"sort"
2829
"strconv"
2930
"strings"
@@ -130,6 +131,11 @@ type Transport struct {
130131
// Defaults to 15s.
131132
PingTimeout time.Duration
132133

134+
// WriteByteTimeout is the timeout after which the connection will be
135+
// closed no data can be written to it. The timeout begins when data is
136+
// available to write, and is extended whenever any bytes are written.
137+
WriteByteTimeout time.Duration
138+
133139
// CountError, if non-nil, is called on HTTP/2 transport errors.
134140
// It's intended to increment a metric for monitoring, such
135141
// as an expvar or Prometheus metric.
@@ -393,17 +399,31 @@ func (cs *clientStream) abortRequestBodyWrite() {
393399
}
394400

395401
type stickyErrWriter struct {
396-
w io.Writer
397-
err *error
402+
conn net.Conn
403+
timeout time.Duration
404+
err *error
398405
}
399406

400407
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
401408
if *sew.err != nil {
402409
return 0, *sew.err
403410
}
404-
n, err = sew.w.Write(p)
405-
*sew.err = err
406-
return
411+
for {
412+
if sew.timeout != 0 {
413+
sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
414+
}
415+
nn, err := sew.conn.Write(p[n:])
416+
n += nn
417+
if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
418+
// Keep extending the deadline so long as we're making progress.
419+
continue
420+
}
421+
if sew.timeout != 0 {
422+
sew.conn.SetWriteDeadline(time.Time{})
423+
}
424+
*sew.err = err
425+
return n, err
426+
}
407427
}
408428

409429
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -658,7 +678,11 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
658678

659679
// TODO: adjust this writer size to account for frame size +
660680
// MTU + crypto/tls record padding.
661-
cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
681+
cc.bw = bufio.NewWriter(stickyErrWriter{
682+
conn: c,
683+
timeout: t.WriteByteTimeout,
684+
err: &cc.werr,
685+
})
662686
cc.br = bufio.NewReader(c)
663687
cc.fr = NewFramer(cc.bw, cc.br)
664688
if t.CountError != nil {

http2/transport_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -5736,3 +5736,73 @@ func TestTransport300ResponseBody(t *testing.T) {
57365736
res.Body.Close()
57375737
pw.Close()
57385738
}
5739+
5740+
func TestTransportWriteByteTimeout(t *testing.T) {
5741+
st := newServerTester(t,
5742+
func(w http.ResponseWriter, r *http.Request) {},
5743+
optOnlyServer,
5744+
)
5745+
defer st.Close()
5746+
tr := &Transport{
5747+
TLSClientConfig: tlsConfigInsecure,
5748+
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
5749+
_, c := net.Pipe()
5750+
return c, nil
5751+
},
5752+
WriteByteTimeout: 1 * time.Millisecond,
5753+
}
5754+
defer tr.CloseIdleConnections()
5755+
c := &http.Client{Transport: tr}
5756+
5757+
_, err := c.Get(st.ts.URL)
5758+
if !errors.Is(err, os.ErrDeadlineExceeded) {
5759+
t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
5760+
}
5761+
}
5762+
5763+
type slowWriteConn struct {
5764+
net.Conn
5765+
hasWriteDeadline bool
5766+
}
5767+
5768+
func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
5769+
c.hasWriteDeadline = !t.IsZero()
5770+
return nil
5771+
}
5772+
5773+
func (c *slowWriteConn) Write(b []byte) (n int, err error) {
5774+
if c.hasWriteDeadline && len(b) > 1 {
5775+
n, err = c.Conn.Write(b[:1])
5776+
if err != nil {
5777+
return n, err
5778+
}
5779+
return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
5780+
}
5781+
return c.Conn.Write(b)
5782+
}
5783+
5784+
func TestTransportSlowWrites(t *testing.T) {
5785+
st := newServerTester(t,
5786+
func(w http.ResponseWriter, r *http.Request) {},
5787+
optOnlyServer,
5788+
)
5789+
defer st.Close()
5790+
tr := &Transport{
5791+
TLSClientConfig: tlsConfigInsecure,
5792+
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
5793+
cfg.InsecureSkipVerify = true
5794+
c, err := tls.Dial(network, addr, cfg)
5795+
return &slowWriteConn{Conn: c}, err
5796+
},
5797+
WriteByteTimeout: 1 * time.Millisecond,
5798+
}
5799+
defer tr.CloseIdleConnections()
5800+
c := &http.Client{Transport: tr}
5801+
5802+
const bodySize = 1 << 20
5803+
resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
5804+
if err != nil {
5805+
t.Fatal(err)
5806+
}
5807+
resp.Body.Close()
5808+
}

0 commit comments

Comments
 (0)