diff --git a/server.go b/server.go index 7dbaa0f..70492b1 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" gossh "golang.org/x/crypto/ssh" @@ -70,12 +71,12 @@ type Server struct { // handlers, but handle named subsystems. SubsystemHandlers map[string]SubsystemHandler + inShutdown atomic.Bool // true when server is in shutdown listenerWg sync.WaitGroup mu sync.RWMutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup - doneChan chan struct{} } func (srv *Server) ensureHostSigner() error { @@ -191,11 +192,20 @@ func (srv *Server) Handle(fn Handler) { // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { + srv.inShutdown.Store(true) srv.mu.Lock() defer srv.mu.Unlock() - srv.closeDoneChanLocked() err := srv.closeListenersLocked() + + // Unlock srv.mu while waiting for listenerWg. + // The group Add and Done calls are made with srv.mu held, + // to avoid adding a new listener in the window between + // us setting inShutdown above and waiting here. + srv.mu.Unlock() + srv.listenerWg.Wait() + srv.mu.Lock() + for c := range srv.conns { c.Close() delete(srv.conns, c) @@ -209,9 +219,9 @@ func (srv *Server) Close() error { // If the provided context expires before the shutdown is complete, // then the context's error is returned. func (srv *Server) Shutdown(ctx context.Context) error { + srv.inShutdown.Store(true) srv.mu.Lock() lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() srv.mu.Unlock() finished := make(chan struct{}, 1) @@ -229,6 +239,10 @@ func (srv *Server) Shutdown(ctx context.Context) error { } } +func (s *Server) shuttingDown() bool { + return s.inShutdown.Load() +} + // Serve accepts incoming connections on the Listener l, creating a new // connection goroutine for each. The connection goroutines read requests and then // calls srv.Handler to handle sessions. @@ -245,15 +259,15 @@ func (srv *Server) Serve(l net.Listener) error { } var tempDelay time.Duration - srv.trackListener(l, true) + if !srv.trackListener(l, true) { + return ErrServerClosed + } defer srv.trackListener(l, false) for { conn, e := l.Accept() if e != nil { - select { - case <-srv.getDoneChan(): + if srv.shuttingDown() { return ErrServerClosed - default: } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -393,32 +407,6 @@ func (srv *Server) SetOption(option Option) error { return option(srv) } -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - func (srv *Server) closeListenersLocked() error { var err error for ln := range srv.listeners { @@ -430,7 +418,7 @@ func (srv *Server) closeListenersLocked() error { return err } -func (srv *Server) trackListener(ln net.Listener, add bool) { +func (srv *Server) trackListener(ln net.Listener, add bool) bool { srv.mu.Lock() defer srv.mu.Unlock() @@ -438,10 +426,8 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { srv.listeners = make(map[net.Listener]struct{}) } if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil + if srv.shuttingDown() { + return false } srv.listeners[ln] = struct{}{} srv.listenerWg.Add(1) @@ -449,6 +435,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { delete(srv.listeners, ln) srv.listenerWg.Done() } + return true } func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { diff --git a/server_test.go b/server_test.go index 63fe694..7b0fd49 100644 --- a/server_test.go +++ b/server_test.go @@ -41,7 +41,7 @@ func TestServerShutdown(t *testing.T) { go func() { err := s.Serve(l) if err != nil && err != ErrServerClosed { - t.Fatal(err) + t.Error(err) } }() sessDone := make(chan struct{}) @@ -52,10 +52,10 @@ func TestServerShutdown(t *testing.T) { var stdout bytes.Buffer sess.Stdout = &stdout if err := sess.Run(""); err != nil { - t.Fatal(err) + t.Error(err) } if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + t.Errorf("expected = %s; got %s", testBytes, stdout.Bytes()) } }() @@ -64,7 +64,7 @@ func TestServerShutdown(t *testing.T) { defer close(srvDone) err := s.Shutdown(context.Background()) if err != nil { - t.Fatal(err) + t.Error(err) } }() @@ -90,7 +90,7 @@ func TestServerClose(t *testing.T) { go func() { err := s.Serve(l) if err != nil && err != ErrServerClosed { - t.Fatal(err) + t.Error(err) } }() @@ -103,14 +103,14 @@ func TestServerClose(t *testing.T) { defer close(clientDoneChan) <-closeDoneChan if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) + t.Error(err) } }() go func() { err := s.Close() if err != nil { - t.Fatal(err) + t.Error(err) } close(closeDoneChan) }() @@ -120,12 +120,44 @@ func TestServerClose(t *testing.T) { case <-timeout: t.Error("timeout") return - case <-s.getDoneChan(): + case <-closeDoneChan: <-clientDoneChan return } } +func TestServerCloseBeforeServe(t *testing.T) { + l := newLocalListener() + s := &Server{} + + serveDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + go func() { + <-closeDoneChan + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Error(err) + } + close(serveDoneChan) + }() + + err := s.Close() + if err != nil { + t.Error(err) + } + close(closeDoneChan) + + timeout := time.After(1 * time.Second) + select { + case <-timeout: + t.Error("timeout") + return + case <-serveDoneChan: + return + } +} + func TestServerHandshakeTimeout(t *testing.T) { l := newLocalListener()