diff --git a/core/requestalb.go b/core/requestalb.go
new file mode 100644
index 0000000..648dcb0
--- /dev/null
+++ b/core/requestalb.go
@@ -0,0 +1,156 @@
+// Package core provides utility methods that help convert proxy events
+// into an http.Request and http.ResponseWriter
+package core
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "github.com/aws/aws-lambda-go/events"
+ "github.com/aws/aws-lambda-go/lambdacontext"
+)
+
+const (
+ // ALBTgContextHeader is the custom header key used to store the
+ // ALB Target Group Request context. To access the Context properties use the
+ // RequestAccessorALB method of the RequestAccessorALB object.
+ ALBTgContextHeader = "X-Golambdaproxy-Albtargetgroup-Context"
+)
+
+// RequestAccessorALB objects give access to custom ALB Target Group properties
+// in the request.
+type RequestAccessorALB struct{}
+
+// GetALBTargetGroupRequestContext extracts the ALB Target Group Responce context object from a
+// request's custom header.
+// Returns a populated events.ALBTargetGroupRequestContext object from
+// the request.
+func (r *RequestAccessorALB) GetALBTargetGroupRequestContext(req *http.Request) (events.ALBTargetGroupRequestContext, error) {
+ if req.Header.Get(ALBTgContextHeader) == "" {
+ return events.ALBTargetGroupRequestContext{}, errors.New("No context header in request")
+ }
+ context := events.ALBTargetGroupRequestContext{}
+ err := json.Unmarshal([]byte(req.Header.Get(ALBTgContextHeader)), &context)
+ if err != nil {
+ log.Println("Erorr while unmarshalling context")
+ log.Println(err)
+ return events.ALBTargetGroupRequestContext{}, err
+ }
+ return context, nil
+}
+
+// ProxyEventToHTTPRequest converts an ALB Target Group proxy event into a http.Request object.
+// Returns the populated http request with additional two custom headers for ALB Tg Req context.
+// To access these properties use the GetALBTargetGroupRequestContext method of the RequestAccessor object.
+func (r *RequestAccessorALB) ProxyEventToHTTPRequest(req events.ALBTargetGroupRequest) (*http.Request, error) {
+ httpRequest, err := r.EventToRequest(req)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ return addToHeaderALB(httpRequest, req)
+}
+
+// EventToRequestWithContext converts an ALB Target Group proxy event and context into an http.Request object.
+// Returns the populated http request with lambda context, ALBTargetGroupRequestContext as part of its context.
+// Access those using GetRuntimeContextFromContextALB and GetRuntimeContextFromContext functions in this package.
+func (r *RequestAccessorALB) EventToRequestWithContext(ctx context.Context, req events.ALBTargetGroupRequest) (*http.Request, error) {
+ httpRequest, err := r.EventToRequest(req)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ return addToContextALB(ctx, httpRequest, req), nil
+}
+
+// EventToRequest converts an ALB Target group proxy event into an http.Request object.
+// Returns the populated request maintaining headers
+func (r *RequestAccessorALB) EventToRequest(req events.ALBTargetGroupRequest) (*http.Request, error) {
+ decodedBody := []byte(req.Body)
+ if req.IsBase64Encoded {
+ base64Body, err := base64.StdEncoding.DecodeString(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ decodedBody = base64Body
+ }
+
+ path := req.Path
+
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if len(req.QueryStringParameters) > 0 {
+ values := url.Values{}
+ for key, value := range req.QueryStringParameters {
+ values.Add(key, value)
+ }
+ path += "?" + values.Encode()
+ }
+
+ httpRequest, err := http.NewRequest(
+ strings.ToUpper(req.HTTPMethod),
+ path,
+ bytes.NewReader(decodedBody),
+ )
+
+ if err != nil {
+ fmt.Printf("Could not convert request %s:%s to http.Request\n", req.HTTPMethod, req.Path)
+ log.Println(err)
+ return nil, err
+ }
+
+ for headerKey, headerValue := range req.Headers {
+ for _, val := range strings.Split(headerValue, ",") {
+ httpRequest.Header.Add(headerKey, strings.Trim(val, " "))
+ }
+ }
+
+ for headerKey, headerValue := range req.MultiValueHeaders {
+ for _, arrVal := range headerValue {
+ for _, val := range strings.Split(arrVal, ",") {
+ httpRequest.Header.Add(headerKey, strings.Trim(val, " "))
+ }
+ }
+ }
+
+ httpRequest.RequestURI = httpRequest.URL.RequestURI()
+
+ return httpRequest, nil
+}
+
+func addToHeaderALB(req *http.Request, albTgRequest events.ALBTargetGroupRequest) (*http.Request, error) {
+ albTgContext, err := json.Marshal(albTgRequest.RequestContext)
+ if err != nil {
+ log.Println("Could not Marshal ALB Tg context for custom header")
+ return req, err
+ }
+ req.Header.Add(ALBTgContextHeader, string(albTgContext))
+ return req, nil
+}
+
+func addToContextALB(ctx context.Context, req *http.Request, albTgRequest events.ALBTargetGroupRequest) *http.Request {
+ lc, _ := lambdacontext.FromContext(ctx)
+ rc := requestContextALB{lambdaContext: lc, gatewayProxyContext: albTgRequest.RequestContext}
+ ctx = context.WithValue(ctx, ctxKey{}, rc)
+ return req.WithContext(ctx)
+}
+
+func GetRuntimeContextFromContextALB(ctx context.Context) (*lambdacontext.LambdaContext, bool) {
+ v, ok := ctx.Value(ctxKey{}).(requestContextALB)
+ return v.lambdaContext, ok
+}
+
+type requestContextALB struct {
+ lambdaContext *lambdacontext.LambdaContext
+ gatewayProxyContext events.ALBTargetGroupRequestContext
+}
diff --git a/core/requestalb_test.go b/core/requestalb_test.go
new file mode 100644
index 0000000..fe84f89
--- /dev/null
+++ b/core/requestalb_test.go
@@ -0,0 +1,239 @@
+package core_test
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "io/ioutil"
+ "math/rand"
+ "strings"
+
+ "github.com/aws/aws-lambda-go/events"
+ "github.com/aws/aws-lambda-go/lambdacontext"
+ "github.com/awslabs/aws-lambda-go-api-proxy/core"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("RequestAccessorALB tests", func() {
+ Context("event conversion", func() {
+ accessor := core.RequestAccessorALB{}
+ basicRequest := getProxyRequestALB("/hello", "GET")
+ It("Correctly converts a basic event", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect("/hello").To(Equal(httpReq.RequestURI))
+ Expect("GET").To(Equal(httpReq.Method))
+ })
+
+ basicRequest = getProxyRequestALB("/hello", "get")
+ It("Converts method to uppercase", func() {
+ // calling old method to verify reverse compatibility
+ httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect("/hello").To(Equal(httpReq.RequestURI))
+ Expect("GET").To(Equal(httpReq.Method))
+ })
+
+ binaryBody := make([]byte, 256)
+ _, err := rand.Read(binaryBody)
+ if err != nil {
+ Fail("Could not generate random binary body")
+ }
+
+ encodedBody := base64.StdEncoding.EncodeToString(binaryBody)
+
+ binaryRequest := getProxyRequestALB("/hello", "POST")
+ binaryRequest.Body = encodedBody
+ binaryRequest.IsBase64Encoded = true
+
+ It("Decodes a base64 encoded body", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), binaryRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect("/hello").To(Equal(httpReq.RequestURI))
+ Expect("POST").To(Equal(httpReq.Method))
+
+ bodyBytes, err := ioutil.ReadAll(httpReq.Body)
+
+ Expect(err).To(BeNil())
+ Expect(len(binaryBody)).To(Equal(len(bodyBytes)))
+ Expect(binaryBody).To(Equal(bodyBytes))
+ })
+
+ mqsRequest := getProxyRequestALB("/hello", "GET")
+ mqsRequest.QueryStringParameters = map[string]string{
+ "hello": "1",
+ "world": "2",
+ }
+ It("Populates multiple value query string correctly", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), mqsRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect(httpReq.RequestURI).To(ContainSubstring("hello=1"))
+ Expect(httpReq.RequestURI).To(ContainSubstring("world=2"))
+ Expect("GET").To(Equal(httpReq.Method))
+
+ query := httpReq.URL.Query()
+ Expect(2).To(Equal(len(query)))
+ Expect(query["hello"]).ToNot(BeNil())
+ Expect(query["world"]).ToNot(BeNil())
+ Expect(1).To(Equal(len(query["hello"])))
+ Expect(1).To(Equal(len(query["world"])))
+ Expect("1").To(Equal(query["hello"][0]))
+ Expect("2").To(Equal(query["world"][0]))
+ })
+
+ // Support `QueryStringParameters` for backward compatibility.
+ // https://github.com/awslabs/aws-lambda-go-api-proxy/issues/37
+ qsRequest := getProxyRequestALB("/hello", "GET")
+ qsRequest.QueryStringParameters = map[string]string{
+ "hello": "1",
+ "world": "2",
+ }
+ It("Populates query string correctly", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), qsRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect(httpReq.RequestURI).To(ContainSubstring("hello=1"))
+ Expect(httpReq.RequestURI).To(ContainSubstring("world=2"))
+ Expect("GET").To(Equal(httpReq.Method))
+
+ query := httpReq.URL.Query()
+ Expect(2).To(Equal(len(query)))
+ Expect(query["hello"]).ToNot(BeNil())
+ Expect(query["world"]).ToNot(BeNil())
+ Expect(1).To(Equal(len(query["hello"])))
+ Expect(1).To(Equal(len(query["world"])))
+ Expect("1").To(Equal(query["hello"][0]))
+ Expect("2").To(Equal(query["world"][0]))
+ })
+
+ mvhRequest := getProxyRequestALB("/hello", "GET")
+ mvhRequest.Headers = map[string]string{
+ "hello": "1",
+ "world": "2,3",
+ }
+ mvhRequest.MultiValueHeaders = map[string][]string{
+ "hello world": []string{"4", "5", "6"},
+ }
+
+ It("Populates multiple value headers correctly", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), mvhRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect("GET").To(Equal(httpReq.Method))
+
+ headers := httpReq.Header
+ Expect(3).To(Equal(len(headers)))
+
+ for k, value := range headers {
+ if mvhRequest.Headers[strings.ToLower(k)] != "" {
+ Expect(strings.Join(value, ",")).To(Equal(mvhRequest.Headers[strings.ToLower(k)]))
+ } else {
+ Expect(strings.Join(value, ",")).To(Equal(strings.Join(mvhRequest.MultiValueHeaders[strings.ToLower(k)], ",")))
+ }
+ }
+ })
+
+ svhRequest := getProxyRequestALB("/hello", "GET")
+ svhRequest.Headers = map[string]string{
+ "hello": "1",
+ "world": "2",
+ }
+ It("Populates single value headers correctly", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), svhRequest)
+ Expect(err).To(BeNil())
+ Expect("/hello").To(Equal(httpReq.URL.Path))
+ Expect("GET").To(Equal(httpReq.Method))
+
+ headers := httpReq.Header
+ Expect(2).To(Equal(len(headers)))
+
+ for k, value := range headers {
+ Expect(value[0]).To(Equal(svhRequest.Headers[strings.ToLower(k)]))
+ }
+ })
+
+ basePathRequest := getProxyRequestALB("/orders", "GET")
+
+ It("Stips the base path correct", func() {
+ httpReq, err := accessor.EventToRequestWithContext(context.Background(), basePathRequest)
+
+ Expect(err).To(BeNil())
+ Expect("/orders").To(Equal(httpReq.URL.Path))
+ Expect("/orders").To(Equal(httpReq.RequestURI))
+ })
+
+ contextRequest := getProxyRequestALB("/orders", "GET")
+ contextRequest.RequestContext = getRequestContextALB()
+
+ It("Populates context header correctly", func() {
+ // calling old method to verify reverse compatibility
+ httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
+ Expect(err).To(BeNil())
+ Expect(1).To(Equal(len(httpReq.Header)))
+ Expect(httpReq.Header.Get(core.ALBTgContextHeader)).ToNot(BeNil())
+ })
+ })
+
+ Context("Retrieves ALB Target Group context", func() {
+ It("Returns a correctly unmarshalled object", func() {
+ contextRequest := getProxyRequestALB("/orders", "GET")
+ contextRequest.RequestContext = getRequestContextALB()
+
+ accessor := core.RequestAccessorALB{}
+ httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
+ Expect(err).To(BeNil())
+ ctx := httpReq.Header[core.ALBTgContextHeader][0]
+ var parsedCtx events.ALBTargetGroupRequestContext
+ json.Unmarshal([]byte(ctx), &parsedCtx)
+ Expect("foo").To(Equal(parsedCtx.ELB.TargetGroupArn))
+
+ headerContext, err := accessor.GetALBTargetGroupRequestContext(httpReq)
+ Expect(err).To(BeNil())
+ Expect("foo").To(Equal(headerContext.ELB.TargetGroupArn))
+
+ httpReq, err = accessor.EventToRequestWithContext(context.Background(), contextRequest)
+ Expect(err).To(BeNil())
+ Expect("/orders").To(Equal(httpReq.RequestURI))
+ runtimeContext, ok := core.GetRuntimeContextFromContextALB(httpReq.Context())
+ Expect(ok).To(BeTrue())
+ Expect(runtimeContext).To(BeNil())
+
+ lambdaContext := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{AwsRequestID: "abc123"})
+ httpReq, err = accessor.EventToRequestWithContext(lambdaContext, contextRequest)
+ Expect(err).To(BeNil())
+ Expect("/orders").To(Equal(httpReq.RequestURI))
+
+ headerContext, err = accessor.GetALBTargetGroupRequestContext(httpReq)
+ // should fail as new context method doesn't populate headers
+ Expect(err).ToNot(BeNil())
+ Expect("").To(Equal(headerContext.ELB.TargetGroupArn))
+ runtimeContext, ok = core.GetRuntimeContextFromContextALB(httpReq.Context())
+ Expect(ok).To(BeTrue())
+ Expect(runtimeContext).ToNot(BeNil())
+ Expect("abc123").To(Equal(runtimeContext.AwsRequestID))
+ })
+ })
+})
+
+func getProxyRequestALB(path string, method string) events.ALBTargetGroupRequest {
+ return events.ALBTargetGroupRequest{
+ RequestContext: events.ALBTargetGroupRequestContext{},
+ Path: path,
+ HTTPMethod: method,
+ Headers: map[string]string{},
+ }
+}
+
+func getRequestContextALB() events.ALBTargetGroupRequestContext {
+ return events.ALBTargetGroupRequestContext{
+ ELB: events.ELBContext{
+ TargetGroupArn: "foo",
+ },
+ }
+}
diff --git a/core/responsealb.go b/core/responsealb.go
new file mode 100644
index 0000000..cca6e17
--- /dev/null
+++ b/core/responsealb.go
@@ -0,0 +1,126 @@
+// Package core provides utility methods that help convert proxy events
+// into an http.Request and http.ResponseWriter
+package core
+
+import (
+ "bytes"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "net/http"
+ "strings"
+ "unicode/utf8"
+
+ "github.com/aws/aws-lambda-go/events"
+)
+
+// ProxyResponseWriterALB implements http.ResponseWriter and adds the method
+// necessary to return an events.ALBTargetGroupResponse object
+type ProxyResponseWriterALB struct {
+ headers http.Header
+ body bytes.Buffer
+ status int
+ observers []chan<- bool
+}
+
+// NewProxyResponseWriter returns a new ProxyResponseWriter object.
+// The object is initialized with an empty map of headers and a
+// status code of -1
+func NewProxyResponseWriterALB() *ProxyResponseWriterALB {
+ return &ProxyResponseWriterALB{
+ headers: make(http.Header),
+ status: defaultStatusCode,
+ observers: make([]chan<- bool, 0),
+ }
+
+}
+
+func (r *ProxyResponseWriterALB) CloseNotify() <-chan bool {
+ ch := make(chan bool, 1)
+
+ r.observers = append(r.observers, ch)
+
+ return ch
+}
+
+func (r *ProxyResponseWriterALB) notifyClosed() {
+ for _, v := range r.observers {
+ v <- true
+ }
+}
+
+// Header implementation from the http.ResponseWriter interface.
+func (r *ProxyResponseWriterALB) Header() http.Header {
+ return r.headers
+}
+
+// Write sets the response body in the object. If no status code
+// was set before with the WriteHeader method it sets the status
+// for the response to 200 OK.
+func (r *ProxyResponseWriterALB) Write(body []byte) (int, error) {
+ if r.status == defaultStatusCode {
+ r.status = http.StatusOK
+ }
+
+ // if the content type header is not set when we write the body we try to
+ // detect one and set it by default. If the content type cannot be detected
+ // it is automatically set to "application/octet-stream" by the
+ // DetectContentType method
+ if r.Header().Get(contentTypeHeaderKey) == "" {
+ r.Header().Add(contentTypeHeaderKey, http.DetectContentType(body))
+ }
+
+ return (&r.body).Write(body)
+}
+
+// WriteHeader sets a status code for the response. This method is used
+// for error responses.
+func (r *ProxyResponseWriterALB) WriteHeader(status int) {
+ r.status = status
+}
+
+// GetProxyResponse converts the data passed to the response writer into
+// an events.ALBTargetGroupResponse object.
+// Returns a populated proxy response object. If the response is invalid, for example
+// has no headers or an invalid status code returns an error.
+func (r *ProxyResponseWriterALB) GetProxyResponse() (events.ALBTargetGroupResponse, error) {
+ r.notifyClosed()
+
+ if r.status == defaultStatusCode {
+ return events.ALBTargetGroupResponse{}, errors.New("Status code not set on response")
+ }
+
+ var output string
+ isBase64 := false
+
+ bb := (&r.body).Bytes()
+
+ if utf8.Valid(bb) {
+ output = string(bb)
+ } else {
+ output = base64.StdEncoding.EncodeToString(bb)
+ isBase64 = true
+ }
+
+ headers := make(map[string]string)
+ multiHeaders := make(map[string][]string)
+
+ // set both Headers and MultiValueHeaders
+ for headerKey, headerValue := range http.Header(r.headers) {
+ headers[headerKey] = strings.Join(headerValue, ",")
+ if multiHeaders[headerKey] != nil {
+ multiHeaders[headerKey] = append(multiHeaders[headerKey], strings.Join(headerValue, ","))
+ } else {
+ multiHeaders[headerKey] = []string{strings.Join(headerValue, ",")}
+ }
+ }
+
+ return events.ALBTargetGroupResponse{
+ StatusCode: r.status,
+ StatusDescription: fmt.Sprintf("%d %s", r.status, http.StatusText(r.status)),
+ Headers: headers,
+ MultiValueHeaders: multiHeaders,
+ Body: output,
+ IsBase64Encoded: isBase64,
+ }, nil
+}
diff --git a/core/responsealb_test.go b/core/responsealb_test.go
new file mode 100644
index 0000000..99badb4
--- /dev/null
+++ b/core/responsealb_test.go
@@ -0,0 +1,182 @@
+package core
+
+import (
+ "encoding/base64"
+ "math/rand"
+ "net/http"
+ "strings"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("ResponseWriter tests", func() {
+ Context("writing to response object", func() {
+ response := NewProxyResponseWriterALB()
+
+ It("Sets the correct default status", func() {
+ Expect(defaultStatusCode).To(Equal(response.status))
+ })
+
+ It("Initializes the headers map", func() {
+ Expect(response.headers).ToNot(BeNil())
+ Expect(0).To(Equal(len(response.headers)))
+ })
+
+ It("Writes headers correctly", func() {
+ response.Header().Add("Content-Type", "application/json")
+ response.Header().Add("Content-Type", "charset=utf-8")
+
+ Expect(1).To(Equal(len(response.headers)))
+ Expect("application/json").To(Equal(response.headers["Content-Type"][0]))
+ Expect("charset=utf-8").To(Equal(response.headers["Content-Type"][1]))
+ })
+
+ It("Writes body content correctly", func() {
+ binaryBody := make([]byte, 256)
+ _, err := rand.Read(binaryBody)
+ Expect(err).To(BeNil())
+
+ written, err := response.Write(binaryBody)
+ Expect(err).To(BeNil())
+ Expect(len(binaryBody)).To(Equal(written))
+ })
+
+ It("Automatically set the status code to 200", func() {
+ Expect(http.StatusOK).To(Equal(response.status))
+ })
+
+ It("Forces the status to a new code", func() {
+ response.WriteHeader(http.StatusAccepted)
+ Expect(http.StatusAccepted).To(Equal(response.status))
+ })
+ })
+
+ Context("Automatically set response content type", func() {
+ xmlBodyContent := "