diff --git a/context.go b/context.go index 505a43d..2344414 100644 --- a/context.go +++ b/context.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "net" "sync" + "time" gossh "golang.org/x/crypto/ssh" ) @@ -92,13 +93,17 @@ type Context interface { } type sshContext struct { - context.Context - *sync.Mutex + ctx context.Context + mtx *sync.RWMutex } +var _ context.Context = &sshContext{} + +var _ sync.Locker = &sshContext{} + func newContext(srv *Server) (*sshContext, context.CancelFunc) { innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} + ctx := &sshContext{innerCtx, &sync.RWMutex{}} ctx.SetValue(ContextKeyServer, srv) perms := &Permissions{&gossh.Permissions{}} ctx.SetValue(ContextKeyPermissions, perms) @@ -120,7 +125,45 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { } func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) + ctx.mtx.Lock() + defer ctx.mtx.Unlock() + ctx.ctx = context.WithValue(ctx.ctx, key, value) +} + +func (ctx *sshContext) Value(key interface{}) interface{} { + ctx.mtx.RLock() + defer ctx.mtx.RUnlock() + return ctx.ctx.Value(key) +} + +func (ctx *sshContext) Done() <-chan struct{} { + ctx.mtx.RLock() + defer ctx.mtx.RUnlock() + return ctx.ctx.Done() +} + +// Deadline implements context.Context. +func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { + ctx.mtx.RLock() + defer ctx.mtx.RUnlock() + return ctx.ctx.Deadline() +} + +// Err implements context.Context. +func (ctx *sshContext) Err() error { + ctx.mtx.RLock() + defer ctx.mtx.RUnlock() + return ctx.ctx.Err() +} + +// Lock implements sync.Locker. +func (ctx *sshContext) Lock() { + ctx.mtx.Lock() +} + +// Unlock implements sync.Locker. +func (ctx *sshContext) Unlock() { + ctx.mtx.Unlock() } func (ctx *sshContext) User() string { diff --git a/context_test.go b/context_test.go index f5a9315..58051d7 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,9 @@ package ssh -import "testing" +import ( + "testing" + "time" +) func TestSetPermissions(t *testing.T) { t.Parallel() @@ -45,3 +48,63 @@ func TestSetValue(t *testing.T) { t.Fatal(err) } } + +func TestRaceRWIssue160(t *testing.T) { + value := "foo" + key := "bar" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + t.Run("test done", func(t *testing.T) { + t.Parallel() + go func() { + s.Context().SetValue(key, value) + }() + go func() { + select { + case <-s.Context().Done(): + } + }() + }) + }, + }, nil) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +// Taken from https://github.com/gliderlabs/ssh/pull/211/commits/02f9d573009f8c13755b6b90fa14a4f549b17b22 +func TestSetValueConcurrency(t *testing.T) { + ctx, cancel := newContext(nil) + defer cancel() + + go func() { + for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue + _, _ = ctx.Deadline() + _ = ctx.Err() + _ = ctx.Value("foo") + select { + case <-ctx.Done(): + break + default: + } + } + }() + ctx.SetValue("bar", -1) // a context value which never changes + now := time.Now() + var cnt int64 + go func() { + for time.Since(now) < 100*time.Millisecond { + cnt++ + ctx.SetValue("foo", cnt) // a context value which changes a lot + } + cancel() + }() + <-ctx.Done() + if ctx.Value("foo") != cnt { + t.Fatal("context.Value(foo) doesn't match latest SetValue") + } + if ctx.Value("bar") != -1 { + t.Fatal("context.Value(bar) doesn't match latest SetValue") + } +}