diff --git a/lambda/entry.go b/lambda/entry.go index 6c1d7194..d865e472 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -35,6 +35,9 @@ import ( // // Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library. // See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves +// +// "TOut" may also implement the io.Reader interface. +// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used. func Start(handler interface{}) { StartWithOptions(handler) } diff --git a/lambda/handler.go b/lambda/handler.go index 0fc82d6e..ee273577 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -8,7 +8,10 @@ import ( "encoding/json" "errors" "fmt" + "io" + "io/ioutil" // nolint:staticcheck "reflect" + "strings" "github.com/aws/aws-lambda-go/lambda/handlertrace" ) @@ -18,7 +21,7 @@ type Handler interface { } type handlerOptions struct { - Handler + handlerFunc baseContext context.Context jsonResponseEscapeHTML bool jsonResponseIndentPrefix string @@ -168,32 +171,68 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { if h.enableSIGTERM { enableSIGTERM(h.sigtermCallbacks) } - h.Handler = reflectHandler(handlerFunc, h) + h.handlerFunc = reflectHandler(handlerFunc, h) return h } -type bytesHandlerFunc func(context.Context, []byte) ([]byte, error) +type handlerFunc func(context.Context, []byte) (io.Reader, error) -func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) { - return h(ctx, payload) +// back-compat for the rpc mode +func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) { + response, err := h(ctx, payload) + if err != nil { + return nil, err + } + // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak + if response, ok := response.(io.Closer); ok { + defer response.Close() + } + // optimization: if the response is a *bytes.Buffer, a copy can be eliminated + switch response := response.(type) { + case *jsonOutBuffer: + return response.Bytes(), nil + case *bytes.Buffer: + return response.Bytes(), nil + } + b, err := ioutil.ReadAll(response) + if err != nil { + return nil, err + } + return b, nil } -func errorHandler(err error) Handler { - return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { + +func errorHandler(err error) handlerFunc { + return func(_ context.Context, _ []byte) (io.Reader, error) { return nil, err - }) + } +} + +type jsonOutBuffer struct { + *bytes.Buffer } -func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { - if handlerFunc == nil { +func (j *jsonOutBuffer) ContentType() string { + return contentTypeJSON +} + +func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { + if f == nil { return errorHandler(errors.New("handler is nil")) } - if handler, ok := handlerFunc.(Handler); ok { - return handler + // back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped + if handler, ok := f.(Handler); ok { + return func(ctx context.Context, payload []byte) (io.Reader, error) { + b, err := handler.Invoke(ctx, payload) + if err != nil { + return nil, err + } + return bytes.NewBuffer(b), nil + } } - handler := reflect.ValueOf(handlerFunc) - handlerType := reflect.TypeOf(handlerFunc) + handler := reflect.ValueOf(f) + handlerType := reflect.TypeOf(f) if handlerType.Kind() != reflect.Func { return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func)) } @@ -207,9 +246,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { return errorHandler(err) } - return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) { + out := &jsonOutBuffer{bytes.NewBuffer(nil)} + return func(ctx context.Context, payload []byte) (io.Reader, error) { + out.Reset() in := bytes.NewBuffer(payload) - out := bytes.NewBuffer(nil) decoder := json.NewDecoder(in) encoder := json.NewEncoder(out) encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) @@ -250,16 +290,28 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { trace.ResponseEvent(ctx, val) } } + + // encode to JSON if err := encoder.Encode(val); err != nil { + // if response is not JSON serializable, but the response type is a reader, return it as-is + if reader, ok := val.(io.Reader); ok { + return reader, nil + } return nil, err } - responseBytes := out.Bytes() + // if response value is an io.Reader, return it as-is + if reader, ok := val.(io.Reader); ok { + // back-compat, don't return the reader if the value serialized to a non-empty json + if strings.HasPrefix(out.String(), "{}") { + return reader, nil + } + } + // back-compat, strip the encoder's trailing newline unless WithSetIndent was used if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" { - return responseBytes[:len(responseBytes)-1], nil + out.Truncate(out.Len() - 1) } - - return responseBytes, nil - }) + return out, nil + } } diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 3c3c51d4..17303827 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -3,14 +3,19 @@ package lambda import ( + "bytes" "context" "errors" "fmt" + "io" + "io/ioutil" //nolint: staticcheck + "strings" "testing" "github.com/aws/aws-lambda-go/lambda/handlertrace" "github.com/aws/aws-lambda-go/lambda/messages" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInvalidHandlers(t *testing.T) { @@ -83,6 +88,23 @@ func TestInvalidHandlers(t *testing.T) { } } +type arbitraryJSON struct { + json []byte + err error +} + +func (a arbitraryJSON) MarshalJSON() ([]byte, error) { + return a.json, a.err +} + +type staticHandler struct { + body []byte +} + +func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) { + return h.body, nil +} + type expected struct { val string err error @@ -106,10 +128,8 @@ func TestInvokes(t *testing.T) { }{ { input: `"Lambda"`, - expected: expected{`"Hello Lambda!"`, nil}, - handler: func(name string) (string, error) { - return hello(name), nil - }, + expected: expected{`null`, nil}, + handler: func(_ string) {}, }, { input: `"Lambda"`, @@ -118,6 +138,12 @@ func TestInvokes(t *testing.T) { return hello(name), nil }, }, + { + expected: expected{`"Hello No Value!"`, nil}, + handler: func(ctx context.Context) (string, error) { + return hello("No Value"), nil + }, + }, { input: `"Lambda"`, expected: expected{`"Hello Lambda!"`, nil}, @@ -232,22 +258,86 @@ func TestInvokes(t *testing.T) { { name: "Handler interface implementations are passthrough", expected: expected{`hello`, nil}, - handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { - return []byte(`hello`), nil - }), + handler: &staticHandler{body: []byte(`hello`)}, + }, + { + name: "io.Reader responses are passthrough", + expected: expected{`yolo`, nil}, + handler: func() (io.Reader, error) { + return strings.NewReader(`yolo`), nil + }, + }, + { + name: "io.Reader responses that are byte buffers are passthrough", + expected: expected{`yolo`, nil}, + handler: func() (*bytes.Buffer, error) { + return bytes.NewBuffer([]byte(`yolo`)), nil + }, + }, + { + name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader", + expected: expected{`{"Yolo":"yolo"}`, nil}, + handler: func() (io.Reader, error) { + return struct { + io.Reader `json:"-"` + Yolo string + }{ + Reader: strings.NewReader(`yolo`), + Yolo: "yolo", + }, nil + }, + }, + { + name: "types that are not json serializable result in an error", + expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")}, + handler: func() (interface{}, error) { + return struct { + arbitraryJSON + }{ + arbitraryJSON{nil, errors.New("barf")}, + }, nil + }, + }, + { + name: "io.Reader responses that not json serializable remain passthrough", + expected: expected{`wat`, nil}, + handler: func() (io.Reader, error) { + return struct { + arbitraryJSON + io.Reader + }{ + arbitraryJSON{nil, errors.New("barf")}, + strings.NewReader("wat"), + }, nil + }, }, } for i, testCase := range testCases { testCase := testCase t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { lambdaHandler := newHandler(testCase.handler, testCase.options...) - response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) - if testCase.expected.err != nil { - assert.Equal(t, testCase.expected.err, err) - } else { - assert.NoError(t, err) - assert.Equal(t, testCase.expected.val, string(response)) - } + t.Run("via Handler.Invoke", func(t *testing.T) { + response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) + if testCase.expected.err != nil { + assert.EqualError(t, err, testCase.expected.err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.expected.val, string(response)) + } + }) + t.Run("via handlerOptions.handlerFunc", func(t *testing.T) { + response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input)) + if testCase.expected.err != nil { + assert.EqualError(t, err, testCase.expected.err.Error()) + } else { + assert.NoError(t, err) + require.NotNil(t, response) + responseBytes, err := ioutil.ReadAll(response) + assert.NoError(t, err) + assert.Equal(t, testCase.expected.val, string(responseBytes)) + } + }) + }) } } diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index f73689ba..9e2d6598 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -3,9 +3,11 @@ package lambda import ( + "bytes" "context" "encoding/json" "fmt" + "io" "log" "os" "strconv" @@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke) + response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc) if invokeErr != nil { if err := reportFailure(invoke, invokeErr); err != nil { return err @@ -80,7 +82,19 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { } return nil } - if err := invoke.success(response, contentTypeJSON); err != nil { + // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak + if response, ok := response.(io.Closer); ok { + defer response.Close() + } + + // if the response defines a content-type, plumb it through + contentType := contentTypeBytes + type ContentType interface{ ContentType() string } + if response, ok := response.(ContentType); ok { + contentType = response.ContentType() + } + + if err := invoke.success(response, contentType); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err) } @@ -90,13 +104,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error { errorPayload := safeMarshal(invokeErr) log.Printf("%s", errorPayload) - if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { + if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil { return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) } return nil } -func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) { +func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) { defer func() { if err := recover(); err != nil { invokeErr = lambdaPanicResponse(err) diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 54ec96cf..fab800b9 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -86,6 +86,7 @@ func TestCustomErrorMarshaling(t *testing.T) { assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) for i := range errors { assert.JSONEq(t, expected[i], string(record.responses[i])) + assert.Equal(t, contentTypeJSON, record.contentTypes[i]) } } @@ -156,7 +157,64 @@ func TestReadPayload(t *testing.T) { endpoint := strings.Split(ts.URL, "://")[1] _ = startRuntimeAPILoop(endpoint, handler) assert.Equal(t, `"socat gnivarc ma I"`, string(record.responses[0])) + assert.Equal(t, contentTypeJSON, record.contentTypes[0]) +} + +type readCloser struct { + closed bool + reader *strings.Reader +} + +func (r *readCloser) Read(p []byte) (int, error) { + return r.reader.Read(p) +} + +func (r *readCloser) Close() error { + r.closed = true + return nil +} + +func TestBinaryResponseDefaultContentType(t *testing.T) { + ts, record := runtimeAPIServer(`{"message": "I am craving tacos"}`, 1) + defer ts.Close() + + handler := NewHandler(func(event struct{ Message string }) (io.Reader, error) { + length := utf8.RuneCountInString(event.Message) + reversed := make([]rune, length) + for i, v := range event.Message { + reversed[length-i-1] = v + } + return strings.NewReader(string(reversed)), nil + }) + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + assert.Equal(t, `socat gnivarc ma I`, string(record.responses[0])) + assert.Equal(t, contentTypeBytes, record.contentTypes[0]) +} + +func TestBinaryResponseDoesNotLeakResources(t *testing.T) { + numResponses := 3 + responses := make([]*readCloser, numResponses) + for i := 0; i < numResponses; i++ { + responses[i] = &readCloser{closed: false, reader: strings.NewReader(fmt.Sprintf("hello %d", i))} + } + timesCalled := 0 + handler := NewHandler(func() (res io.Reader, _ error) { + res = responses[timesCalled] + timesCalled++ + return + }) + ts, record := runtimeAPIServer(`{}`, numResponses) + defer ts.Close() + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + + for i := 0; i < numResponses; i++ { + assert.Equal(t, contentTypeBytes, record.contentTypes[i]) + assert.Equal(t, fmt.Sprintf("hello %d", i), string(record.responses[i])) + assert.True(t, responses[i].closed) + } } func TestContextDeserializationErrors(t *testing.T) { @@ -209,9 +267,10 @@ func TestSafeMarshal_SerializationError(t *testing.T) { } type requestRecord struct { - nGets int - nPosts int - responses [][]byte + nGets int + nPosts int + responses [][]byte + contentTypes []string } type eventMetadata struct { @@ -276,6 +335,7 @@ func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMeta _ = r.Body.Close() w.WriteHeader(http.StatusAccepted) record.responses = append(record.responses, response.Bytes()) + record.contentTypes = append(record.contentTypes, r.Header.Get("Content-Type")) default: w.WriteHeader(http.StatusBadRequest) } diff --git a/lambda/rpc_function_test.go b/lambda/rpc_function_test.go index 6935084d..515fc62a 100644 --- a/lambda/rpc_function_test.go +++ b/lambda/rpc_function_test.go @@ -9,7 +9,10 @@ import ( "context" "encoding/json" "errors" + "io" "os" + "strconv" + "strings" "testing" "time" @@ -63,14 +66,13 @@ func TestInvoke(t *testing.T) { func TestInvokeWithContext(t *testing.T) { key := struct{}{} srv := NewFunction(&handlerOptions{ - Handler: testWrapperHandler( - func(ctx context.Context, input []byte) (interface{}, error) { - assert.Equal(t, "dummy", ctx.Value(key)) - if deadline, ok := ctx.Deadline(); ok { - return deadline.UnixNano(), nil - } - return nil, errors.New("!?!?!?!?!") - }), + handlerFunc: func(ctx context.Context, _ []byte) (io.Reader, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return strings.NewReader(strconv.FormatInt(deadline.UnixNano(), 10)), nil + } + return nil, errors.New("!?!?!?!?!") + }, baseContext: context.WithValue(context.Background(), key, "dummy"), }) deadline := time.Now() @@ -231,3 +233,56 @@ func TestXAmznTraceID(t *testing.T) { } } + +type closeableResponse struct { + reader io.Reader + closed bool +} + +func (c *closeableResponse) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func (c *closeableResponse) Close() error { + c.closed = true + return nil +} + +type readerError struct { + err error +} + +func (r *readerError) Read(_ []byte) (int, error) { + return 0, r.err +} + +func TestRPCModeInvokeClosesCloserIfResponseIsCloser(t *testing.T) { + handlerResource := &closeableResponse{ + reader: strings.NewReader(""), + closed: false, + } + srv := NewFunction(newHandler(func() (interface{}, error) { + return handlerResource, nil + })) + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{}, &response) + require.NoError(t, err) + assert.Equal(t, "", string(response.Payload)) + assert.True(t, handlerResource.closed) +} + +func TestRPCModeInvokeReaderErrorPropogated(t *testing.T) { + handlerResource := &closeableResponse{ + reader: &readerError{errors.New("yolo")}, + closed: false, + } + srv := NewFunction(newHandler(func() (interface{}, error) { + return handlerResource, nil + })) + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{}, &response) + require.NoError(t, err) + assert.Equal(t, "", string(response.Payload)) + assert.Equal(t, "yolo", response.Error.Message) + assert.True(t, handlerResource.closed) +} diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index 843f7ace..a83c3ce8 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -22,6 +22,7 @@ const ( headerClientContext = "Lambda-Runtime-Client-Context" headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" contentTypeJSON = "application/json" + contentTypeBytes = "application/octet-stream" apiVersion = "2018-06-01" ) @@ -51,9 +52,9 @@ type invoke struct { // success sends the response payload for an in-progress invocation. // Notes: // - An invoke is not complete until next() is called again! -func (i *invoke) success(payload []byte, contentType string) error { +func (i *invoke) success(body io.Reader, contentType string) error { url := i.client.baseURL + i.id + "/response" - return i.client.post(url, payload, contentType) + return i.client.post(url, body, contentType) } // failure sends the payload to the Runtime API. This marks the function's invoke as a failure. @@ -61,9 +62,9 @@ func (i *invoke) success(payload []byte, contentType string) error { // - The execution of the function process continues, and is billed, until next() is called again! // - A Lambda Function continues to be re-used for future invokes even after a failure. // If the error is fatal (panic, unrecoverable state), exit the process immediately after calling failure() -func (i *invoke) failure(payload []byte, contentType string) error { +func (i *invoke) failure(body io.Reader, contentType string) error { url := i.client.baseURL + i.id + "/error" - return i.client.post(url, payload, contentType) + return i.client.post(url, body, contentType) } // next connects to the Runtime API and waits for a new invoke Request to be available. @@ -104,8 +105,8 @@ func (c *runtimeAPIClient) next() (*invoke, error) { }, nil } -func (c *runtimeAPIClient) post(url string, payload []byte, contentType string) error { - req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload)) +func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) error { + req, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return fmt.Errorf("failed to construct POST request to %s: %v", url, err) } diff --git a/lambda/runtime_api_client_test.go b/lambda/runtime_api_client_test.go index 4693048b..7ccd47fb 100644 --- a/lambda/runtime_api_client_test.go +++ b/lambda/runtime_api_client_test.go @@ -3,6 +3,7 @@ package lambda import ( + "bytes" "fmt" "io/ioutil" //nolint: staticcheck "net/http" @@ -87,11 +88,11 @@ func TestClientDoneAndError(t *testing.T) { client: client, } t.Run(fmt.Sprintf("happy Done with payload[%d]", i), func(t *testing.T) { - err := invoke.success(payload, contentTypeJSON) + err := invoke.success(bytes.NewReader(payload), contentTypeJSON) assert.NoError(t, err) }) t.Run(fmt.Sprintf("happy Error with payload[%d]", i), func(t *testing.T) { - err := invoke.failure(payload, contentTypeJSON) + err := invoke.failure(bytes.NewReader(payload), contentTypeJSON) assert.NoError(t, err) }) }