Skip to content

WIP - Introduce support for event with payload 2.0 #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 153 additions & 26 deletions lambda_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,125 @@ def _converters(value: str, pathArg: str) -> Any:
return value


class Event:
"""Basic class for payload based events"""

payload: dict
_headers: dict

def __init__(self, payload) -> None:
"""Initialize event object"""
self.payload = payload
self.__init_headers__()

def __init_headers__(self):
"""
Initialize headers bases on hacked payload
HACK: For an unknown reason some keys can have lower or upper case.
To make sure the app works well we cast all the keys to lowercase.
"""
headers = self.payload.get("headers", {}) or {}
self._headers = dict((key.lower(), value) for key, value in headers.items())

def version(self) -> str:
"""Get payload version"""
return self.payload.get("version")

def http_method(self) -> str:
"""Get HTTP method"""
pass

def headers(self) -> dict:
"""Get headers"""
return self._headers

def path(self) -> Optional[str]:
"""Get path"""
pass

def raw_path(self):
"""Get raw path"""
pass

def resource(self) -> str:
"""Get resource"""
return self.payload.get("resource", "/")

def request_context(self) -> dict:
"""Get requestContext"""
return self.payload.get("requestContext", {})

def path_parameters(self) -> dict:
"""Get path parameters"""
return self.payload.get("pathParameters", {}) or {}

def query_string_parameters(self) -> dict:
"""Get query string parameters"""
return self.payload.get("queryStringParameters", {}) or {}

def body(self) -> str:
"""Get body"""
return self.payload.get("body")

def is_base64_encoded(self) -> bool:
"""Return is payload base64 encoded"""
return self.payload.get("isBase64Encoded")

def to_dict(self):
"""Return Event object as dict"""
return {
"version": self.version(),
"path": self.path(),
"httpMethod": self.http_method(),
"headers": self.headers(),
"queryStringParameters": self.query_string_parameters(),
}


class Event10(Event):
"""Event based on payload version 1.0"""

def http_method(self) -> str:
"""Get HTTP method"""
return self.payload.get("httpMethod")

def path(self) -> Optional[str]:
"""Get path"""
return self.payload.get("path")


class Event20(Event):
"""Event based on payload version 2.0"""

def http_method(self) -> str:
"""Get HTTP method"""
return self.request_context().get("http").get("method")

def path(self) -> Optional[str]:
"""Get raw path - path in 1.0"""
return self.payload.get("rawPath")

def resource(self) -> str:
"""Get route key - resource in 1.0"""
return self.payload.get("routeKey")


class EventFactory:
"""Create Event from payload."""

@staticmethod
def create_from_payload(payload):
"""Create Event object"""
versions = {
"1.0": Event10,
"2.0": Event20,
}
event = versions.get(payload.get("version")) # type: Event()
if event is None:
raise Exception("Unsupported payload version")
return event(payload)


class RouteEntry(object):
"""Decode request path."""

Expand Down Expand Up @@ -122,37 +241,37 @@ def _get_path_args(self) -> Sequence[Any]:
return args


def _get_apigw_stage(event: Dict) -> str:
def _get_apigw_stage(event: Event) -> str:
"""Return API Gateway stage name."""
header = event.get("headers", {})
header = event.headers()
host = header.get("x-forwarded-host", header.get("host", ""))
if ".execute-api." in host and ".amazonaws.com" in host:
stage = event["requestContext"].get("stage", "")
stage = event.request_context().get("stage", "")
return stage
return ""


def _get_request_path(event: Dict) -> Optional[str]:
def _get_request_path(event: Event) -> Optional[str]:
"""Return API call path."""
resource_proxy = proxy_pattern.search(event.get("resource", "/"))
resource_proxy = proxy_pattern.search(event.resource())
if resource_proxy:
proxy_path = event["pathParameters"].get(resource_proxy["name"])
proxy_path = event.path_parameters().get(resource_proxy["name"])
return f"/{proxy_path}"

return event.get("path")
return event.path()


class ApigwPath(object):
"""Parse path of API Call."""

def __init__(self, event: Dict):
def __init__(self, event: Event):
"""Initialize API Gateway Path Info object."""
self.version = event.get("version")
self.version = event.version()
self.apigw_stage = _get_apigw_stage(event)
self.path = _get_request_path(event)
self.api_prefix = proxy_pattern.sub("", event.get("resource", "")).rstrip("/")
self.api_prefix = proxy_pattern.sub("", event.resource()).rstrip("/")
if not self.apigw_stage and self.path:
path = event.get("path", "")
path = event.path()
suffix = self.api_prefix + self.path
self.path_mapping = path.replace(suffix, "")
else:
Expand Down Expand Up @@ -190,7 +309,7 @@ def __init__(
self.version: str = version
self.routes: List[RouteEntry] = []
self.context: Dict = {}
self.event: Dict = {}
self.event: Event
self.request_path: ApigwPath
self.debug: bool = debug
self.https: bool = https
Expand All @@ -203,8 +322,9 @@ def __init__(
@property
def host(self) -> str:
"""Construct api gateway endpoint url."""
host = self.event["headers"].get(
"x-forwarded-host", self.event["headers"].get("host", "")
# host = self.event["headers"].get(
host = self.event.headers().get(
"x-forwarded-host", self.event.headers().get("host", "")
)
path_info = self.request_path
if path_info.apigw_stage and not path_info.apigw_stage == "$default":
Expand Down Expand Up @@ -494,7 +614,7 @@ def new_func(*args, **kwargs) -> Callable:

def setup_docs(self) -> None:
"""Add default documentation routes."""
openapi_url = f"/openapi.json"
openapi_url = f'{"/openapi.json"}'

def _openapi() -> Tuple[str, str, str]:
"""Return OpenAPI json."""
Expand Down Expand Up @@ -641,15 +761,16 @@ def __call__(self, event, context):
"""Initialize route and handlers."""
self.log.debug(json.dumps(event, default=str))

self.event = event
# self.event = event
self.event = EventFactory.create_from_payload(event)
print(self.event)
self.context = context

# HACK: For an unknown reason some keys can have lower or upper case.
# To make sure the app works well we cast all the keys to lowercase.
headers = self.event.get("headers", {}) or {}
self.event["headers"] = dict(
(key.lower(), value) for key, value in headers.items()
)
headers = event.get("headers", {}) or {}
# headers = self.event.get("headers", {}) or {}
event["headers"] = dict((key.lower(), value) for key, value in headers.items())

self.request_path = ApigwPath(self.event)
if self.request_path.path is None:
Expand All @@ -659,7 +780,8 @@ def __call__(self, event, context):
json.dumps({"errorMessage": "Missing or invalid path"}),
)

http_method = event["httpMethod"]
# http_method = event["httpMethod"]
http_method = self.event.http_method()
route_entry = self._url_matching(self.request_path.path, http_method)
if not route_entry:
return self.response(
Expand All @@ -674,7 +796,8 @@ def __call__(self, event, context):
),
)

request_params = event.get("queryStringParameters", {}) or {}
# request_params = event.get("queryStringParameters", {}) or {}
request_params = self.event.query_string_parameters()
if route_entry.token:
if not self._validate_token(request_params.get("access_token")):
return self.response(
Expand All @@ -688,9 +811,12 @@ def __call__(self, event, context):

function_kwargs = self._get_matching_args(route_entry, self.request_path.path)
function_kwargs.update(request_params.copy())
if http_method in ["POST", "PUT", "PATCH"] and event.get("body"):
body = event["body"]
if event.get("isBase64Encoded"):
# if http_method in ["POST", "PUT", "PATCH"] and event.get("body"):
if http_method in ["POST", "PUT", "PATCH"] and self.event.body():
# body = event["body"]
body = self.event.body()
# if event.get("isBase64Encoded"):
if self.event.is_base64_encoded():
body = base64.b64decode(body).decode()
function_kwargs.update(dict(body=body))

Expand All @@ -710,7 +836,8 @@ def __call__(self, event, context):
response[2],
cors=route_entry.cors,
accepted_methods=route_entry.methods,
accepted_compression=self.event["headers"].get("accept-encoding", ""),
# accepted_compression=self.event["headers"].get("accept-encoding", ""),
accepted_compression=self.event.headers().get("accept-encoding", ""),
compression=route_entry.compression,
b64encode=route_entry.b64encode,
ttl=route_entry.ttl,
Expand Down
Loading