Skip to content

Commit eb1d7a8

Browse files
author
Michael Brewer
committed
chore: some refactoring
1 parent 04648fe commit eb1d7a8

File tree

1 file changed

+36
-36
lines changed

1 file changed

+36
-36
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
2828
# API GW/ALB decode non-safe URI chars; we must support them too
2929
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
30-
3130
_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
3231

3332

@@ -435,7 +434,7 @@ def __init__(
435434
self._proxy_type = proxy_type
436435
self._routes: List[Route] = []
437436
self._route_keys: List[str] = []
438-
self._exception_handlers: Dict[Union[int, Type], Callable] = {}
437+
self._exception_handlers: Dict[Type, Callable] = {}
439438
self._cors = cors
440439
self._cors_enabled: bool = cors is not None
441440
self._cors_methods: Set[str] = {"OPTIONS"}
@@ -597,8 +596,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
597596
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
598597
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
599598

600-
# Allow for custom exception handlers
601-
handler = self._exception_handlers.get(404)
599+
handler = self._lookup_exception_handler(NotFoundError)
602600
if handler:
603601
return ResponseBuilder(handler(NotFoundError()))
604602

@@ -635,6 +633,40 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
635633

636634
raise
637635

636+
def not_found(self, func: Callable):
637+
return self.exception_handler(NotFoundError)(func)
638+
639+
def exception_handler(self, exc_class: Type[Exception]):
640+
def register_exception_handler(func: Callable):
641+
self._exception_handlers[exc_class] = func
642+
643+
return register_exception_handler
644+
645+
def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]:
646+
# Use "Method Resolution Order" to allow for matching against a base class
647+
# of an exception
648+
for cls in exp_type.__mro__:
649+
if cls in self._exception_handlers:
650+
return self._exception_handlers[cls]
651+
return None
652+
653+
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
654+
handler = self._lookup_exception_handler(type(exp))
655+
if handler:
656+
return ResponseBuilder(handler(exp), route)
657+
658+
if isinstance(exp, ServiceError):
659+
return ResponseBuilder(
660+
Response(
661+
status_code=exp.status_code,
662+
content_type=content_types.APPLICATION_JSON,
663+
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
664+
),
665+
route,
666+
)
667+
668+
return None
669+
638670
def _to_response(self, result: Union[Dict, Response]) -> Response:
639671
"""Convert the route's result to a Response
640672
@@ -679,38 +711,6 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
679711

680712
self.route(*route)(func)
681713

682-
def not_found(self, func: Callable):
683-
return self.exception_handler(404)(func)
684-
685-
def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]):
686-
def register_exception_handler(func: Callable):
687-
self._exception_handlers[exc_class_or_status_code] = func
688-
689-
return register_exception_handler
690-
691-
def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]:
692-
for cls in type(exp).__mro__:
693-
if cls in self._exception_handlers:
694-
return self._exception_handlers[cls]
695-
return None
696-
697-
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
698-
handler = self._lookup_exception_handler(exp)
699-
if handler:
700-
return ResponseBuilder(handler(exp), route)
701-
702-
if isinstance(exp, ServiceError):
703-
return ResponseBuilder(
704-
Response(
705-
status_code=exp.status_code,
706-
content_type=content_types.APPLICATION_JSON,
707-
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
708-
),
709-
route,
710-
)
711-
712-
return None
713-
714714

715715
class Router(BaseRouter):
716716
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""

0 commit comments

Comments
 (0)