diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index 4cd7f402a9..e8267d539d 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -4,7 +4,9 @@ Based on Tom Christie's `sentry-asgi `_. """ +import asyncio import functools +import inspect import urllib from sentry_sdk._types import MYPY @@ -17,6 +19,7 @@ from typing import Dict from typing import Any from typing import Optional + from typing import Callable from sentry_sdk._types import Event, Hint @@ -37,26 +40,45 @@ def _capture_exception(hub, exc): hub.capture_event(event, hint=hint) +def _looks_like_asgi3(app): + # type: (Any) -> bool + """ + Try to figure out if an application object supports ASGI3. + + This is how uvicorn figures out the application version as well. + """ + if inspect.isclass(app): + return hasattr(app, "__await__") + elif inspect.isfunction(app): + return asyncio.iscoroutinefunction(app) + else: + call = getattr(app, "__call__", None) # noqa + return asyncio.iscoroutinefunction(call) + + class SentryAsgiMiddleware: - __slots__ = ("app",) + __slots__ = ("app", "__call__") def __init__(self, app): # type: (Any) -> None self.app = app - def __call__(self, scope, receive=None, send=None): - # type: (Any, Any, Any) -> Any - if receive is None or send is None: + if _looks_like_asgi3(app): + self.__call__ = self._run_asgi3 # type: Callable[..., Any] + else: + self.__call__ = self._run_asgi2 - async def run_asgi2(receive, send): - # type: (Any, Any) -> Any - return await self._run_app( - scope, lambda: self.app(scope)(receive, send) - ) + def _run_asgi2(self, scope): + # type: (Any) -> Any + async def inner(receive, send): + # type: (Any, Any) -> Any + return await self._run_app(scope, lambda: self.app(scope)(receive, send)) - return run_asgi2 - else: - return self._run_app(scope, lambda: self.app(scope, receive, send)) + return inner + + async def _run_asgi3(self, scope, receive, send): + # type: (Any, Any, Any) -> Any + return await self._run_app(scope, lambda: self.app(scope, receive, send)) async def _run_app(self, scope, callback): # type: (Any, Any) -> Any