diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 464e0f734df465..66a6c11911d051 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6757,3 +6757,124 @@ func TestProcessing(t *testing.T) { t.Errorf("unexpected response; got %q; should start by %q", got, expected) } } + +type wrapConnListener struct { + net.Listener + wrap func(net.Conn) net.Conn +} + +func (l *wrapConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return nil, err + } + c = l.wrap(c) + return +} + +type trackCloseConn struct { + net.Conn + closes chan<- error +} + +func (p *trackCloseConn) Close() error { + err := p.Conn.Close() + p.closes <- err + return err +} + +// Issue 48642: close accepted connection when Closed +func TestAcceptedConnectionWhenClosed(t *testing.T) { + defer afterTest(t) + connCloses := make(chan error, 1) + srvClosed := make(chan error, 1) + + setup := func(ts *httptest.Server) { + ts.Listener = &wrapConnListener{ts.Listener, func(c net.Conn) net.Conn { + return &trackCloseConn{c, connCloses} + }} + srv := ts.Config + // Use ConnContext to close the server after connection is accepted but before it is tracked + srv.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + srvClosed <- srv.Close() + return ctx + } + } + handler := HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "test content") + }) + + cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, handler, setup) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err == nil { + res.Body.Close() + t.Fatalf("expected error from Get") + } + + <-srvClosed + + select { + case err := <-connCloses: + if err != nil { + t.Fatal("expected no error from connection Close") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for connection Close") + } +} + +// Issue 33313 and 36819: serve accepted connection when Shutdown +func TestAcceptedConnectionWhenShutdown(t *testing.T) { + defer afterTest(t) + connCloses := make(chan error, 1) + srvShutdown := make(chan error, 1) + + setup := func(ts *httptest.Server) { + ts.Listener = &wrapConnListener{ts.Listener, func(c net.Conn) net.Conn { + return &trackCloseConn{c, connCloses} + }} + srv := ts.Config + inShutdown := make(chan struct{}) + srv.RegisterOnShutdown(func() { close(inShutdown) }) + // Use ConnContext to shutdown the server after connection is accepted but before it is tracked + srv.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + go func() { srvShutdown <- srv.Shutdown(context.Background()) }() + <-inShutdown + return ctx + } + } + handler := HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "test content") + }) + + cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, handler, setup) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + res.Body.Close() + + content := string(body) + if content != "test content" { + t.Fatalf("unexpected content: %s", content) + } + + <-srvShutdown + + select { + case err := <-connCloses: + if err != nil { + t.Fatal("expected no error from connection Close") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for connection Close") + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go index bc3a4633da8b15..3af6993b13a253 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1778,10 +1778,10 @@ const ( func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { srv := c.server switch state { - case StateNew: - srv.trackConn(c, true) case StateHijacked, StateClosed: - srv.trackConn(c, false) + srv.mu.Lock() + delete(srv.activeConn, c) + srv.mu.Unlock() } if state > 0xff || state < 0 { panic("internal error") @@ -2679,7 +2679,8 @@ type Server struct { // value. ConnContext func(ctx context.Context, c net.Conn) context.Context - inShutdown atomicBool // true when server is in shutdown + inShutdown atomicBool // true if Shutdown was called + inClose atomicBool // true if Close was called disableKeepAlives int32 // accessed atomically. nextProtoOnce sync.Once // guards setupHTTP2_* init @@ -2727,7 +2728,7 @@ func (s *Server) closeDoneChanLocked() { // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { - srv.inShutdown.setTrue() + srv.inClose.setTrue() srv.mu.Lock() defer srv.mu.Unlock() srv.closeDoneChanLocked() @@ -2979,7 +2980,7 @@ func AllowQuerySemicolons(h Handler) Handler { // ListenAndServe always returns a non-nil error. After Shutdown or Close, // the returned error is ErrServerClosed. func (srv *Server) ListenAndServe() error { - if srv.shuttingDown() { + if srv.inShutdown.isSet() || srv.inClose.isSet() { return ErrServerClosed } addr := srv.Addr @@ -3092,6 +3093,10 @@ func (srv *Server) Serve(l net.Listener) error { } tempDelay = 0 c := srv.newConn(rw) + if !srv.trackConn(c) { + rw.Close() + return ErrServerClosed + } c.setState(c.rwc, StateNew, runHooks) // before Serve can return go c.serve(connCtx) } @@ -3153,7 +3158,7 @@ func (s *Server) trackListener(ln *net.Listener, add bool) bool { s.listeners = make(map[*net.Listener]struct{}) } if add { - if s.shuttingDown() { + if s.inShutdown.isSet() || s.inClose.isSet() { return false } s.listeners[ln] = struct{}{} @@ -3163,17 +3168,19 @@ func (s *Server) trackListener(ln *net.Listener, add bool) bool { return true } -func (s *Server) trackConn(c *conn, add bool) { +// trackConn adds a connection to the set of tracked connections. +// It reports whether the server is still up or in graceful shutdown (not Closed). +func (s *Server) trackConn(c *conn) bool { s.mu.Lock() defer s.mu.Unlock() + if s.inClose.isSet() { + return false + } if s.activeConn == nil { s.activeConn = make(map[*conn]struct{}) } - if add { - s.activeConn[c] = struct{}{} - } else { - delete(s.activeConn, c) - } + s.activeConn[c] = struct{}{} + return true } func (s *Server) idleTimeout() time.Duration { @@ -3191,11 +3198,7 @@ func (s *Server) readHeaderTimeout() time.Duration { } func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() -} - -func (s *Server) shuttingDown() bool { - return s.inShutdown.isSet() + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !(s.inShutdown.isSet() || s.inClose.isSet()) } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -3273,7 +3276,7 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { // ListenAndServeTLS always returns a non-nil error. After Shutdown or // Close, the returned error is ErrServerClosed. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { - if srv.shuttingDown() { + if srv.inShutdown.isSet() || srv.inClose.isSet() { return ErrServerClosed } addr := srv.Addr