From c40cebcc1484f69a093158a7a4c72444e101a009 Mon Sep 17 00:00:00 2001 From: Chris Reeves Date: Mon, 27 Apr 2020 10:03:13 +0100 Subject: [PATCH 1/2] Support setting an optional base context for functions. --- lambda/entry.go | 20 +++++++++++++++++--- lambda/function.go | 20 +++++++++++++++++++- lambda/function_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/lambda/entry.go b/lambda/entry.go index 581d9bcf..9aa6a4ab 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -3,6 +3,7 @@ package lambda import ( + "context" "log" "net" "net/rpc" @@ -37,8 +38,12 @@ import ( // Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library. // See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves func Start(handler interface{}) { - wrappedHandler := NewHandler(handler) - StartHandler(wrappedHandler) + StartWithContext(context.Background(), handler) +} + +// StartWithContext is the same as Start except sets the base context for the function. +func StartWithContext(ctx context.Context, handler interface{}) { + StartHandlerWithContext(ctx, NewHandler(handler)) } // StartHandler takes in a Handler wrapper interface which can be implemented either by a @@ -48,12 +53,21 @@ func Start(handler interface{}) { // // func Invoke(context.Context, []byte) ([]byte, error) func StartHandler(handler Handler) { + StartHandlerWithContext(context.Background(), handler) +} + +// StartHandlerWithContext is the same as StartHandler except sets the base context for the function. +// +// Handler implementation requires a single "Invoke()" function: +// +// func Invoke(context.Context, []byte) ([]byte, error) +func StartHandlerWithContext(ctx context.Context, handler Handler) { port := os.Getenv("_LAMBDA_SERVER_PORT") lis, err := net.Listen("tcp", "localhost:"+port) if err != nil { log.Fatal(err) } - err = rpc.Register(NewFunction(handler)) + err = rpc.Register(NewFunctionWithContext(ctx, handler)) if err != nil { log.Fatal("failed to register handler function") } diff --git a/lambda/function.go b/lambda/function.go index 1400dbcc..f6cc0281 100644 --- a/lambda/function.go +++ b/lambda/function.go @@ -16,6 +16,7 @@ import ( // Function struct which wrap the Handler type Function struct { handler Handler + context context.Context } // NewFunction which creates a Function with a given Handler @@ -23,6 +24,23 @@ func NewFunction(handler Handler) *Function { return &Function{handler: handler} } +// NewFunctionWithContext which creates a Function with a given Handler and sets the base Context. +func NewFunctionWithContext(ctx context.Context, handler Handler) *Function { + return &Function{ + context: ctx, + handler: handler, + } +} + +// Context returns the base context used for the fn. +func (fn *Function) Context() context.Context { + if fn.context == nil { + return context.Background() + } + + return fn.context +} + // Ping method which given a PingRequest and a PingResponse parses the PingResponse func (fn *Function) Ping(req *messages.PingRequest, response *messages.PingResponse) error { *response = messages.PingResponse{} @@ -44,7 +62,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok }() deadline := time.Unix(req.Deadline.Seconds, req.Deadline.Nanos).UTC() - invokeContext, cancel := context.WithDeadline(context.Background(), deadline) + invokeContext, cancel := context.WithDeadline(fn.Context(), deadline) defer cancel() lc := &lambdacontext.LambdaContext{ diff --git a/lambda/function_test.go b/lambda/function_test.go index 9531e5e4..824ec774 100644 --- a/lambda/function_test.go +++ b/lambda/function_test.go @@ -58,6 +58,32 @@ func TestInvoke(t *testing.T) { assert.Equal(t, deadline.UnixNano(), responseValue) } +func TestInvokeWithContext(t *testing.T) { + key := struct{}{} + srv := &Function{ + handler: testWrapperHandler( + func(ctx context.Context, input []byte) (interface{}, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return deadline.UnixNano(), nil + } + return nil, errors.New("!?!?!?!?!") + }), + context: context.WithValue(context.Background(), key, "dummy"), + } + deadline := time.Now() + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{ + Deadline: messages.InvokeRequest_Timestamp{ + Seconds: deadline.Unix(), + Nanos: int64(deadline.Nanosecond()), + }}, &response) + assert.NoError(t, err) + var responseValue int64 + assert.NoError(t, json.Unmarshal(response.Payload, &responseValue)) + assert.Equal(t, deadline.UnixNano(), responseValue) +} + type CustomError struct{} func (e CustomError) Error() string { return fmt.Sprintf("Something bad happened!") } From 0d6c14c7d9710c79b1dbe5e020046750a0751432 Mon Sep 17 00:00:00 2001 From: Chris Reeves Date: Sat, 23 May 2020 12:32:21 +0100 Subject: [PATCH 2/2] refactor: remove exported context related Function methods --- lambda/entry.go | 6 ++++-- lambda/function.go | 45 ++++++++++++++++++++++++----------------- lambda/function_test.go | 20 +++++++++--------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/lambda/entry.go b/lambda/entry.go index 9aa6a4ab..b3e52ea4 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -67,10 +67,12 @@ func StartHandlerWithContext(ctx context.Context, handler Handler) { if err != nil { log.Fatal(err) } - err = rpc.Register(NewFunctionWithContext(ctx, handler)) - if err != nil { + + fn := NewFunction(handler).withContext(ctx) + if err := rpc.Register(fn); err != nil { log.Fatal("failed to register handler function") } + rpc.Accept(lis) log.Fatal("accept should not have returned") } diff --git a/lambda/function.go b/lambda/function.go index f6cc0281..628deb03 100644 --- a/lambda/function.go +++ b/lambda/function.go @@ -16,7 +16,7 @@ import ( // Function struct which wrap the Handler type Function struct { handler Handler - context context.Context + ctx context.Context } // NewFunction which creates a Function with a given Handler @@ -24,23 +24,6 @@ func NewFunction(handler Handler) *Function { return &Function{handler: handler} } -// NewFunctionWithContext which creates a Function with a given Handler and sets the base Context. -func NewFunctionWithContext(ctx context.Context, handler Handler) *Function { - return &Function{ - context: ctx, - handler: handler, - } -} - -// Context returns the base context used for the fn. -func (fn *Function) Context() context.Context { - if fn.context == nil { - return context.Background() - } - - return fn.context -} - // Ping method which given a PingRequest and a PingResponse parses the PingResponse func (fn *Function) Ping(req *messages.PingRequest, response *messages.PingResponse) error { *response = messages.PingResponse{} @@ -62,7 +45,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok }() deadline := time.Unix(req.Deadline.Seconds, req.Deadline.Nanos).UTC() - invokeContext, cancel := context.WithDeadline(fn.Context(), deadline) + invokeContext, cancel := context.WithDeadline(fn.context(), deadline) defer cancel() lc := &lambdacontext.LambdaContext{ @@ -93,6 +76,30 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok return nil } +// context returns the base context used for the fn. +func (fn *Function) context() context.Context { + if fn.ctx == nil { + return context.Background() + } + + return fn.ctx +} + +// withContext returns a shallow copy of Function with its context changed +// to the provided ctx. If the provided ctx is non-nil a Background context is set. +func (fn *Function) withContext(ctx context.Context) *Function { + if ctx == nil { + ctx = context.Background() + } + + fn2 := new(Function) + *fn2 = *fn + + fn2.ctx = ctx + + return fn2 +} + func getErrorType(err interface{}) string { errorType := reflect.TypeOf(err) if errorType.Kind() == reflect.Ptr { diff --git a/lambda/function_test.go b/lambda/function_test.go index 824ec774..ab7e7b27 100644 --- a/lambda/function_test.go +++ b/lambda/function_test.go @@ -60,17 +60,15 @@ func TestInvoke(t *testing.T) { func TestInvokeWithContext(t *testing.T) { key := struct{}{} - srv := &Function{ - handler: testWrapperHandler( - func(ctx context.Context, input []byte) (interface{}, error) { - assert.Equal(t, "dummy", ctx.Value(key)) - if deadline, ok := ctx.Deadline(); ok { - return deadline.UnixNano(), nil - } - return nil, errors.New("!?!?!?!?!") - }), - context: context.WithValue(context.Background(), key, "dummy"), - } + srv := NewFunction(testWrapperHandler( + func(ctx context.Context, input []byte) (interface{}, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return deadline.UnixNano(), nil + } + return nil, errors.New("!?!?!?!?!") + })) + srv = srv.withContext(context.WithValue(context.Background(), key, "dummy")) deadline := time.Now() var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{