diff --git a/context.go b/context.go index 505a43d..99725eb 100644 --- a/context.go +++ b/context.go @@ -94,11 +94,14 @@ type Context interface { type sshContext struct { context.Context *sync.Mutex + + values map[interface{}]interface{} + valuesMu sync.Mutex } func newContext(srv *Server) (*sshContext, context.CancelFunc) { innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} + ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})} ctx.SetValue(ContextKeyServer, srv) perms := &Permissions{&gossh.Permissions{}} ctx.SetValue(ContextKeyPermissions, perms) @@ -119,8 +122,19 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) } +func (ctx *sshContext) Value(key interface{}) interface{} { + ctx.valuesMu.Lock() + defer ctx.valuesMu.Unlock() + if v, ok := ctx.values[key]; ok { + return v + } + return ctx.Context.Value(key) +} + func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) + ctx.valuesMu.Lock() + defer ctx.valuesMu.Unlock() + ctx.values[key] = value } func (ctx *sshContext) User() string { diff --git a/context_test.go b/context_test.go index f5a9315..0b49abf 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,9 @@ package ssh -import "testing" +import ( + "testing" + "time" +) func TestSetPermissions(t *testing.T) { t.Parallel() @@ -45,3 +48,38 @@ func TestSetValue(t *testing.T) { t.Fatal(err) } } + +func TestSetValueConcurrency(t *testing.T) { + ctx, cancel := newContext(nil) + defer cancel() + + go func() { + for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue + _, _ = ctx.Deadline() + _ = ctx.Err() + _ = ctx.Value("foo") + select { + case <-ctx.Done(): + break + default: + } + } + }() + ctx.SetValue("bar", -1) // a context value which never changes + now := time.Now() + var cnt int64 + go func() { + for time.Since(now) < 100*time.Millisecond { + cnt++ + ctx.SetValue("foo", cnt) // a context value which changes a lot + } + cancel() + }() + <-ctx.Done() + if ctx.Value("foo") != cnt { + t.Fatal("context.Value(foo) doesn't match latest SetValue") + } + if ctx.Value("bar") != -1 { + t.Fatal("context.Value(bar) doesn't match latest SetValue") + } +}