|
27 | 27 | _SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
|
28 | 28 | # API GW/ALB decode non-safe URI chars; we must support them too
|
29 | 29 | _UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
|
30 |
| - |
31 | 30 | _NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
|
32 | 31 |
|
33 | 32 |
|
@@ -435,7 +434,7 @@ def __init__(
|
435 | 434 | self._proxy_type = proxy_type
|
436 | 435 | self._routes: List[Route] = []
|
437 | 436 | self._route_keys: List[str] = []
|
438 |
| - self._exception_handlers: Dict[Union[int, Type], Callable] = {} |
| 437 | + self._exception_handlers: Dict[Type, Callable] = {} |
439 | 438 | self._cors = cors
|
440 | 439 | self._cors_enabled: bool = cors is not None
|
441 | 440 | self._cors_methods: Set[str] = {"OPTIONS"}
|
@@ -597,8 +596,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
|
597 | 596 | headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
|
598 | 597 | return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
|
599 | 598 |
|
600 |
| - # Allow for custom exception handlers |
601 |
| - handler = self._exception_handlers.get(404) |
| 599 | + handler = self._lookup_exception_handler(NotFoundError) |
602 | 600 | if handler:
|
603 | 601 | return ResponseBuilder(handler(NotFoundError()))
|
604 | 602 |
|
@@ -635,6 +633,40 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
|
635 | 633 |
|
636 | 634 | raise
|
637 | 635 |
|
| 636 | + def not_found(self, func: Callable): |
| 637 | + return self.exception_handler(NotFoundError)(func) |
| 638 | + |
| 639 | + def exception_handler(self, exc_class: Type[Exception]): |
| 640 | + def register_exception_handler(func: Callable): |
| 641 | + self._exception_handlers[exc_class] = func |
| 642 | + |
| 643 | + return register_exception_handler |
| 644 | + |
| 645 | + def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]: |
| 646 | + # Use "Method Resolution Order" to allow for matching against a base class |
| 647 | + # of an exception |
| 648 | + for cls in exp_type.__mro__: |
| 649 | + if cls in self._exception_handlers: |
| 650 | + return self._exception_handlers[cls] |
| 651 | + return None |
| 652 | + |
| 653 | + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: |
| 654 | + handler = self._lookup_exception_handler(type(exp)) |
| 655 | + if handler: |
| 656 | + return ResponseBuilder(handler(exp), route) |
| 657 | + |
| 658 | + if isinstance(exp, ServiceError): |
| 659 | + return ResponseBuilder( |
| 660 | + Response( |
| 661 | + status_code=exp.status_code, |
| 662 | + content_type=content_types.APPLICATION_JSON, |
| 663 | + body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), |
| 664 | + ), |
| 665 | + route, |
| 666 | + ) |
| 667 | + |
| 668 | + return None |
| 669 | + |
638 | 670 | def _to_response(self, result: Union[Dict, Response]) -> Response:
|
639 | 671 | """Convert the route's result to a Response
|
640 | 672 |
|
@@ -679,38 +711,6 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
|
679 | 711 |
|
680 | 712 | self.route(*route)(func)
|
681 | 713 |
|
682 |
| - def not_found(self, func: Callable): |
683 |
| - return self.exception_handler(404)(func) |
684 |
| - |
685 |
| - def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]): |
686 |
| - def register_exception_handler(func: Callable): |
687 |
| - self._exception_handlers[exc_class_or_status_code] = func |
688 |
| - |
689 |
| - return register_exception_handler |
690 |
| - |
691 |
| - def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]: |
692 |
| - for cls in type(exp).__mro__: |
693 |
| - if cls in self._exception_handlers: |
694 |
| - return self._exception_handlers[cls] |
695 |
| - return None |
696 |
| - |
697 |
| - def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: |
698 |
| - handler = self._lookup_exception_handler(exp) |
699 |
| - if handler: |
700 |
| - return ResponseBuilder(handler(exp), route) |
701 |
| - |
702 |
| - if isinstance(exp, ServiceError): |
703 |
| - return ResponseBuilder( |
704 |
| - Response( |
705 |
| - status_code=exp.status_code, |
706 |
| - content_type=content_types.APPLICATION_JSON, |
707 |
| - body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), |
708 |
| - ), |
709 |
| - route, |
710 |
| - ) |
711 |
| - |
712 |
| - return None |
713 |
| - |
714 | 714 |
|
715 | 715 | class Router(BaseRouter):
|
716 | 716 | """Router helper class to allow splitting ApiGatewayResolver into multiple files"""
|
|
0 commit comments