Skip to content

Support handlers that return io.Reader #472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 23, 2022
3 changes: 3 additions & 0 deletions lambda/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import (
//
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
// See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves
//
// "TOut" may also implement the io.Reader interface.
// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used.
func Start(handler interface{}) {
StartWithOptions(handler)
}
Expand Down
94 changes: 73 additions & 21 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil" // nolint:staticcheck
"reflect"
"strings"

"github.com/aws/aws-lambda-go/lambda/handlertrace"
)
Expand All @@ -18,7 +21,7 @@ type Handler interface {
}

type handlerOptions struct {
Handler
handlerFunc
baseContext context.Context
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
Expand Down Expand Up @@ -168,32 +171,68 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
h.Handler = reflectHandler(handlerFunc, h)
h.handlerFunc = reflectHandler(handlerFunc, h)
return h
}

type bytesHandlerFunc func(context.Context, []byte) ([]byte, error)
type handlerFunc func(context.Context, []byte) (io.Reader, error)

func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
return h(ctx, payload)
// back-compat for the rpc mode
func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
response, err := h(ctx, payload)
if err != nil {
return nil, err
}
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
if response, ok := response.(io.Closer); ok {
defer response.Close()
}
// optimization: if the response is a *bytes.Buffer, a copy can be eliminated
switch response := response.(type) {
case *jsonOutBuffer:
return response.Bytes(), nil
case *bytes.Buffer:
return response.Bytes(), nil
}
b, err := ioutil.ReadAll(response)
if err != nil {
return nil, err
}
return b, nil
}
func errorHandler(err error) Handler {
return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {

func errorHandler(err error) handlerFunc {
return func(_ context.Context, _ []byte) (io.Reader, error) {
return nil, err
})
}
}

type jsonOutBuffer struct {
*bytes.Buffer
}

func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
if handlerFunc == nil {
func (j *jsonOutBuffer) ContentType() string {
return contentTypeJSON
}

func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
if f == nil {
return errorHandler(errors.New("handler is nil"))
}

if handler, ok := handlerFunc.(Handler); ok {
return handler
// back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped
if handler, ok := f.(Handler); ok {
return func(ctx context.Context, payload []byte) (io.Reader, error) {
b, err := handler.Invoke(ctx, payload)
if err != nil {
return nil, err
}
return bytes.NewBuffer(b), nil
}
}

handler := reflect.ValueOf(handlerFunc)
handlerType := reflect.TypeOf(handlerFunc)
handler := reflect.ValueOf(f)
handlerType := reflect.TypeOf(f)
if handlerType.Kind() != reflect.Func {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}
Expand All @@ -207,9 +246,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
return errorHandler(err)
}

return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) {
out := &jsonOutBuffer{bytes.NewBuffer(nil)}
return func(ctx context.Context, payload []byte) (io.Reader, error) {
out.Reset()
in := bytes.NewBuffer(payload)
out := bytes.NewBuffer(nil)
decoder := json.NewDecoder(in)
encoder := json.NewEncoder(out)
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
Expand Down Expand Up @@ -250,16 +290,28 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
trace.ResponseEvent(ctx, val)
}
}

// encode to JSON
if err := encoder.Encode(val); err != nil {
// if response is not JSON serializable, but the response type is a reader, return it as-is
if reader, ok := val.(io.Reader); ok {
return reader, nil
}
return nil, err
}

responseBytes := out.Bytes()
// if response value is an io.Reader, return it as-is
if reader, ok := val.(io.Reader); ok {
// back-compat, don't return the reader if the value serialized to a non-empty json
if strings.HasPrefix(out.String(), "{}") {
return reader, nil
}
}

// back-compat, strip the encoder's trailing newline unless WithSetIndent was used
if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" {
return responseBytes[:len(responseBytes)-1], nil
out.Truncate(out.Len() - 1)
}

return responseBytes, nil
})
return out, nil
}
}
118 changes: 104 additions & 14 deletions lambda/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
package lambda

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil" //nolint: staticcheck
"strings"
"testing"

"github.com/aws/aws-lambda-go/lambda/handlertrace"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestInvalidHandlers(t *testing.T) {
Expand Down Expand Up @@ -83,6 +88,23 @@ func TestInvalidHandlers(t *testing.T) {
}
}

type arbitraryJSON struct {
json []byte
err error
}

func (a arbitraryJSON) MarshalJSON() ([]byte, error) {
return a.json, a.err
}

type staticHandler struct {
body []byte
}

func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) {
return h.body, nil
}

type expected struct {
val string
err error
Expand All @@ -106,10 +128,8 @@ func TestInvokes(t *testing.T) {
}{
{
input: `"Lambda"`,
expected: expected{`"Hello Lambda!"`, nil},
handler: func(name string) (string, error) {
return hello(name), nil
},
expected: expected{`null`, nil},
handler: func(_ string) {},
},
{
input: `"Lambda"`,
Expand All @@ -118,6 +138,12 @@ func TestInvokes(t *testing.T) {
return hello(name), nil
},
},
{
expected: expected{`"Hello No Value!"`, nil},
handler: func(ctx context.Context) (string, error) {
return hello("No Value"), nil
},
},
{
input: `"Lambda"`,
expected: expected{`"Hello Lambda!"`, nil},
Expand Down Expand Up @@ -232,22 +258,86 @@ func TestInvokes(t *testing.T) {
{
name: "Handler interface implementations are passthrough",
expected: expected{`<xml>hello</xml>`, nil},
handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
return []byte(`<xml>hello</xml>`), nil
}),
handler: &staticHandler{body: []byte(`<xml>hello</xml>`)},
},
{
name: "io.Reader responses are passthrough",
expected: expected{`<yolo>yolo</yolo>`, nil},
handler: func() (io.Reader, error) {
return strings.NewReader(`<yolo>yolo</yolo>`), nil
},
},
{
name: "io.Reader responses that are byte buffers are passthrough",
expected: expected{`<yolo>yolo</yolo>`, nil},
handler: func() (*bytes.Buffer, error) {
return bytes.NewBuffer([]byte(`<yolo>yolo</yolo>`)), nil
},
},
{
name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader",
expected: expected{`{"Yolo":"yolo"}`, nil},
handler: func() (io.Reader, error) {
return struct {
io.Reader `json:"-"`
Yolo string
}{
Reader: strings.NewReader(`<yolo>yolo</yolo>`),
Yolo: "yolo",
}, nil
},
},
{
name: "types that are not json serializable result in an error",
expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")},
handler: func() (interface{}, error) {
return struct {
arbitraryJSON
}{
arbitraryJSON{nil, errors.New("barf")},
}, nil
},
},
{
name: "io.Reader responses that not json serializable remain passthrough",
expected: expected{`wat`, nil},
handler: func() (io.Reader, error) {
return struct {
arbitraryJSON
io.Reader
}{
arbitraryJSON{nil, errors.New("barf")},
strings.NewReader("wat"),
}, nil
},
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
lambdaHandler := newHandler(testCase.handler, testCase.options...)
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.Equal(t, testCase.expected.err, err)
} else {
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(response))
}
t.Run("via Handler.Invoke", func(t *testing.T) {
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.EqualError(t, err, testCase.expected.err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(response))
}
})
t.Run("via handlerOptions.handlerFunc", func(t *testing.T) {
response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.EqualError(t, err, testCase.expected.err.Error())
} else {
assert.NoError(t, err)
require.NotNil(t, response)
responseBytes, err := ioutil.ReadAll(response)
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(responseBytes))
}
})

})
}
}
Expand Down
22 changes: 18 additions & 4 deletions lambda/invoke_loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package lambda

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"strconv"
Expand Down Expand Up @@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)

// call the handler, marshal any returned error
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc)
if invokeErr != nil {
if err := reportFailure(invoke, invokeErr); err != nil {
return err
Expand All @@ -80,7 +82,19 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
}
return nil
}
if err := invoke.success(response, contentTypeJSON); err != nil {
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
if response, ok := response.(io.Closer); ok {
defer response.Close()
}

// if the response defines a content-type, plumb it through
contentType := contentTypeBytes
type ContentType interface{ ContentType() string }
if response, ok := response.(ContentType); ok {
contentType = response.ContentType()
}

if err := invoke.success(response, contentType); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
}

Expand All @@ -90,13 +104,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
errorPayload := safeMarshal(invokeErr)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
}
return nil
}

func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) {
defer func() {
if err := recover(); err != nil {
invokeErr = lambdaPanicResponse(err)
Expand Down
Loading