diff --git a/lambda/entry.go b/lambda/entry.go index c4a78522..f71ac455 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -37,12 +37,14 @@ import ( // Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library. // See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves func Start(handler interface{}) { - StartWithContext(context.Background(), handler) + StartWithOptions(handler) } // StartWithContext is the same as Start except sets the base context for the function. +// +// Deprecated: use lambda.StartWithOptions(handler, lambda.WithContext(ctx)) instead func StartWithContext(ctx context.Context, handler interface{}) { - StartHandlerWithContext(ctx, NewHandler(handler)) + StartWithOptions(handler, WithContext(ctx)) } // StartHandler takes in a Handler wrapper interface which can be implemented either by a @@ -51,13 +53,20 @@ func StartWithContext(ctx context.Context, handler interface{}) { // Handler implementation requires a single "Invoke()" function: // // func Invoke(context.Context, []byte) ([]byte, error) +// +// Deprecated: use lambda.Start(handler) instead func StartHandler(handler Handler) { - StartHandlerWithContext(context.Background(), handler) + StartWithOptions(handler) +} + +// StartWithOptions is the same as Start after the application of any handler options specified +func StartWithOptions(handler interface{}, options ...Option) { + start(newHandler(handler, options...)) } type startFunction struct { env string - f func(ctx context.Context, envValue string, handler Handler) error + f func(envValue string, handler Handler) error } var ( @@ -66,7 +75,7 @@ var ( // To drop the rpc dependencies, compile with `-tags lambda.norpc` rpcStartFunction = &startFunction{ env: "_LAMBDA_SERVER_PORT", - f: func(c context.Context, p string, h Handler) error { + f: func(_ string, _ Handler) error { return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support") }, } @@ -85,17 +94,24 @@ var ( // Handler implementation requires a single "Invoke()" function: // // func Invoke(context.Context, []byte) ([]byte, error) +// +// Deprecated: use lambda.StartWithOptions(handler, lambda.WithContext(ctx)) instead func StartHandlerWithContext(ctx context.Context, handler Handler) { + StartWithOptions(handler, WithContext(ctx)) +} + +func start(handler *handlerOptions) { var keys []string for _, start := range startFunctions { config := os.Getenv(start.env) if config != "" { // in normal operation, the start function never returns // if it does, exit!, this triggers a restart of the lambda function - err := start.f(ctx, config, handler) + err := start.f(config, handler) logFatalf("%v", err) } keys = append(keys, start.env) } logFatalf("expected AWS Lambda environment variables %s are not defined", keys) + } diff --git a/lambda/function.go b/lambda/function.go index d03aff76..e6fe464f 100644 --- a/lambda/function.go +++ b/lambda/function.go @@ -13,14 +13,17 @@ import ( ) // Function struct which wrap the Handler +// +// Deprecated: The Function type is public for the go1.x runtime internal use of the net/rpc package type Function struct { - handler Handler - ctx context.Context + handler *handlerOptions } // NewFunction which creates a Function with a given Handler +// +// Deprecated: The Function type is public for the go1.x runtime internal use of the net/rpc package func NewFunction(handler Handler) *Function { - return &Function{handler: handler} + return &Function{newHandler(handler)} } // Ping method which given a PingRequest and a PingResponse parses the PingResponse @@ -38,7 +41,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok }() deadline := time.Unix(req.Deadline.Seconds, req.Deadline.Nanos).UTC() - invokeContext, cancel := context.WithDeadline(fn.context(), deadline) + invokeContext, cancel := context.WithDeadline(fn.baseContext(), deadline) defer cancel() lc := &lambdacontext.LambdaContext{ @@ -70,26 +73,9 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok return nil } -// context returns the base context used for the fn. -func (fn *Function) context() context.Context { - if fn.ctx == nil { - return context.Background() +func (fn *Function) baseContext() context.Context { + if fn.handler.baseContext != nil { + return fn.handler.baseContext } - - return fn.ctx -} - -// withContext returns a shallow copy of Function with its context changed -// to the provided ctx. If the provided ctx is non-nil a Background context is set. -func (fn *Function) withContext(ctx context.Context) *Function { - if ctx == nil { - ctx = context.Background() - } - - fn2 := new(Function) - *fn2 = *fn - - fn2.ctx = ctx - - return fn2 + return context.Background() } diff --git a/lambda/function_test.go b/lambda/function_test.go index 08d4fb8f..a02c25b8 100644 --- a/lambda/function_test.go +++ b/lambda/function_test.go @@ -36,14 +36,14 @@ func (h testWrapperHandler) Invoke(ctx context.Context, payload []byte) ([]byte, var _ Handler = (testWrapperHandler)(nil) func TestInvoke(t *testing.T) { - srv := &Function{handler: testWrapperHandler( + srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { if deadline, ok := ctx.Deadline(); ok { return deadline.UnixNano(), nil } return nil, errors.New("!?!?!?!?!") }, - )} + )) deadline := time.Now() var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ @@ -59,15 +59,17 @@ func TestInvoke(t *testing.T) { func TestInvokeWithContext(t *testing.T) { key := struct{}{} - srv := NewFunction(testWrapperHandler( - func(ctx context.Context, input []byte) (interface{}, error) { - assert.Equal(t, "dummy", ctx.Value(key)) - if deadline, ok := ctx.Deadline(); ok { - return deadline.UnixNano(), nil - } - return nil, errors.New("!?!?!?!?!") - })) - srv = srv.withContext(context.WithValue(context.Background(), key, "dummy")) + srv := NewFunction(&handlerOptions{ + Handler: testWrapperHandler( + func(ctx context.Context, input []byte) (interface{}, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return deadline.UnixNano(), nil + } + return nil, errors.New("!?!?!?!?!") + }), + baseContext: context.WithValue(context.Background(), key, "dummy"), + }) deadline := time.Now() var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ @@ -86,12 +88,11 @@ type CustomError struct{} func (e CustomError) Error() string { return "Something bad happened!" } func TestCustomError(t *testing.T) { - - srv := &Function{handler: testWrapperHandler( + srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return nil, CustomError{} }, - )} + )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) assert.NoError(t, err) @@ -106,11 +107,11 @@ func (e *CustomError2) Error() string { return "Something bad happened!" } func TestCustomErrorRef(t *testing.T) { - srv := &Function{handler: testWrapperHandler( + srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return nil, &CustomError2{} }, - )} + )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) assert.NoError(t, err) @@ -120,12 +121,12 @@ func TestCustomErrorRef(t *testing.T) { } func TestContextPlumbing(t *testing.T) { - srv := &Function{handler: testWrapperHandler( + srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { lc, _ := lambdacontext.FromContext(ctx) return lc, nil }, - )} + )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ CognitoIdentityId: "dummyident", @@ -172,14 +173,14 @@ func TestXAmznTraceID(t *testing.T) { Ctx string } - srv := &Function{handler: testWrapperHandler( + srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return &XRayResponse{ Env: os.Getenv("_X_AMZN_TRACE_ID"), Ctx: ctx.Value("x-amzn-trace-id").(string), }, nil }, - )} + )) sequence := []struct { Input string diff --git a/lambda/handler.go b/lambda/handler.go index 39b8d6d8..354459b1 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -3,8 +3,10 @@ package lambda import ( + "bytes" "context" "encoding/json" + "errors" "fmt" "reflect" @@ -15,29 +17,60 @@ type Handler interface { Invoke(ctx context.Context, payload []byte) ([]byte, error) } -// lambdaHandler is the generic function type -type lambdaHandler func(context.Context, []byte) (interface{}, error) - -// Invoke calls the handler, and serializes the response. -// If the underlying handler returned an error, or an error occurs during serialization, error is returned. -func (handler lambdaHandler) Invoke(ctx context.Context, payload []byte) ([]byte, error) { - response, err := handler(ctx, payload) - if err != nil { - return nil, err - } +type handlerOptions struct { + Handler + baseContext context.Context + jsonResponseEscapeHTML bool + jsonResponseIndentPrefix string + jsonResponseIndentValue string +} - responseBytes, err := json.Marshal(response) - if err != nil { - return nil, err - } +type Option func(*handlerOptions) + +// WithContext is a HandlerOption that sets the base context for all invocations of the handler. +// +// Usage: +// lambda.StartWithOptions( +// func (ctx context.Context) (string, error) { +// return ctx.Value("foo"), nil +// }, +// lambda.WithContext(context.WithValue(context.Background(), "foo", "bar")) +// ) +func WithContext(ctx context.Context) Option { + return Option(func(h *handlerOptions) { + h.baseContext = ctx + }) +} - return responseBytes, nil +// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder +// +// Usage: +// lambda.StartWithOptions( +// func () (string, error) { +// return "hello!>", nil +// }, +// lambda.WithSetEscapeHTML(true), +// ) +func WithSetEscapeHTML(escapeHTML bool) Option { + return Option(func(h *handlerOptions) { + h.jsonResponseEscapeHTML = escapeHTML + }) } -func errorHandler(e error) lambdaHandler { - return func(ctx context.Context, event []byte) (interface{}, error) { - return nil, e - } +// WithSetIndent sets the SetIndent argument on the underling json encoder +// +// Usage: +// lambda.StartWithOptions( +// func (event any) (any, error) { +// return event, nil +// }, +// lambda.WithSetIndent(">"," "), +// ) +func WithSetIndent(prefix, indent string) Option { + return Option(func(h *handlerOptions) { + h.jsonResponseIndentPrefix = prefix + h.jsonResponseIndentValue = indent + }) } func validateArguments(handler reflect.Type) (bool, error) { @@ -77,13 +110,59 @@ func validateReturns(handler reflect.Type) error { // NewHandler creates a base lambda handler from the given handler function. The // returned Handler performs JSON serialization and deserialization, and -// delegates to the input handler function. The handler function parameter must -// satisfy the rules documented by Start. If handlerFunc is not a valid +// delegates to the input handler function. The handler function parameter must +// satisfy the rules documented by Start. If handlerFunc is not a valid // handler, the returned Handler simply reports the validation error. func NewHandler(handlerFunc interface{}) Handler { + return NewHandlerWithOptions(handlerFunc) +} + +// NewHandlerWithOptions creates a base lambda handler from the given handler function. The +// returned Handler performs JSON serialization and deserialization, and +// delegates to the input handler function. The handler function parameter must +// satisfy the rules documented by Start. If handlerFunc is not a valid +// handler, the returned Handler simply reports the validation error. +func NewHandlerWithOptions(handlerFunc interface{}, options ...Option) Handler { + return newHandler(handlerFunc, options...) +} + +func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { + if h, ok := handlerFunc.(*handlerOptions); ok { + return h + } + h := &handlerOptions{ + baseContext: context.Background(), + jsonResponseEscapeHTML: false, + jsonResponseIndentPrefix: "", + jsonResponseIndentValue: "", + } + for _, option := range options { + option(h) + } + h.Handler = reflectHandler(handlerFunc, h) + return h +} + +type bytesHandlerFunc func(context.Context, []byte) ([]byte, error) + +func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) { + return h(ctx, payload) +} +func errorHandler(err error) Handler { + return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { + return nil, err + }) +} + +func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { if handlerFunc == nil { - return errorHandler(fmt.Errorf("handler is nil")) + return errorHandler(errors.New("handler is nil")) + } + + if handler, ok := handlerFunc.(Handler); ok { + return handler } + handler := reflect.ValueOf(handlerFunc) handlerType := reflect.TypeOf(handlerFunc) if handlerType.Kind() != reflect.Func { @@ -99,7 +178,13 @@ func NewHandler(handlerFunc interface{}) Handler { return errorHandler(err) } - return lambdaHandler(func(ctx context.Context, payload []byte) (interface{}, error) { + return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) { + in := bytes.NewBuffer(payload) + out := bytes.NewBuffer(nil) + decoder := json.NewDecoder(in) + encoder := json.NewEncoder(out) + encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) + encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue) trace := handlertrace.FromContext(ctx) @@ -111,8 +196,7 @@ func NewHandler(handlerFunc interface{}) Handler { if (handlerType.NumIn() == 1 && !takesContext) || handlerType.NumIn() == 2 { eventType := handlerType.In(handlerType.NumIn() - 1) event := reflect.New(eventType) - - if err := json.Unmarshal(payload, event.Interface()); err != nil { + if err := decoder.Decode(event.Interface()); err != nil { return nil, err } if nil != trace.RequestEvent { @@ -123,22 +207,30 @@ func NewHandler(handlerFunc interface{}) Handler { response := handler.Call(args) - // convert return values into (interface{}, error) - var err error + // return the error, if any if len(response) > 0 { - if errVal, ok := response[len(response)-1].Interface().(error); ok { - err = errVal + if errVal, ok := response[len(response)-1].Interface().(error); ok && errVal != nil { + return nil, errVal } } + // set the response value, if any var val interface{} if len(response) > 1 { val = response[0].Interface() - if nil != trace.ResponseEvent { trace.ResponseEvent(ctx, val) } } + if err := encoder.Encode(val); err != nil { + return nil, err + } + + responseBytes := out.Bytes() + // back-compat, strip the encoder's trailing newline unless WithSetIndent was used + if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" { + return responseBytes[:len(responseBytes)-1], nil + } - return val, err + return responseBytes, nil }) } diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 157a28cc..3c3c51d4 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -102,6 +102,7 @@ func TestInvokes(t *testing.T) { input string expected expected handler interface{} + options []Option }{ { input: `"Lambda"`, @@ -196,11 +197,50 @@ func TestInvokes(t *testing.T) { return nil, messages.InvokeResponse_Error{Message: "message", Type: "type"} }, }, + { + name: "WithSetEscapeHTML(false)", + expected: expected{`"html in json string!"`, nil}, + handler: func() (string, error) { + return "html in json string!", nil + }, + options: []Option{WithSetEscapeHTML(false)}, + }, + { + name: "WithSetEscapeHTML(true)", + expected: expected{`"\u003chtml\u003e\u003cbody\u003ehtml in json string!\u003c/body\u003e\u003c/html\u003e"`, nil}, + handler: func() (string, error) { + return "html in json string!", nil + }, + options: []Option{WithSetEscapeHTML(true)}, + }, + { + name: `WithSetIndent(">>", " ")`, + expected: expected{"{\n>> \"Foo\": \"Bar\"\n>>}\n", nil}, + handler: func() (interface{}, error) { + return struct{ Foo string }{"Bar"}, nil + }, + options: []Option{WithSetIndent(">>", " ")}, + }, + { + name: "bytes are base64 encoded strings", + input: `"aGVsbG8="`, + expected: expected{`"aGVsbG95b2xv"`, nil}, + handler: func(_ context.Context, req []byte) ([]byte, error) { + return append(req, []byte("yolo")...), nil + }, + }, + { + name: "Handler interface implementations are passthrough", + expected: expected{`hello`, nil}, + handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { + return []byte(`hello`), nil + }), + }, } for i, testCase := range testCases { testCase := testCase t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { - lambdaHandler := NewHandler(testCase.handler) + lambdaHandler := newHandler(testCase.handler, testCase.options...) response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) if testCase.expected.err != nil { assert.Equal(t, testCase.expected.err, err) @@ -215,7 +255,7 @@ func TestInvokes(t *testing.T) { func TestInvalidJsonInput(t *testing.T) { lambdaHandler := NewHandler(func(s string) error { return nil }) _, err := lambdaHandler.Invoke(context.TODO(), []byte(`{"invalid json`)) - assert.Equal(t, "unexpected end of JSON input", err.Error()) + assert.Equal(t, "unexpected EOF", err.Error()) } func TestHandlerTrace(t *testing.T) { diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index f9f7005a..dca2bf68 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -3,7 +3,6 @@ package lambda import ( - "context" "encoding/json" "fmt" "log" @@ -19,9 +18,9 @@ const ( ) // startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error -func startRuntimeAPILoop(ctx context.Context, api string, handler Handler) error { +func startRuntimeAPILoop(api string, handler Handler) error { client := newRuntimeAPIClient(api) - function := NewFunction(handler).withContext(ctx) + function := NewFunction(handler) for { invoke, err := client.next() if err != nil { diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 8e76ac2f..9e94d74d 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -26,7 +26,7 @@ func TestFatalErrors(t *testing.T) { }) endpoint := strings.Split(ts.URL, "://")[1] expectedErrorMessage := "calling the handler function resulted in a panic, the process should exit" - assert.EqualError(t, startRuntimeAPILoop(context.Background(), endpoint, handler), expectedErrorMessage) + assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedErrorMessage) assert.Equal(t, 1, record.nGets) assert.Equal(t, 1, record.nGets) } @@ -47,7 +47,7 @@ func TestRuntimeAPILoop(t *testing.T) { }) endpoint := strings.Split(ts.URL, "://")[1] expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) - assert.EqualError(t, startRuntimeAPILoop(context.Background(), endpoint, handler), expectedError) + assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) assert.Equal(t, nInvokes+1, record.nGets) assert.Equal(t, nInvokes, record.nPosts) } @@ -63,7 +63,7 @@ func TestRuntimeAPIContextPlumbing(t *testing.T) { endpoint := strings.Split(ts.URL, "://")[1] expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) - assert.EqualError(t, startRuntimeAPILoop(context.Background(), endpoint, handler), expectedError) + assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) expected := ` { @@ -101,7 +101,7 @@ func TestReadPayload(t *testing.T) { return string(reversed), nil }) endpoint := strings.Split(ts.URL, "://")[1] - _ = startRuntimeAPILoop(context.Background(), endpoint, handler) + _ = startRuntimeAPILoop(endpoint, handler) assert.Equal(t, `"socat gnivarc ma I"`, string(record.responses[0])) } diff --git a/lambda/rpc.go b/lambda/rpc.go index 05e92836..8c232a98 100644 --- a/lambda/rpc.go +++ b/lambda/rpc.go @@ -6,7 +6,6 @@ package lambda import ( - "context" "errors" "log" "net" @@ -20,12 +19,12 @@ func init() { rpcStartFunction.f = startFunctionRPC } -func startFunctionRPC(ctx context.Context, port string, handler Handler) error { +func startFunctionRPC(port string, handler Handler) error { lis, err := net.Listen("tcp", "localhost:"+port) if err != nil { log.Fatal(err) } - err = rpc.Register(NewFunction(handler).withContext(ctx)) + err = rpc.Register(NewFunction(handler)) if err != nil { log.Fatal("failed to register handler function") }