11
11
from typing import Any
12
12
13
13
import orjson
14
+ from asgi_tools import ResponseWebSocket
14
15
from asgiref import typing as asgi_types
15
16
from asgiref.compatibility import guarantee_single_callable
16
17
from servestatic import ServeStaticASGI
26
27
AsgiHttpApp,
27
28
AsgiLifespanApp,
28
29
AsgiWebsocketApp,
30
+ AsgiWebsocketReceive,
31
+ AsgiWebsocketSend,
29
32
Connection,
30
33
Location,
31
34
ReactPyConfig,
@@ -153,41 +156,56 @@ async def __call__(
153
156
send: asgi_types.ASGISendCallable,
154
157
) -> None:
155
158
"""ASGI app for rendering ReactPy Python components."""
156
- dispatcher: asyncio.Task[Any] | None = None
157
- recv_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
158
-
159
159
# Start a loop that handles ASGI websocket events
160
- while True:
161
- event = await receive()
162
- if event["type"] == "websocket.connect":
163
- await send(
164
- {"type": "websocket.accept", "subprotocol": None, "headers": []}
165
- )
166
- dispatcher = asyncio.create_task(
167
- self.run_dispatcher(scope, receive, send, recv_queue)
168
- )
169
-
170
- elif event["type"] == "websocket.disconnect":
171
- if dispatcher:
172
- dispatcher.cancel()
173
- break
174
-
175
- elif event["type"] == "websocket.receive" and event["text"]:
176
- queue_put_func = recv_queue.put(orjson.loads(event["text"]))
177
- await queue_put_func
178
-
179
- async def run_dispatcher(
160
+ async with ReactPyWebsocket(scope, receive, send, parent=self.parent) as ws: # type: ignore
161
+ while True:
162
+ # Wait for the webserver to notify us of a new event
163
+ event: dict[str, Any] = await ws.receive(raw=True) # type: ignore
164
+
165
+ # If the event is a `receive` event, parse the message and send it to the rendering queue
166
+ if event["type"] == "websocket.receive":
167
+ msg: dict[str, str] = orjson.loads(event["text"])
168
+ if msg.get("type") == "layout-event":
169
+ await ws.rendering_queue.put(msg)
170
+ else: # pragma: no cover
171
+ await asyncio.to_thread(
172
+ _logger.warning, f"Unknown message type: {msg.get('type')}"
173
+ )
174
+
175
+ # If the event is a `disconnect` event, break the rendering loop and close the connection
176
+ elif event["type"] == "websocket.disconnect":
177
+ break
178
+
179
+
180
+ class ReactPyWebsocket(ResponseWebSocket):
181
+ def __init__(
180
182
self,
181
183
scope: asgi_types.WebSocketScope,
182
- receive: asgi_types.ASGIReceiveCallable ,
183
- send: asgi_types.ASGISendCallable ,
184
- recv_queue: asyncio.Queue[dict[str, Any]] ,
184
+ receive: AsgiWebsocketReceive ,
185
+ send: AsgiWebsocketSend ,
186
+ parent: ReactPyMiddleware ,
185
187
) -> None:
186
- """Asyncio background task that renders and transmits layout updates of ReactPy components."""
188
+ super().__init__(scope=scope, receive=receive, send=send) # type: ignore
189
+ self.scope = scope
190
+ self.parent = parent
191
+ self.rendering_queue: asyncio.Queue[dict[str, str]] = asyncio.Queue()
192
+ self.dispatcher: asyncio.Task[Any] | None = None
193
+
194
+ async def __aenter__(self) -> ReactPyWebsocket:
195
+ self.dispatcher = asyncio.create_task(self.run_dispatcher())
196
+ return await super().__aenter__() # type: ignore
197
+
198
+ async def __aexit__(self, *_: Any) -> None:
199
+ if self.dispatcher:
200
+ self.dispatcher.cancel()
201
+ await super().__aexit__() # type: ignore
202
+
203
+ async def run_dispatcher(self) -> None:
204
+ """Async background task that renders ReactPy components over a websocket."""
187
205
try:
188
206
# Determine component to serve by analyzing the URL and/or class parameters.
189
207
if self.parent.multiple_root_components:
190
- url_match = re.match(self.parent.dispatcher_pattern, scope["path"])
208
+ url_match = re.match(self.parent.dispatcher_pattern, self. scope["path"])
191
209
if not url_match: # pragma: no cover
192
210
raise RuntimeError("Could not find component in URL path.")
193
211
dotted_path = url_match["dotted_path"]
@@ -203,10 +221,10 @@ async def run_dispatcher(
203
221
204
222
# Create a connection object by analyzing the websocket's query string.
205
223
ws_query_string = urllib.parse.parse_qs(
206
- scope["query_string"].decode(), strict_parsing=True
224
+ self. scope["query_string"].decode(), strict_parsing=True
207
225
)
208
226
connection = Connection(
209
- scope=scope,
227
+ scope=self. scope,
210
228
location=Location(
211
229
path=ws_query_string.get("http_pathname", [""])[0],
212
230
query_string=ws_query_string.get("http_query_string", [""])[0],
@@ -217,20 +235,19 @@ async def run_dispatcher(
217
235
# Start the ReactPy component rendering loop
218
236
await serve_layout(
219
237
Layout(ConnectionContext(component(), value=connection)),
220
- lambda msg: send(
221
- {
222
- "type": "websocket.send",
223
- "text": orjson.dumps(msg).decode(),
224
- "bytes": None,
225
- }
226
- ),
227
- recv_queue.get, # type: ignore
238
+ self.send_json,
239
+ self.rendering_queue.get, # type: ignore
228
240
)
229
241
230
242
# Manually log exceptions since this function is running in a separate asyncio task.
231
243
except Exception as error:
232
244
await asyncio.to_thread(_logger.error, f"{error}\n{traceback.format_exc()}")
233
245
246
+ async def send_json(self, data: Any) -> None:
247
+ return await self._send(
248
+ {"type": "websocket.send", "text": orjson.dumps(data).decode()}
249
+ )
250
+
234
251
235
252
@dataclass
236
253
class StaticFileApp:
0 commit comments