diff --git a/lambda/entry.go b/lambda/entry.go index 581d9bcf..b3e52ea4 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -3,6 +3,7 @@ package lambda import ( + "context" "log" "net" "net/rpc" @@ -37,8 +38,12 @@ import ( // Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library. // See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves func Start(handler interface{}) { - wrappedHandler := NewHandler(handler) - StartHandler(wrappedHandler) + StartWithContext(context.Background(), handler) +} + +// StartWithContext is the same as Start except sets the base context for the function. +func StartWithContext(ctx context.Context, handler interface{}) { + StartHandlerWithContext(ctx, NewHandler(handler)) } // StartHandler takes in a Handler wrapper interface which can be implemented either by a @@ -48,15 +53,26 @@ func Start(handler interface{}) { // // func Invoke(context.Context, []byte) ([]byte, error) func StartHandler(handler Handler) { + StartHandlerWithContext(context.Background(), handler) +} + +// StartHandlerWithContext is the same as StartHandler except sets the base context for the function. +// +// Handler implementation requires a single "Invoke()" function: +// +// func Invoke(context.Context, []byte) ([]byte, error) +func StartHandlerWithContext(ctx context.Context, handler Handler) { port := os.Getenv("_LAMBDA_SERVER_PORT") lis, err := net.Listen("tcp", "localhost:"+port) if err != nil { log.Fatal(err) } - err = rpc.Register(NewFunction(handler)) - if err != nil { + + fn := NewFunction(handler).withContext(ctx) + if err := rpc.Register(fn); err != nil { log.Fatal("failed to register handler function") } + rpc.Accept(lis) log.Fatal("accept should not have returned") } diff --git a/lambda/function.go b/lambda/function.go index 1400dbcc..628deb03 100644 --- a/lambda/function.go +++ b/lambda/function.go @@ -16,6 +16,7 @@ import ( // Function struct which wrap the Handler type Function struct { handler Handler + ctx context.Context } // NewFunction which creates a Function with a given Handler @@ -44,7 +45,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok }() deadline := time.Unix(req.Deadline.Seconds, req.Deadline.Nanos).UTC() - invokeContext, cancel := context.WithDeadline(context.Background(), deadline) + invokeContext, cancel := context.WithDeadline(fn.context(), deadline) defer cancel() lc := &lambdacontext.LambdaContext{ @@ -75,6 +76,30 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok return nil } +// context returns the base context used for the fn. +func (fn *Function) context() context.Context { + if fn.ctx == nil { + return context.Background() + } + + return fn.ctx +} + +// withContext returns a shallow copy of Function with its context changed +// to the provided ctx. If the provided ctx is non-nil a Background context is set. +func (fn *Function) withContext(ctx context.Context) *Function { + if ctx == nil { + ctx = context.Background() + } + + fn2 := new(Function) + *fn2 = *fn + + fn2.ctx = ctx + + return fn2 +} + func getErrorType(err interface{}) string { errorType := reflect.TypeOf(err) if errorType.Kind() == reflect.Ptr { diff --git a/lambda/function_test.go b/lambda/function_test.go index 9531e5e4..ab7e7b27 100644 --- a/lambda/function_test.go +++ b/lambda/function_test.go @@ -58,6 +58,30 @@ func TestInvoke(t *testing.T) { assert.Equal(t, deadline.UnixNano(), responseValue) } +func TestInvokeWithContext(t *testing.T) { + key := struct{}{} + srv := NewFunction(testWrapperHandler( + func(ctx context.Context, input []byte) (interface{}, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return deadline.UnixNano(), nil + } + return nil, errors.New("!?!?!?!?!") + })) + srv = srv.withContext(context.WithValue(context.Background(), key, "dummy")) + deadline := time.Now() + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{ + Deadline: messages.InvokeRequest_Timestamp{ + Seconds: deadline.Unix(), + Nanos: int64(deadline.Nanosecond()), + }}, &response) + assert.NoError(t, err) + var responseValue int64 + assert.NoError(t, json.Unmarshal(response.Payload, &responseValue)) + assert.Equal(t, deadline.UnixNano(), responseValue) +} + type CustomError struct{} func (e CustomError) Error() string { return fmt.Sprintf("Something bad happened!") }