Skip to content

Commit 70ee525

Browse files
committed
net/http: fix Transport crash when abandoning dial which upgrades protos
When the Transport was creating an bound HTTP connection (protocol unknown initially) and then ends up deciding it doesn't need it, a goroutine sits around to clean up whatever the result was. That goroutine made the false assumption that the result was always an HTTP/1 connection or an error. It may also be an alternate protocol in which case the *persistConn.conn net.Conn field is nil, and the alt field is non-nil. Fixes #13839 Change-Id: Ia4972e5eb1ad53fa00410b3466d4129c753e0871 Reviewed-on: https://go-review.googlesource.com/18573 Reviewed-by: Russ Cox <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent 4525571 commit 70ee525

File tree

3 files changed

+118
-6
lines changed

3 files changed

+118
-6
lines changed

src/net/http/export_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ var (
2121
ExportServerNewConn = (*Server).newConn
2222
ExportCloseWriteAndWait = (*conn).closeWriteAndWait
2323
ExportErrRequestCanceled = errRequestCanceled
24+
ExportErrRequestCanceledConn = errRequestCanceledConn
2425
ExportServeFile = serveFile
2526
ExportHttp2ConfigureTransport = http2ConfigureTransport
2627
ExportHttp2ConfigureServer = http2ConfigureServer

src/net/http/transport.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,13 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
618618
return true
619619
}
620620

621-
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
621+
func (t *Transport) dial(network, addr string) (net.Conn, error) {
622622
if t.Dial != nil {
623-
return t.Dial(network, addr)
623+
c, err := t.Dial(network, addr)
624+
if c == nil && err == nil {
625+
err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
626+
}
627+
return c, err
624628
}
625629
return net.Dial(network, addr)
626630
}
@@ -682,10 +686,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
682686
return pc, nil
683687
case <-req.Cancel:
684688
handlePendingDial()
685-
return nil, errors.New("net/http: request canceled while waiting for connection")
689+
return nil, errRequestCanceledConn
686690
case <-cancelc:
687691
handlePendingDial()
688-
return nil, errors.New("net/http: request canceled while waiting for connection")
692+
return nil, errRequestCanceledConn
689693
}
690694
}
691695

@@ -705,6 +709,9 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
705709
if err != nil {
706710
return nil, err
707711
}
712+
if pconn.conn == nil {
713+
return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
714+
}
708715
if tc, ok := pconn.conn.(*tls.Conn); ok {
709716
cs := tc.ConnectionState()
710717
pconn.tlsState = &cs
@@ -1326,6 +1333,7 @@ func (e *httpError) Temporary() bool { return true }
13261333
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
13271334
var errClosed error = &httpError{err: "net/http: server closed connection before response was received"}
13281335
var errRequestCanceled = errors.New("net/http: request canceled")
1336+
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
13291337

13301338
func nop() {}
13311339

@@ -1502,9 +1510,19 @@ func (pc *persistConn) closeLocked(err error) {
15021510
}
15031511
pc.broken = true
15041512
if pc.closed == nil {
1505-
pc.conn.Close()
15061513
pc.closed = err
1507-
close(pc.closech)
1514+
if pc.alt != nil {
1515+
// Do nothing; can only get here via getConn's
1516+
// handlePendingDial's putOrCloseIdleConn when
1517+
// it turns out the abandoned connection in
1518+
// flight ended up negotiating an alternate
1519+
// protocol. We don't use the connection
1520+
// freelist for http2. That's done by the
1521+
// alternate protocol's RoundTripper.
1522+
} else {
1523+
pc.conn.Close()
1524+
close(pc.closech)
1525+
}
15081526
}
15091527
pc.mutateHeaderFunc = nil
15101528
}

src/net/http/transport_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
. "net/http"
2525
"net/http/httptest"
2626
"net/http/httputil"
27+
"net/http/internal"
2728
"net/url"
2829
"os"
2930
"reflect"
@@ -2939,6 +2940,98 @@ func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
29392940
}
29402941
}
29412942

2943+
// Issue 13839
2944+
func TestNoCrashReturningTransportAltConn(t *testing.T) {
2945+
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
2946+
if err != nil {
2947+
t.Fatal(err)
2948+
}
2949+
ln := newLocalListener(t)
2950+
defer ln.Close()
2951+
2952+
handledPendingDial := make(chan bool, 1)
2953+
SetPendingDialHooks(nil, func() { handledPendingDial <- true })
2954+
defer SetPendingDialHooks(nil, nil)
2955+
2956+
testDone := make(chan struct{})
2957+
defer close(testDone)
2958+
go func() {
2959+
tln := tls.NewListener(ln, &tls.Config{
2960+
NextProtos: []string{"foo"},
2961+
Certificates: []tls.Certificate{cert},
2962+
})
2963+
sc, err := tln.Accept()
2964+
if err != nil {
2965+
t.Error(err)
2966+
return
2967+
}
2968+
if err := sc.(*tls.Conn).Handshake(); err != nil {
2969+
t.Error(err)
2970+
return
2971+
}
2972+
<-testDone
2973+
sc.Close()
2974+
}()
2975+
2976+
addr := ln.Addr().String()
2977+
2978+
req, _ := NewRequest("GET", "https://fake.tld/", nil)
2979+
cancel := make(chan struct{})
2980+
req.Cancel = cancel
2981+
2982+
doReturned := make(chan bool, 1)
2983+
madeRoundTripper := make(chan bool, 1)
2984+
2985+
tr := &Transport{
2986+
DisableKeepAlives: true,
2987+
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
2988+
"foo": func(authority string, c *tls.Conn) RoundTripper {
2989+
madeRoundTripper <- true
2990+
return funcRoundTripper(func() {
2991+
t.Error("foo RoundTripper should not be called")
2992+
})
2993+
},
2994+
},
2995+
Dial: func(_, _ string) (net.Conn, error) {
2996+
panic("shouldn't be called")
2997+
},
2998+
DialTLS: func(_, _ string) (net.Conn, error) {
2999+
tc, err := tls.Dial("tcp", addr, &tls.Config{
3000+
InsecureSkipVerify: true,
3001+
NextProtos: []string{"foo"},
3002+
})
3003+
if err != nil {
3004+
return nil, err
3005+
}
3006+
if err := tc.Handshake(); err != nil {
3007+
return nil, err
3008+
}
3009+
close(cancel)
3010+
<-doReturned
3011+
return tc, nil
3012+
},
3013+
}
3014+
c := &Client{Transport: tr}
3015+
3016+
_, err = c.Do(req)
3017+
if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
3018+
t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
3019+
}
3020+
3021+
doReturned <- true
3022+
<-madeRoundTripper
3023+
<-handledPendingDial
3024+
}
3025+
3026+
var errFakeRoundTrip = errors.New("fake roundtrip")
3027+
3028+
type funcRoundTripper func()
3029+
3030+
func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
3031+
fn()
3032+
return nil, errFakeRoundTrip
3033+
}
3034+
29423035
func wantBody(res *Response, err error, want string) error {
29433036
if err != nil {
29443037
return err

0 commit comments

Comments
 (0)