diff --git a/lambda/function.go b/lambda/function.go index 8977ce1b..1400dbcc 100644 --- a/lambda/function.go +++ b/lambda/function.go @@ -5,6 +5,7 @@ package lambda import ( "context" "encoding/json" + "os" "reflect" "time" @@ -63,6 +64,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok invokeContext = lambdacontext.NewContext(invokeContext, lc) invokeContext = context.WithValue(invokeContext, "x-amzn-trace-id", req.XAmznTraceId) + os.Setenv("_X_AMZN_TRACE_ID", req.XAmznTraceId) payload, err := fn.handler.Invoke(invokeContext, req.Payload) if err != nil { diff --git a/lambda/function_test.go b/lambda/function_test.go index 1ef64a4c..9531e5e4 100644 --- a/lambda/function_test.go +++ b/lambda/function_test.go @@ -7,12 +7,14 @@ import ( "encoding/json" "errors" "fmt" + "os" "testing" "time" "github.com/aws/aws-lambda-go/lambda/messages" "github.com/aws/aws-lambda-go/lambdacontext" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testWrapperHandler func(ctx context.Context, input []byte) (interface{}, error) @@ -140,3 +142,65 @@ func TestContextPlumbing(t *testing.T) { ` assert.JSONEq(t, expected, string(response.Payload)) } + +func TestXAmznTraceID(t *testing.T) { + type XRayResponse struct { + Env string + Ctx string + } + + srv := &Function{handler: testWrapperHandler( + func(ctx context.Context, input []byte) (interface{}, error) { + return &XRayResponse{ + Env: os.Getenv("_X_AMZN_TRACE_ID"), + Ctx: ctx.Value("x-amzn-trace-id").(string), + }, nil + }, + )} + + sequence := []struct { + Input string + Expected string + }{ + { + "", + `{"Env": "", "Ctx": ""}`, + }, + { + "dummyid", + `{"Env": "dummyid", "Ctx": "dummyid"}`, + }, + { + "", + `{"Env": "", "Ctx": ""}`, + }, + { + "123dummyid", + `{"Env": "123dummyid", "Ctx": "123dummyid"}`, + }, + { + "", + `{"Env": "", "Ctx": ""}`, + }, + { + "", + `{"Env": "", "Ctx": ""}`, + }, + { + "567", + `{"Env": "567", "Ctx": "567"}`, + }, + { + "hihihi", + `{"Env": "hihihi", "Ctx": "hihihi"}`, + }, + } + + for i, test := range sequence { + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{XAmznTraceId: test.Input}, &response) + require.NoError(t, err, "failed test sequence[%d]", i) + assert.JSONEq(t, test.Expected, string(response.Payload), "failed test sequence[%d]", i) + } + +}