Skip to content

net/http: close accepted connection when Closed #53118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions src/net/http/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6757,3 +6757,124 @@ func TestProcessing(t *testing.T) {
t.Errorf("unexpected response; got %q; should start by %q", got, expected)
}
}

type wrapConnListener struct {
net.Listener
wrap func(net.Conn) net.Conn
}

func (l *wrapConnListener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept()
if err != nil {
return nil, err
}
c = l.wrap(c)
return
}

type trackCloseConn struct {
net.Conn
closes chan<- error
}

func (p *trackCloseConn) Close() error {
err := p.Conn.Close()
p.closes <- err
return err
}

// Issue 48642: close accepted connection when Closed
func TestAcceptedConnectionWhenClosed(t *testing.T) {
defer afterTest(t)
connCloses := make(chan error, 1)
srvClosed := make(chan error, 1)

setup := func(ts *httptest.Server) {
ts.Listener = &wrapConnListener{ts.Listener, func(c net.Conn) net.Conn {
return &trackCloseConn{c, connCloses}
}}
srv := ts.Config
// Use ConnContext to close the server after connection is accepted but before it is tracked
srv.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
srvClosed <- srv.Close()
return ctx
}
}
handler := HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, "test content")
})

cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, handler, setup)
defer cst.close()

res, err := cst.c.Get(cst.ts.URL)
if err == nil {
res.Body.Close()
t.Fatalf("expected error from Get")
}

<-srvClosed

select {
case err := <-connCloses:
if err != nil {
t.Fatal("expected no error from connection Close")
}
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for connection Close")
}
}

// Issue 33313 and 36819: serve accepted connection when Shutdown
func TestAcceptedConnectionWhenShutdown(t *testing.T) {
defer afterTest(t)
connCloses := make(chan error, 1)
srvShutdown := make(chan error, 1)

setup := func(ts *httptest.Server) {
ts.Listener = &wrapConnListener{ts.Listener, func(c net.Conn) net.Conn {
return &trackCloseConn{c, connCloses}
}}
srv := ts.Config
inShutdown := make(chan struct{})
srv.RegisterOnShutdown(func() { close(inShutdown) })
// Use ConnContext to shutdown the server after connection is accepted but before it is tracked
srv.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
go func() { srvShutdown <- srv.Shutdown(context.Background()) }()
<-inShutdown
return ctx
}
}
handler := HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, "test content")
})

cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, handler, setup)
defer cst.close()

res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatalf("Get error: %v", err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
res.Body.Close()

content := string(body)
if content != "test content" {
t.Fatalf("unexpected content: %s", content)
}

<-srvShutdown

select {
case err := <-connCloses:
if err != nil {
t.Fatal("expected no error from connection Close")
}
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for connection Close")
}
}
41 changes: 22 additions & 19 deletions src/net/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1778,10 +1778,10 @@ const (
func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) {
srv := c.server
switch state {
case StateNew:
srv.trackConn(c, true)
case StateHijacked, StateClosed:
srv.trackConn(c, false)
srv.mu.Lock()
delete(srv.activeConn, c)
srv.mu.Unlock()
}
if state > 0xff || state < 0 {
panic("internal error")
Expand Down Expand Up @@ -2679,7 +2679,8 @@ type Server struct {
// value.
ConnContext func(ctx context.Context, c net.Conn) context.Context

inShutdown atomicBool // true when server is in shutdown
inShutdown atomicBool // true if Shutdown was called
inClose atomicBool // true if Close was called

disableKeepAlives int32 // accessed atomically.
nextProtoOnce sync.Once // guards setupHTTP2_* init
Expand Down Expand Up @@ -2727,7 +2728,7 @@ func (s *Server) closeDoneChanLocked() {
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
srv.inShutdown.setTrue()
srv.inClose.setTrue()
srv.mu.Lock()
defer srv.mu.Unlock()
srv.closeDoneChanLocked()
Expand Down Expand Up @@ -2979,7 +2980,7 @@ func AllowQuerySemicolons(h Handler) Handler {
// ListenAndServe always returns a non-nil error. After Shutdown or Close,
// the returned error is ErrServerClosed.
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
if srv.inShutdown.isSet() || srv.inClose.isSet() {
return ErrServerClosed
}
addr := srv.Addr
Expand Down Expand Up @@ -3092,6 +3093,10 @@ func (srv *Server) Serve(l net.Listener) error {
}
tempDelay = 0
c := srv.newConn(rw)
if !srv.trackConn(c) {
rw.Close()
return ErrServerClosed
}
c.setState(c.rwc, StateNew, runHooks) // before Serve can return
go c.serve(connCtx)
}
Expand Down Expand Up @@ -3153,7 +3158,7 @@ func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.listeners = make(map[*net.Listener]struct{})
}
if add {
if s.shuttingDown() {
if s.inShutdown.isSet() || s.inClose.isSet() {
return false
}
s.listeners[ln] = struct{}{}
Expand All @@ -3163,17 +3168,19 @@ func (s *Server) trackListener(ln *net.Listener, add bool) bool {
return true
}

func (s *Server) trackConn(c *conn, add bool) {
// trackConn adds a connection to the set of tracked connections.
// It reports whether the server is still up or in graceful shutdown (not Closed).
func (s *Server) trackConn(c *conn) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.inClose.isSet() {
return false
}
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
s.activeConn[c] = struct{}{}
return true
}

func (s *Server) idleTimeout() time.Duration {
Expand All @@ -3191,11 +3198,7 @@ func (s *Server) readHeaderTimeout() time.Duration {
}

func (s *Server) doKeepAlives() bool {
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
}

func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !(s.inShutdown.isSet() || s.inClose.isSet())
}

// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
Expand Down Expand Up @@ -3273,7 +3276,7 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
// ListenAndServeTLS always returns a non-nil error. After Shutdown or
// Close, the returned error is ErrServerClosed.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
if srv.shuttingDown() {
if srv.inShutdown.isSet() || srv.inClose.isSet() {
return ErrServerClosed
}
addr := srv.Addr
Expand Down