diff --git a/server.go b/server.go index cad0402..359f967 100644 --- a/server.go +++ b/server.go @@ -58,7 +58,7 @@ type Server struct { RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup - mu sync.Mutex + mu sync.RWMutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup @@ -66,6 +66,9 @@ type Server struct { } func (srv *Server) ensureHostSigner() error { + srv.mu.Lock() + defer srv.mu.Unlock() + if len(srv.HostSigners) == 0 { signer, err := generateSigner() if err != nil { @@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() + if srv.RequestHandlers == nil { srv.RequestHandlers = map[string]RequestHandler{} for k, v := range DefaultRequestHandlers { @@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() { } func (srv *Server) config(ctx Context) *gossh.ServerConfig { + srv.mu.RLock() + defer srv.mu.RUnlock() + var config *gossh.ServerConfig if srv.ServerConfigCallback == nil { config = &gossh.ServerConfig{} @@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { // Handle sets the Handler for the server. func (srv *Server) Handle(fn Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.Handler = fn } @@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) { func (srv *Server) Close() error { srv.mu.Lock() defer srv.mu.Unlock() + srv.closeDoneChanLocked() err := srv.closeListenersLocked() for c := range srv.conns { @@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error { // with the same algorithm, it is overwritten. Each server config must have at // least one host key. func (srv *Server) AddHostKey(key Signer) { + srv.mu.Lock() + defer srv.mu.Unlock() + // these are later added via AddHostKey on ServerConfig, which performs the // check for one of every algorithm. + + // This check is based on the AddHostKey method from the x/crypto/ssh + // library. This allows us to only keep one active key for each type on a + // server at once. So, if you're dynamically updating keys at runtime, this + // list will not keep growing. + for i, k := range srv.HostSigners { + if k.PublicKey().Type() == key.PublicKey().Type() { + srv.HostSigners[i] = key + return + } + } + srv.HostSigners = append(srv.HostSigners, key) } // SetOption runs a functional option against the server. func (srv *Server) SetOption(option Option) error { + // NOTE: there is a potential race here for any option that doesn't call an + // internal method. We can't actually lock here because if something calls + // (as an example) AddHostKey, it will deadlock. + + //srv.mu.Lock() + //defer srv.mu.Unlock() + return option(srv) } func (srv *Server) getDoneChan() <-chan struct{} { srv.mu.Lock() defer srv.mu.Unlock() + return srv.getDoneChanLocked() } @@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error { func (srv *Server) trackListener(ln net.Listener, add bool) { srv.mu.Lock() defer srv.mu.Unlock() + if srv.listeners == nil { srv.listeners = make(map[net.Listener]struct{}) } @@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { srv.mu.Lock() defer srv.mu.Unlock() + if srv.conns == nil { srv.conns = make(map[*gossh.ServerConn]struct{}) } diff --git a/server_test.go b/server_test.go index 558f171..8028a3a 100644 --- a/server_test.go +++ b/server_test.go @@ -8,6 +8,26 @@ import ( "time" ) +func TestAddHostKey(t *testing.T) { + s := Server{} + signer, err := generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly added") + } + signer, err = generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly replaced") + } +} + func TestServerShutdown(t *testing.T) { l := newLocalListener() testBytes := []byte("Hello world\n") diff --git a/session_test.go b/session_test.go index f086792..786a661 100644 --- a/session_test.go +++ b/session_test.go @@ -289,20 +289,40 @@ func TestPtyResize(t *testing.T) { func TestSignals(t *testing.T) { t.Parallel() + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + session, _, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { - signals := make(chan Signal) + // We need to use a buffered channel here, otherwise it's possible for the + // second call to Signal to get discarded. + signals := make(chan Signal, 2) s.Signals(signals) - if sig := <-signals; sig != SIGINT { - t.Fatalf("expected signal %v but got %v", SIGINT, sig) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return } - exiter := make(chan bool) - go func() { - if sig := <-signals; sig == SIGKILL { - close(exiter) + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return } - }() - <-exiter + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } }, }, nil) defer cleanup() @@ -312,7 +332,13 @@ func TestSignals(t *testing.T) { session.Signal(gossh.SIGKILL) }() - err := session.Run("") + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + if err != nil { t.Fatalf("expected nil but got %v", err) }