Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ type Server struct {
RequestHandlers map[string]RequestHandler

listenerWg sync.WaitGroup
mu sync.Mutex
mu sync.RWMutex
listeners map[net.Listener]struct{}
conns map[*gossh.ServerConn]struct{}
connWg sync.WaitGroup
doneChan chan struct{}
}

func (srv *Server) ensureHostSigner() error {
srv.mu.Lock()
defer srv.mu.Unlock()

if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
if err != nil {
Expand All @@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
func (srv *Server) ensureHandlers() {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.RequestHandlers == nil {
srv.RequestHandlers = map[string]RequestHandler{}
for k, v := range DefaultRequestHandlers {
Expand All @@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
}

func (srv *Server) config(ctx Context) *gossh.ServerConfig {
srv.mu.RLock()
defer srv.mu.RUnlock()

var config *gossh.ServerConfig
if srv.ServerConfigCallback == nil {
config = &gossh.ServerConfig{}
Expand Down Expand Up @@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {

// Handle sets the Handler for the server.
func (srv *Server) Handle(fn Handler) {
srv.mu.Lock()
defer srv.mu.Unlock()

srv.Handler = fn
}

Expand All @@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
func (srv *Server) Close() error {
srv.mu.Lock()
defer srv.mu.Unlock()

srv.closeDoneChanLocked()
err := srv.closeListenersLocked()
for c := range srv.conns {
Expand Down Expand Up @@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error {
// with the same algorithm, it is overwritten. Each server config must have at
// least one host key.
func (srv *Server) AddHostKey(key Signer) {
srv.mu.Lock()
defer srv.mu.Unlock()

// these are later added via AddHostKey on ServerConfig, which performs the
// check for one of every algorithm.

// This check is based on the AddHostKey method from the x/crypto/ssh
// library. This allows us to only keep one active key for each type on a
// server at once. So, if you're dynamically updating keys at runtime, this
// list will not keep growing.
for i, k := range srv.HostSigners {
if k.PublicKey().Type() == key.PublicKey().Type() {
srv.HostSigners[i] = key
return
}
}

srv.HostSigners = append(srv.HostSigners, key)
}

// SetOption runs a functional option against the server.
func (srv *Server) SetOption(option Option) error {
// NOTE: there is a potential race here for any option that doesn't call an
// internal method. We can't actually lock here because if something calls
// (as an example) AddHostKey, it will deadlock.

//srv.mu.Lock()
//defer srv.mu.Unlock()

return option(srv)
}

func (srv *Server) getDoneChan() <-chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()

return srv.getDoneChanLocked()
}

Expand Down Expand Up @@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
func (srv *Server) trackListener(ln net.Listener, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.listeners == nil {
srv.listeners = make(map[net.Listener]struct{})
}
Expand All @@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.conns == nil {
srv.conns = make(map[*gossh.ServerConn]struct{})
}
Expand Down
20 changes: 20 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@ import (
"time"
)

func TestAddHostKey(t *testing.T) {
s := Server{}
signer, err := generateSigner()
if err != nil {
t.Fatal(err)
}
s.AddHostKey(signer)
if len(s.HostSigners) != 1 {
t.Fatal("Key was not properly added")
}
signer, err = generateSigner()
if err != nil {
t.Fatal(err)
}
s.AddHostKey(signer)
if len(s.HostSigners) != 1 {
t.Fatal("Key was not properly replaced")
}
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
testBytes := []byte("Hello world\n")
Expand Down
46 changes: 36 additions & 10 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,40 @@ func TestPtyResize(t *testing.T) {
func TestSignals(t *testing.T) {
t.Parallel()

// errChan lets us get errors back from the session
errChan := make(chan error, 5)

// doneChan lets us specify that we should exit.
doneChan := make(chan interface{})

session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
signals := make(chan Signal)
// We need to use a buffered channel here, otherwise it's possible for the
// second call to Signal to get discarded.
signals := make(chan Signal, 2)
s.Signals(signals)
if sig := <-signals; sig != SIGINT {
t.Fatalf("expected signal %v but got %v", SIGINT, sig)

select {
case sig := <-signals:
if sig != SIGINT {
errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig)
return
}
case <-doneChan:
errChan <- fmt.Errorf("Unexpected done")
return
}
exiter := make(chan bool)
go func() {
if sig := <-signals; sig == SIGKILL {
close(exiter)

select {
case sig := <-signals:
if sig != SIGKILL {
errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig)
return
}
}()
<-exiter
case <-doneChan:
errChan <- fmt.Errorf("Unexpected done")
return
}
},
}, nil)
defer cleanup()
Expand All @@ -312,7 +332,13 @@ func TestSignals(t *testing.T) {
session.Signal(gossh.SIGKILL)
}()

err := session.Run("")
go func() {
errChan <- session.Run("")
}()

err := <-errChan
close(doneChan)

if err != nil {
t.Fatalf("expected nil but got %v", err)
}
Expand Down