diff --git a/lambda/entry.go b/lambda/entry.go
index 6c1d7194..d865e472 100644
--- a/lambda/entry.go
+++ b/lambda/entry.go
@@ -35,6 +35,9 @@ import (
//
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
// See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves
+//
+// "TOut" may also implement the io.Reader interface.
+// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used.
func Start(handler interface{}) {
StartWithOptions(handler)
}
diff --git a/lambda/handler.go b/lambda/handler.go
index 0fc82d6e..ee273577 100644
--- a/lambda/handler.go
+++ b/lambda/handler.go
@@ -8,7 +8,10 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
+ "io/ioutil" // nolint:staticcheck
"reflect"
+ "strings"
"github.com/aws/aws-lambda-go/lambda/handlertrace"
)
@@ -18,7 +21,7 @@ type Handler interface {
}
type handlerOptions struct {
- Handler
+ handlerFunc
baseContext context.Context
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
@@ -168,32 +171,68 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
- h.Handler = reflectHandler(handlerFunc, h)
+ h.handlerFunc = reflectHandler(handlerFunc, h)
return h
}
-type bytesHandlerFunc func(context.Context, []byte) ([]byte, error)
+type handlerFunc func(context.Context, []byte) (io.Reader, error)
-func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
- return h(ctx, payload)
+// back-compat for the rpc mode
+func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
+ response, err := h(ctx, payload)
+ if err != nil {
+ return nil, err
+ }
+ // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
+ if response, ok := response.(io.Closer); ok {
+ defer response.Close()
+ }
+ // optimization: if the response is a *bytes.Buffer, a copy can be eliminated
+ switch response := response.(type) {
+ case *jsonOutBuffer:
+ return response.Bytes(), nil
+ case *bytes.Buffer:
+ return response.Bytes(), nil
+ }
+ b, err := ioutil.ReadAll(response)
+ if err != nil {
+ return nil, err
+ }
+ return b, nil
}
-func errorHandler(err error) Handler {
- return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
+
+func errorHandler(err error) handlerFunc {
+ return func(_ context.Context, _ []byte) (io.Reader, error) {
return nil, err
- })
+ }
+}
+
+type jsonOutBuffer struct {
+ *bytes.Buffer
}
-func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
- if handlerFunc == nil {
+func (j *jsonOutBuffer) ContentType() string {
+ return contentTypeJSON
+}
+
+func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
+ if f == nil {
return errorHandler(errors.New("handler is nil"))
}
- if handler, ok := handlerFunc.(Handler); ok {
- return handler
+ // back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped
+ if handler, ok := f.(Handler); ok {
+ return func(ctx context.Context, payload []byte) (io.Reader, error) {
+ b, err := handler.Invoke(ctx, payload)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewBuffer(b), nil
+ }
}
- handler := reflect.ValueOf(handlerFunc)
- handlerType := reflect.TypeOf(handlerFunc)
+ handler := reflect.ValueOf(f)
+ handlerType := reflect.TypeOf(f)
if handlerType.Kind() != reflect.Func {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}
@@ -207,9 +246,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
return errorHandler(err)
}
- return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) {
+ out := &jsonOutBuffer{bytes.NewBuffer(nil)}
+ return func(ctx context.Context, payload []byte) (io.Reader, error) {
+ out.Reset()
in := bytes.NewBuffer(payload)
- out := bytes.NewBuffer(nil)
decoder := json.NewDecoder(in)
encoder := json.NewEncoder(out)
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
@@ -250,16 +290,28 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
trace.ResponseEvent(ctx, val)
}
}
+
+ // encode to JSON
if err := encoder.Encode(val); err != nil {
+ // if response is not JSON serializable, but the response type is a reader, return it as-is
+ if reader, ok := val.(io.Reader); ok {
+ return reader, nil
+ }
return nil, err
}
- responseBytes := out.Bytes()
+ // if response value is an io.Reader, return it as-is
+ if reader, ok := val.(io.Reader); ok {
+ // back-compat, don't return the reader if the value serialized to a non-empty json
+ if strings.HasPrefix(out.String(), "{}") {
+ return reader, nil
+ }
+ }
+
// back-compat, strip the encoder's trailing newline unless WithSetIndent was used
if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" {
- return responseBytes[:len(responseBytes)-1], nil
+ out.Truncate(out.Len() - 1)
}
-
- return responseBytes, nil
- })
+ return out, nil
+ }
}
diff --git a/lambda/handler_test.go b/lambda/handler_test.go
index 3c3c51d4..17303827 100644
--- a/lambda/handler_test.go
+++ b/lambda/handler_test.go
@@ -3,14 +3,19 @@
package lambda
import (
+ "bytes"
"context"
"errors"
"fmt"
+ "io"
+ "io/ioutil" //nolint: staticcheck
+ "strings"
"testing"
"github.com/aws/aws-lambda-go/lambda/handlertrace"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestInvalidHandlers(t *testing.T) {
@@ -83,6 +88,23 @@ func TestInvalidHandlers(t *testing.T) {
}
}
+type arbitraryJSON struct {
+ json []byte
+ err error
+}
+
+func (a arbitraryJSON) MarshalJSON() ([]byte, error) {
+ return a.json, a.err
+}
+
+type staticHandler struct {
+ body []byte
+}
+
+func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) {
+ return h.body, nil
+}
+
type expected struct {
val string
err error
@@ -106,10 +128,8 @@ func TestInvokes(t *testing.T) {
}{
{
input: `"Lambda"`,
- expected: expected{`"Hello Lambda!"`, nil},
- handler: func(name string) (string, error) {
- return hello(name), nil
- },
+ expected: expected{`null`, nil},
+ handler: func(_ string) {},
},
{
input: `"Lambda"`,
@@ -118,6 +138,12 @@ func TestInvokes(t *testing.T) {
return hello(name), nil
},
},
+ {
+ expected: expected{`"Hello No Value!"`, nil},
+ handler: func(ctx context.Context) (string, error) {
+ return hello("No Value"), nil
+ },
+ },
{
input: `"Lambda"`,
expected: expected{`"Hello Lambda!"`, nil},
@@ -232,22 +258,86 @@ func TestInvokes(t *testing.T) {
{
name: "Handler interface implementations are passthrough",
expected: expected{`hello`, nil},
- handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
- return []byte(`hello`), nil
- }),
+ handler: &staticHandler{body: []byte(`hello`)},
+ },
+ {
+ name: "io.Reader responses are passthrough",
+ expected: expected{`yolo`, nil},
+ handler: func() (io.Reader, error) {
+ return strings.NewReader(`yolo`), nil
+ },
+ },
+ {
+ name: "io.Reader responses that are byte buffers are passthrough",
+ expected: expected{`yolo`, nil},
+ handler: func() (*bytes.Buffer, error) {
+ return bytes.NewBuffer([]byte(`yolo`)), nil
+ },
+ },
+ {
+ name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader",
+ expected: expected{`{"Yolo":"yolo"}`, nil},
+ handler: func() (io.Reader, error) {
+ return struct {
+ io.Reader `json:"-"`
+ Yolo string
+ }{
+ Reader: strings.NewReader(`yolo`),
+ Yolo: "yolo",
+ }, nil
+ },
+ },
+ {
+ name: "types that are not json serializable result in an error",
+ expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")},
+ handler: func() (interface{}, error) {
+ return struct {
+ arbitraryJSON
+ }{
+ arbitraryJSON{nil, errors.New("barf")},
+ }, nil
+ },
+ },
+ {
+ name: "io.Reader responses that not json serializable remain passthrough",
+ expected: expected{`wat`, nil},
+ handler: func() (io.Reader, error) {
+ return struct {
+ arbitraryJSON
+ io.Reader
+ }{
+ arbitraryJSON{nil, errors.New("barf")},
+ strings.NewReader("wat"),
+ }, nil
+ },
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
lambdaHandler := newHandler(testCase.handler, testCase.options...)
- response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
- if testCase.expected.err != nil {
- assert.Equal(t, testCase.expected.err, err)
- } else {
- assert.NoError(t, err)
- assert.Equal(t, testCase.expected.val, string(response))
- }
+ t.Run("via Handler.Invoke", func(t *testing.T) {
+ response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
+ if testCase.expected.err != nil {
+ assert.EqualError(t, err, testCase.expected.err.Error())
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, testCase.expected.val, string(response))
+ }
+ })
+ t.Run("via handlerOptions.handlerFunc", func(t *testing.T) {
+ response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input))
+ if testCase.expected.err != nil {
+ assert.EqualError(t, err, testCase.expected.err.Error())
+ } else {
+ assert.NoError(t, err)
+ require.NotNil(t, response)
+ responseBytes, err := ioutil.ReadAll(response)
+ assert.NoError(t, err)
+ assert.Equal(t, testCase.expected.val, string(responseBytes))
+ }
+ })
+
})
}
}
diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go
index f73689ba..9e2d6598 100644
--- a/lambda/invoke_loop.go
+++ b/lambda/invoke_loop.go
@@ -3,9 +3,11 @@
package lambda
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
+ "io"
"log"
"os"
"strconv"
@@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)
// call the handler, marshal any returned error
- response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
+ response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc)
if invokeErr != nil {
if err := reportFailure(invoke, invokeErr); err != nil {
return err
@@ -80,7 +82,19 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
}
return nil
}
- if err := invoke.success(response, contentTypeJSON); err != nil {
+ // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
+ if response, ok := response.(io.Closer); ok {
+ defer response.Close()
+ }
+
+ // if the response defines a content-type, plumb it through
+ contentType := contentTypeBytes
+ type ContentType interface{ ContentType() string }
+ if response, ok := response.(ContentType); ok {
+ contentType = response.ContentType()
+ }
+
+ if err := invoke.success(response, contentType); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
}
@@ -90,13 +104,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
errorPayload := safeMarshal(invokeErr)
log.Printf("%s", errorPayload)
- if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
+ if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
}
return nil
}
-func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
+func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) {
defer func() {
if err := recover(); err != nil {
invokeErr = lambdaPanicResponse(err)
diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go
index 54ec96cf..fab800b9 100644
--- a/lambda/invoke_loop_test.go
+++ b/lambda/invoke_loop_test.go
@@ -86,6 +86,7 @@ func TestCustomErrorMarshaling(t *testing.T) {
assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError)
for i := range errors {
assert.JSONEq(t, expected[i], string(record.responses[i]))
+ assert.Equal(t, contentTypeJSON, record.contentTypes[i])
}
}
@@ -156,7 +157,64 @@ func TestReadPayload(t *testing.T) {
endpoint := strings.Split(ts.URL, "://")[1]
_ = startRuntimeAPILoop(endpoint, handler)
assert.Equal(t, `"socat gnivarc ma I"`, string(record.responses[0]))
+ assert.Equal(t, contentTypeJSON, record.contentTypes[0])
+}
+
+type readCloser struct {
+ closed bool
+ reader *strings.Reader
+}
+
+func (r *readCloser) Read(p []byte) (int, error) {
+ return r.reader.Read(p)
+}
+
+func (r *readCloser) Close() error {
+ r.closed = true
+ return nil
+}
+
+func TestBinaryResponseDefaultContentType(t *testing.T) {
+ ts, record := runtimeAPIServer(`{"message": "I am craving tacos"}`, 1)
+ defer ts.Close()
+
+ handler := NewHandler(func(event struct{ Message string }) (io.Reader, error) {
+ length := utf8.RuneCountInString(event.Message)
+ reversed := make([]rune, length)
+ for i, v := range event.Message {
+ reversed[length-i-1] = v
+ }
+ return strings.NewReader(string(reversed)), nil
+ })
+ endpoint := strings.Split(ts.URL, "://")[1]
+ _ = startRuntimeAPILoop(endpoint, handler)
+ assert.Equal(t, `socat gnivarc ma I`, string(record.responses[0]))
+ assert.Equal(t, contentTypeBytes, record.contentTypes[0])
+}
+
+func TestBinaryResponseDoesNotLeakResources(t *testing.T) {
+ numResponses := 3
+ responses := make([]*readCloser, numResponses)
+ for i := 0; i < numResponses; i++ {
+ responses[i] = &readCloser{closed: false, reader: strings.NewReader(fmt.Sprintf("hello %d", i))}
+ }
+ timesCalled := 0
+ handler := NewHandler(func() (res io.Reader, _ error) {
+ res = responses[timesCalled]
+ timesCalled++
+ return
+ })
+ ts, record := runtimeAPIServer(`{}`, numResponses)
+ defer ts.Close()
+ endpoint := strings.Split(ts.URL, "://")[1]
+ _ = startRuntimeAPILoop(endpoint, handler)
+
+ for i := 0; i < numResponses; i++ {
+ assert.Equal(t, contentTypeBytes, record.contentTypes[i])
+ assert.Equal(t, fmt.Sprintf("hello %d", i), string(record.responses[i]))
+ assert.True(t, responses[i].closed)
+ }
}
func TestContextDeserializationErrors(t *testing.T) {
@@ -209,9 +267,10 @@ func TestSafeMarshal_SerializationError(t *testing.T) {
}
type requestRecord struct {
- nGets int
- nPosts int
- responses [][]byte
+ nGets int
+ nPosts int
+ responses [][]byte
+ contentTypes []string
}
type eventMetadata struct {
@@ -276,6 +335,7 @@ func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMeta
_ = r.Body.Close()
w.WriteHeader(http.StatusAccepted)
record.responses = append(record.responses, response.Bytes())
+ record.contentTypes = append(record.contentTypes, r.Header.Get("Content-Type"))
default:
w.WriteHeader(http.StatusBadRequest)
}
diff --git a/lambda/rpc_function_test.go b/lambda/rpc_function_test.go
index 6935084d..515fc62a 100644
--- a/lambda/rpc_function_test.go
+++ b/lambda/rpc_function_test.go
@@ -9,7 +9,10 @@ import (
"context"
"encoding/json"
"errors"
+ "io"
"os"
+ "strconv"
+ "strings"
"testing"
"time"
@@ -63,14 +66,13 @@ func TestInvoke(t *testing.T) {
func TestInvokeWithContext(t *testing.T) {
key := struct{}{}
srv := NewFunction(&handlerOptions{
- Handler: testWrapperHandler(
- func(ctx context.Context, input []byte) (interface{}, error) {
- assert.Equal(t, "dummy", ctx.Value(key))
- if deadline, ok := ctx.Deadline(); ok {
- return deadline.UnixNano(), nil
- }
- return nil, errors.New("!?!?!?!?!")
- }),
+ handlerFunc: func(ctx context.Context, _ []byte) (io.Reader, error) {
+ assert.Equal(t, "dummy", ctx.Value(key))
+ if deadline, ok := ctx.Deadline(); ok {
+ return strings.NewReader(strconv.FormatInt(deadline.UnixNano(), 10)), nil
+ }
+ return nil, errors.New("!?!?!?!?!")
+ },
baseContext: context.WithValue(context.Background(), key, "dummy"),
})
deadline := time.Now()
@@ -231,3 +233,56 @@ func TestXAmznTraceID(t *testing.T) {
}
}
+
+type closeableResponse struct {
+ reader io.Reader
+ closed bool
+}
+
+func (c *closeableResponse) Read(p []byte) (int, error) {
+ return c.reader.Read(p)
+}
+
+func (c *closeableResponse) Close() error {
+ c.closed = true
+ return nil
+}
+
+type readerError struct {
+ err error
+}
+
+func (r *readerError) Read(_ []byte) (int, error) {
+ return 0, r.err
+}
+
+func TestRPCModeInvokeClosesCloserIfResponseIsCloser(t *testing.T) {
+ handlerResource := &closeableResponse{
+ reader: strings.NewReader(""),
+ closed: false,
+ }
+ srv := NewFunction(newHandler(func() (interface{}, error) {
+ return handlerResource, nil
+ }))
+ var response messages.InvokeResponse
+ err := srv.Invoke(&messages.InvokeRequest{}, &response)
+ require.NoError(t, err)
+ assert.Equal(t, "", string(response.Payload))
+ assert.True(t, handlerResource.closed)
+}
+
+func TestRPCModeInvokeReaderErrorPropogated(t *testing.T) {
+ handlerResource := &closeableResponse{
+ reader: &readerError{errors.New("yolo")},
+ closed: false,
+ }
+ srv := NewFunction(newHandler(func() (interface{}, error) {
+ return handlerResource, nil
+ }))
+ var response messages.InvokeResponse
+ err := srv.Invoke(&messages.InvokeRequest{}, &response)
+ require.NoError(t, err)
+ assert.Equal(t, "", string(response.Payload))
+ assert.Equal(t, "yolo", response.Error.Message)
+ assert.True(t, handlerResource.closed)
+}
diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go
index 843f7ace..a83c3ce8 100644
--- a/lambda/runtime_api_client.go
+++ b/lambda/runtime_api_client.go
@@ -22,6 +22,7 @@ const (
headerClientContext = "Lambda-Runtime-Client-Context"
headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn"
contentTypeJSON = "application/json"
+ contentTypeBytes = "application/octet-stream"
apiVersion = "2018-06-01"
)
@@ -51,9 +52,9 @@ type invoke struct {
// success sends the response payload for an in-progress invocation.
// Notes:
// - An invoke is not complete until next() is called again!
-func (i *invoke) success(payload []byte, contentType string) error {
+func (i *invoke) success(body io.Reader, contentType string) error {
url := i.client.baseURL + i.id + "/response"
- return i.client.post(url, payload, contentType)
+ return i.client.post(url, body, contentType)
}
// failure sends the payload to the Runtime API. This marks the function's invoke as a failure.
@@ -61,9 +62,9 @@ func (i *invoke) success(payload []byte, contentType string) error {
// - The execution of the function process continues, and is billed, until next() is called again!
// - A Lambda Function continues to be re-used for future invokes even after a failure.
// If the error is fatal (panic, unrecoverable state), exit the process immediately after calling failure()
-func (i *invoke) failure(payload []byte, contentType string) error {
+func (i *invoke) failure(body io.Reader, contentType string) error {
url := i.client.baseURL + i.id + "/error"
- return i.client.post(url, payload, contentType)
+ return i.client.post(url, body, contentType)
}
// next connects to the Runtime API and waits for a new invoke Request to be available.
@@ -104,8 +105,8 @@ func (c *runtimeAPIClient) next() (*invoke, error) {
}, nil
}
-func (c *runtimeAPIClient) post(url string, payload []byte, contentType string) error {
- req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
+func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) error {
+ req, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return fmt.Errorf("failed to construct POST request to %s: %v", url, err)
}
diff --git a/lambda/runtime_api_client_test.go b/lambda/runtime_api_client_test.go
index 4693048b..7ccd47fb 100644
--- a/lambda/runtime_api_client_test.go
+++ b/lambda/runtime_api_client_test.go
@@ -3,6 +3,7 @@
package lambda
import (
+ "bytes"
"fmt"
"io/ioutil" //nolint: staticcheck
"net/http"
@@ -87,11 +88,11 @@ func TestClientDoneAndError(t *testing.T) {
client: client,
}
t.Run(fmt.Sprintf("happy Done with payload[%d]", i), func(t *testing.T) {
- err := invoke.success(payload, contentTypeJSON)
+ err := invoke.success(bytes.NewReader(payload), contentTypeJSON)
assert.NoError(t, err)
})
t.Run(fmt.Sprintf("happy Error with payload[%d]", i), func(t *testing.T) {
- err := invoke.failure(payload, contentTypeJSON)
+ err := invoke.failure(bytes.NewReader(payload), contentTypeJSON)
assert.NoError(t, err)
})
}