Skip to content

fix: handle Close() before Serve() #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 25 additions & 38 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -70,12 +71,12 @@ type Server struct {
// handlers, but handle named subsystems.
SubsystemHandlers map[string]SubsystemHandler

inShutdown atomic.Bool // true when server is in shutdown
listenerWg sync.WaitGroup
mu sync.RWMutex
listeners map[net.Listener]struct{}
conns map[*gossh.ServerConn]struct{}
connWg sync.WaitGroup
doneChan chan struct{}
}

func (srv *Server) ensureHostSigner() error {
Expand Down Expand Up @@ -191,11 +192,20 @@ func (srv *Server) Handle(fn Handler) {
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
srv.inShutdown.Store(true)
srv.mu.Lock()
defer srv.mu.Unlock()

srv.closeDoneChanLocked()
err := srv.closeListenersLocked()

// Unlock srv.mu while waiting for listenerWg.
// The group Add and Done calls are made with srv.mu held,
// to avoid adding a new listener in the window between
// us setting inShutdown above and waiting here.
srv.mu.Unlock()
srv.listenerWg.Wait()
srv.mu.Lock()

for c := range srv.conns {
c.Close()
delete(srv.conns, c)
Expand All @@ -209,9 +219,9 @@ func (srv *Server) Close() error {
// If the provided context expires before the shutdown is complete,
// then the context's error is returned.
func (srv *Server) Shutdown(ctx context.Context) error {
srv.inShutdown.Store(true)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
srv.closeDoneChanLocked()
srv.mu.Unlock()

finished := make(chan struct{}, 1)
Expand All @@ -229,6 +239,10 @@ func (srv *Server) Shutdown(ctx context.Context) error {
}
}

func (s *Server) shuttingDown() bool {
return s.inShutdown.Load()
}

// Serve accepts incoming connections on the Listener l, creating a new
// connection goroutine for each. The connection goroutines read requests and then
// calls srv.Handler to handle sessions.
Expand All @@ -245,15 +259,15 @@ func (srv *Server) Serve(l net.Listener) error {
}
var tempDelay time.Duration

srv.trackListener(l, true)
if !srv.trackListener(l, true) {
return ErrServerClosed
}
defer srv.trackListener(l, false)
for {
conn, e := l.Accept()
if e != nil {
select {
case <-srv.getDoneChan():
if srv.shuttingDown() {
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
Expand Down Expand Up @@ -393,32 +407,6 @@ func (srv *Server) SetOption(option Option) error {
return option(srv)
}

func (srv *Server) getDoneChan() <-chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()

return srv.getDoneChanLocked()
}

func (srv *Server) getDoneChanLocked() chan struct{} {
if srv.doneChan == nil {
srv.doneChan = make(chan struct{})
}
return srv.doneChan
}

func (srv *Server) closeDoneChanLocked() {
ch := srv.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by srv.mu.
close(ch)
}
}

func (srv *Server) closeListenersLocked() error {
var err error
for ln := range srv.listeners {
Expand All @@ -430,25 +418,24 @@ func (srv *Server) closeListenersLocked() error {
return err
}

func (srv *Server) trackListener(ln net.Listener, add bool) {
func (srv *Server) trackListener(ln net.Listener, add bool) bool {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.listeners == nil {
srv.listeners = make(map[net.Listener]struct{})
}
if add {
// If the *Server is being reused after a previous
// Close or Shutdown, reset its doneChan:
if len(srv.listeners) == 0 && len(srv.conns) == 0 {
srv.doneChan = nil
if srv.shuttingDown() {
return false
}
srv.listeners[ln] = struct{}{}
srv.listenerWg.Add(1)
} else {
delete(srv.listeners, ln)
srv.listenerWg.Done()
}
return true
}

func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
Expand Down
48 changes: 40 additions & 8 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestServerShutdown(t *testing.T) {
go func() {
err := s.Serve(l)
if err != nil && err != ErrServerClosed {
t.Fatal(err)
t.Error(err)
}
}()
sessDone := make(chan struct{})
Expand All @@ -52,10 +52,10 @@ func TestServerShutdown(t *testing.T) {
var stdout bytes.Buffer
sess.Stdout = &stdout
if err := sess.Run(""); err != nil {
t.Fatal(err)
t.Error(err)
}
if !bytes.Equal(stdout.Bytes(), testBytes) {
t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes())
t.Errorf("expected = %s; got %s", testBytes, stdout.Bytes())
}
}()

Expand All @@ -64,7 +64,7 @@ func TestServerShutdown(t *testing.T) {
defer close(srvDone)
err := s.Shutdown(context.Background())
if err != nil {
t.Fatal(err)
t.Error(err)
}
}()

Expand All @@ -90,7 +90,7 @@ func TestServerClose(t *testing.T) {
go func() {
err := s.Serve(l)
if err != nil && err != ErrServerClosed {
t.Fatal(err)
t.Error(err)
}
}()

Expand All @@ -103,14 +103,14 @@ func TestServerClose(t *testing.T) {
defer close(clientDoneChan)
<-closeDoneChan
if err := sess.Run(""); err != nil && err != io.EOF {
t.Fatal(err)
t.Error(err)
}
}()

go func() {
err := s.Close()
if err != nil {
t.Fatal(err)
t.Error(err)
}
close(closeDoneChan)
}()
Expand All @@ -120,12 +120,44 @@ func TestServerClose(t *testing.T) {
case <-timeout:
t.Error("timeout")
return
case <-s.getDoneChan():
case <-closeDoneChan:
<-clientDoneChan
return
}
}

func TestServerCloseBeforeServe(t *testing.T) {
l := newLocalListener()
s := &Server{}

serveDoneChan := make(chan struct{})
closeDoneChan := make(chan struct{})

go func() {
<-closeDoneChan
err := s.Serve(l)
if err != nil && err != ErrServerClosed {
t.Error(err)
}
close(serveDoneChan)
}()

err := s.Close()
if err != nil {
t.Error(err)
}
close(closeDoneChan)

timeout := time.After(1 * time.Second)
select {
case <-timeout:
t.Error("timeout")
return
case <-serveDoneChan:
return
}
}

func TestServerHandshakeTimeout(t *testing.T) {
l := newLocalListener()

Expand Down