diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 06f41fab..14e115d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,11 +25,15 @@ jobs: - run: go version + - name: install lambda runtime interface emulator + run: curl -L -o /usr/local/bin/aws-lambda-rie https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-x86_64 + - run: chmod +x /usr/local/bin/aws-lambda-rie + - name: Check out code into the Go module directory uses: actions/checkout@v2 - name: go test - run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... + run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 diff --git a/lambda/extensions_api_client.go b/lambda/extensions_api_client.go new file mode 100644 index 00000000..e17292b3 --- /dev/null +++ b/lambda/extensions_api_client.go @@ -0,0 +1,90 @@ +package lambda + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" +) + +const ( + headerExtensionName = "Lambda-Extension-Name" + headerExtensionIdentifier = "Lambda-Extension-Identifier" + extensionAPIVersion = "2020-01-01" +) + +type extensionAPIEventType string + +const ( + extensionInvokeEvent extensionAPIEventType = "INVOKE" //nolint:deadcode,unused,varcheck + extensionShutdownEvent extensionAPIEventType = "SHUTDOWN" //nolint:deadcode,unused,varcheck +) + +type extensionAPIClient struct { + baseURL string + httpClient *http.Client +} + +func newExtensionAPIClient(address string) *extensionAPIClient { + client := &http.Client{ + Timeout: 0, // connections to the extensions API are never expected to time out + } + endpoint := "http://" + address + "/" + extensionAPIVersion + "/extension/" + return &extensionAPIClient{ + baseURL: endpoint, + httpClient: client, + } +} + +func (c *extensionAPIClient) register(name string, events ...extensionAPIEventType) (string, error) { + url := c.baseURL + "register" + body, _ := json.Marshal(struct { + Events []extensionAPIEventType `json:"events"` + }{ + Events: events, + }) + + req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + req.Header.Add(headerExtensionName, name) + res, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to register extension: %v", err) + } + defer res.Body.Close() + _, _ = io.Copy(ioutil.Discard, res.Body) + + if res.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) + } + + return res.Header.Get(headerExtensionIdentifier), nil +} + +type extensionEventResponse struct { + EventType extensionAPIEventType + // ... the rest not implemented +} + +func (c *extensionAPIClient) next(id string) (response extensionEventResponse, err error) { + url := c.baseURL + "event/next" + + req, _ := http.NewRequest(http.MethodGet, url, nil) + req.Header.Add(headerExtensionIdentifier, id) + res, err := c.httpClient.Do(req) + if err != nil { + err = fmt.Errorf("failed to get extension event: %v", err) + return + } + defer res.Body.Close() + _, _ = io.Copy(ioutil.Discard, res.Body) + + if res.StatusCode != http.StatusOK { + err = fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) + return + } + + err = json.NewDecoder(res.Body).Decode(&response) + return +} diff --git a/lambda/handler.go b/lambda/handler.go index 354459b1..2e9f53c6 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -23,6 +23,8 @@ type handlerOptions struct { jsonResponseEscapeHTML bool jsonResponseIndentPrefix string jsonResponseIndentValue string + enableSIGTERM bool + sigtermCallbacks []func() } type Option func(*handlerOptions) @@ -73,6 +75,26 @@ func WithSetIndent(prefix, indent string) Option { }) } +// WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown. +// SIGKILL will occur ~500ms after SIGTERM. +// Optionally, an array of callback functions to run on SIGTERM may be provided. +// +// Usage: +// lambda.StartWithOptions( +// func (event any) (any error) { +// return event, nil +// }, +// lambda.WithEnableSIGTERM(func() { +// log.Print("function container shutting down...") +// }) +// ) +func WithEnableSIGTERM(callbacks ...func()) Option { + return Option(func(h *handlerOptions) { + h.sigtermCallbacks = append(h.sigtermCallbacks, callbacks...) + h.enableSIGTERM = true + }) +} + func validateArguments(handler reflect.Type) (bool, error) { handlerTakesContext := false if handler.NumIn() > 2 { @@ -139,6 +161,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { for _, option := range options { option(h) } + if h.enableSIGTERM { + enableSIGTERM(h.sigtermCallbacks) + } h.Handler = reflectHandler(handlerFunc, h) return h } diff --git a/lambda/sigterm.go b/lambda/sigterm.go new file mode 100644 index 00000000..b742e911 --- /dev/null +++ b/lambda/sigterm.go @@ -0,0 +1,53 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +package lambda + +import ( + "log" + "os" + "os/signal" + "syscall" +) + +// enableSIGTERM configures an optional list of sigtermHandlers to run on process shutdown. +// This non-default behavior is enabled within Lambda using the extensions API. +func enableSIGTERM(sigtermHandlers []func()) { + // for fun, we'll also optionally register SIGTERM handlers + if len(sigtermHandlers) > 0 { + signaled := make(chan os.Signal, 1) + signal.Notify(signaled, syscall.SIGTERM) + go func() { + <-signaled + for _, f := range sigtermHandlers { + f() + } + }() + } + + // detect if we're actually running within Lambda + endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") + if endpoint == "" { + log.Print("WARNING! AWS_LAMBDA_RUNTIME_API environment variable not found. Skipping attempt to register internal extension...") + return + } + + // Now to do the AWS Lambda specific stuff. + // The default Lambda behavior is for functions to get SIGKILL at the end of lifetime, or after a timeout. + // Any use of the Lambda extension register API enables SIGTERM to be sent to the function process before the SIGKILL. + // We'll register an extension that does not listen for any lifecycle events named "GoLangEnableSIGTERM". + // The API will respond with an ID we need to pass in future requests. + client := newExtensionAPIClient(endpoint) + id, err := client.register("GoLangEnableSIGTERM") + if err != nil { + log.Printf("WARNING! Failed to register internal extension! SIGTERM events may not be enabled! err: %v", err) + return + } + + // We didn't actually register for any events, but we need to call /next anyways to let the API know we're done initalizing. + // Because we didn't register for any events, /next will never return, so we'll do this in a go routine that is doomed to stay blocked. + go func() { + _, err := client.next(id) + log.Printf("WARNING! Reached expected unreachable code! Extension /next call expected to block forever! err: %v", err) + }() + +} diff --git a/lambda/sigterm_test.go b/lambda/sigterm_test.go new file mode 100644 index 00000000..195c9550 --- /dev/null +++ b/lambda/sigterm_test.go @@ -0,0 +1,93 @@ +//go:build go1.15 +// +build go1.15 + +package lambda + +import ( + "io/ioutil" + "net/http" + "os" + "os/exec" + "path" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations" +) + +func TestEnableSigterm(t *testing.T) { + if _, err := exec.LookPath("aws-lambda-rie"); err != nil { + t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err) + } + + testDir := t.TempDir() + + // compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie + handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "sigterm.handler"), "./testdata/sigterm.go") + handlerBuild.Stderr = os.Stderr + handlerBuild.Stdout = os.Stderr + require.NoError(t, handlerBuild.Run()) + + for name, opts := range map[string]struct { + envVars []string + assertLogs func(t *testing.T, logs string) + }{ + "baseline": { + assertLogs: func(t *testing.T, logs string) { + assert.NotContains(t, logs, "Hello SIGTERM!") + assert.NotContains(t, logs, "I've been TERMINATED!") + }, + }, + "sigterm enabled": { + envVars: []string{"ENABLE_SIGTERM=please"}, + assertLogs: func(t *testing.T, logs string) { + assert.Contains(t, logs, "Hello SIGTERM!") + assert.Contains(t, logs, "I've been TERMINATED!") + }, + }, + } { + t.Run(name, func(t *testing.T) { + // run the runtime interface emulator, capture the logs for assertion + cmd := exec.Command("aws-lambda-rie", "sigterm.handler") + cmd.Env = append([]string{ + "PATH=" + testDir, + "AWS_LAMBDA_FUNCTION_TIMEOUT=2", + }, opts.envVars...) + cmd.Stderr = os.Stderr + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + var logs string + done := make(chan interface{}) // closed on completion of log flush + go func() { + logBytes, err := ioutil.ReadAll(stdout) + require.NoError(t, err) + logs = string(logBytes) + close(done) + }() + require.NoError(t, cmd.Start()) + t.Cleanup(func() { _ = cmd.Process.Kill() }) + + // give a moment for the port to bind + time.Sleep(500 * time.Millisecond) + + client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie + resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, string(body), "Task timed out after 2.00 seconds") + + require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained + <-done + t.Logf("stdout:\n%s", logs) + opts.assertLogs(t, logs) + }) + } +} diff --git a/lambda/testdata/sigterm.go b/lambda/testdata/sigterm.go new file mode 100644 index 00000000..69183e5f --- /dev/null +++ b/lambda/testdata/sigterm.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/aws/aws-lambda-go/lambda" +) + +func init() { + // conventional SIGTERM callback + signaled := make(chan os.Signal, 1) + signal.Notify(signaled, syscall.SIGTERM) + go func() { + <-signaled + fmt.Println("I've been TERMINATED!") + }() + +} + +func main() { + // lambda option to enable sigterm, plus optional extra sigterm callbacks + sigtermOption := lambda.WithEnableSIGTERM(func() { + fmt.Println("Hello SIGTERM!") + }) + handlerOptions := []lambda.Option{} + if os.Getenv("ENABLE_SIGTERM") != "" { + handlerOptions = append(handlerOptions, sigtermOption) + } + lambda.StartWithOptions( + func(ctx context.Context) { + deadline, _ := ctx.Deadline() + <-time.After(time.Until(deadline) + time.Second) + panic("unreachable line reached!") + }, + handlerOptions..., + ) +}