From 5fdbb31020e54ce09dc63252e17a6a37e58d1e98 Mon Sep 17 00:00:00 2001 From: rmorshea Date: Fri, 14 Jan 2022 15:39:55 -0800 Subject: [PATCH] add starlette server implementation --- requirements/pkg-extras.txt | 4 + src/idom/server/fastapi.py | 281 ++--------------- src/idom/server/starlette.py | 291 ++++++++++++++++++ src/idom/server/utils.py | 1 + .../test_server/test_common/test_multiview.py | 6 +- .../test_common/test_per_client_state.py | 6 +- .../test_common/test_shared_state_client.py | 2 + 7 files changed, 328 insertions(+), 263 deletions(-) create mode 100644 src/idom/server/starlette.py diff --git a/requirements/pkg-extras.txt b/requirements/pkg-extras.txt index 6205adf00..4cc2425eb 100644 --- a/requirements/pkg-extras.txt +++ b/requirements/pkg-extras.txt @@ -6,6 +6,10 @@ sanic-cors fastapi >=0.63.0 uvicorn[standard] >=0.13.4 +# extra=starlette +fastapi >=0.16.0 +uvicorn[standard] >=0.13.4 + # extra=flask flask<2.0 flask-cors diff --git a/src/idom/server/fastapi.py b/src/idom/server/fastapi.py index ed73d3710..f36078a48 100644 --- a/src/idom/server/fastapi.py +++ b/src/idom/server/fastapi.py @@ -1,67 +1,25 @@ -from __future__ import annotations +from typing import Optional -import asyncio -import json -import logging -import sys -from asyncio import Future -from threading import Event, Thread, current_thread -from typing import Any, Dict, Optional, Tuple, Union +from fastapi import FastAPI -from fastapi import APIRouter, FastAPI, Request, WebSocket -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse -from fastapi.staticfiles import StaticFiles -from mypy_extensions import TypedDict -from starlette.websockets import WebSocketDisconnect -from uvicorn.config import Config as UvicornConfig -from uvicorn.server import Server as UvicornServer -from uvicorn.supervisors.multiprocess import Multiprocess -from uvicorn.supervisors.statreload import StatReload as ChangeReload - -from idom.config import IDOM_WED_MODULES_DIR -from idom.core.dispatcher import ( - RecvCoroutine, - SendCoroutine, - SharedViewDispatcher, - VdomJsonPatch, - dispatch_single_view, - ensure_shared_view_dispatcher_future, -) -from idom.core.layout import Layout, LayoutEvent from idom.core.proto import ComponentConstructor -from .utils import CLIENT_BUILD_DIR, poll, threaded - - -logger = logging.getLogger(__name__) - - -class Config(TypedDict, total=False): - """Config for :class:`FastApiRenderServer`""" - - cors: Union[bool, Dict[str, Any]] - """Enable or configure Cross Origin Resource Sharing (CORS) - - For more information see docs for ``fastapi.middleware.cors.CORSMiddleware`` - """ - - redirect_root_to_index: bool - """Whether to redirect the root URL (with prefix) to ``index.html``""" - - serve_static_files: bool - """Whether or not to serve static files (i.e. web modules)""" - - url_prefix: str - """The URL prefix where IDOM resources will be served from""" +from .starlette import ( + Config, + StarletteServer, + _setup_common_routes, + _setup_config_and_app, + _setup_shared_view_dispatcher_route, + _setup_single_view_dispatcher_route, +) def PerClientStateServer( constructor: ComponentConstructor, config: Optional[Config] = None, app: Optional[FastAPI] = None, -) -> FastApiServer: - """Return a :class:`FastApiServer` where each client has its own state. +) -> StarletteServer: + """Return a :class:`StarletteServer` where each client has its own state. Implements the :class:`~idom.server.proto.ServerFactory` protocol @@ -70,20 +28,18 @@ def PerClientStateServer( config: Options for configuring server behavior app: An application instance (otherwise a default instance is created) """ - config, app = _setup_config_and_app(config, app) - router = APIRouter(prefix=config["url_prefix"]) - _setup_common_routes(app, router, config) - _setup_single_view_dispatcher_route(router, constructor) - app.include_router(router) - return FastApiServer(app) + config, app = _setup_config_and_app(config, app, FastAPI) + _setup_common_routes(config, app) + _setup_single_view_dispatcher_route(config["url_prefix"], app, constructor) + return StarletteServer(app) def SharedClientStateServer( constructor: ComponentConstructor, config: Optional[Config] = None, app: Optional[FastAPI] = None, -) -> FastApiServer: - """Return a :class:`FastApiServer` where each client shares state. +) -> StarletteServer: + """Return a :class:`StarletteServer` where each client shares state. Implements the :class:`~idom.server.proto.ServerFactory` protocol @@ -92,200 +48,7 @@ def SharedClientStateServer( config: Options for configuring server behavior app: An application instance (otherwise a default instance is created) """ - config, app = _setup_config_and_app(config, app) - router = APIRouter(prefix=config["url_prefix"]) - _setup_common_routes(app, router, config) - _setup_shared_view_dispatcher_route(app, router, constructor) - app.include_router(router) - return FastApiServer(app) - - -class FastApiServer: - """A thin wrapper for running a FastAPI application - - See :class:`idom.server.proto.Server` for more info - """ - - _server: UvicornServer - _current_thread: Thread - - def __init__(self, app: FastAPI) -> None: - self.app = app - self._did_stop = Event() - app.on_event("shutdown")(self._server_did_stop) - - def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None: - self._current_thread = current_thread() - - self._server = server = UvicornServer( - UvicornConfig( - self.app, host=host, port=port, loop="asyncio", *args, **kwargs - ) - ) - - # The following was copied from the uvicorn source with minimal modification. We - # shouldn't need to do this, but unfortunately there's no easy way to gain access to - # the server instance so you can stop it. - # BUG: https://github.com/encode/uvicorn/issues/742 - config = server.config - - if (config.reload or config.workers > 1) and not isinstance( - server.config.app, str - ): # pragma: no cover - logger = logging.getLogger("uvicorn.error") - logger.warning( - "You must pass the application as an import string to enable 'reload' or " - "'workers'." - ) - sys.exit(1) - - if config.should_reload: # pragma: no cover - sock = config.bind_socket() - supervisor = ChangeReload(config, target=server.run, sockets=[sock]) - supervisor.run() - elif config.workers > 1: # pragma: no cover - sock = config.bind_socket() - supervisor = Multiprocess(config, target=server.run, sockets=[sock]) - supervisor.run() - else: - import asyncio - - asyncio.set_event_loop(asyncio.new_event_loop()) - server.run() - - run_in_thread = threaded(run) - - def wait_until_started(self, timeout: Optional[float] = 3.0) -> None: - poll( - f"start {self.app}", - 0.01, - timeout, - lambda: hasattr(self, "_server") and self._server.started, - ) - - def stop(self, timeout: Optional[float] = 3.0) -> None: - self._server.should_exit = True - self._did_stop.wait(timeout) - - async def _server_did_stop(self) -> None: - self._did_stop.set() - - -def _setup_config_and_app( - config: Optional[Config], - app: Optional[FastAPI], -) -> Tuple[Config, FastAPI]: - return ( - { - "cors": False, - "url_prefix": "", - "serve_static_files": True, - "redirect_root_to_index": True, - **(config or {}), # type: ignore - }, - app or FastAPI(), - ) - - -def _setup_common_routes(app: FastAPI, router: APIRouter, config: Config) -> None: - cors_config = config["cors"] - if cors_config: # pragma: no cover - cors_params = ( - cors_config if isinstance(cors_config, dict) else {"allow_origins": ["*"]} - ) - app.add_middleware(CORSMiddleware, **cors_params) - - # This really should be added to the APIRouter, but there's a bug in FastAPI - # BUG: https://github.com/tiangolo/fastapi/issues/1469 - url_prefix = config["url_prefix"] - if config["serve_static_files"]: - app.mount( - f"{url_prefix}/client", - StaticFiles( - directory=str(CLIENT_BUILD_DIR), - html=True, - check_dir=True, - ), - name="idom_static_files", - ) - app.mount( - f"{url_prefix}/modules", - StaticFiles( - directory=str(IDOM_WED_MODULES_DIR.current), - html=True, - check_dir=True, - ), - name="idom_static_files", - ) - - if config["redirect_root_to_index"]: - - @app.route(f"{url_prefix}/") - def redirect_to_index(request: Request) -> RedirectResponse: - return RedirectResponse( - f"{url_prefix}/client/index.html?{request.query_params}" - ) - - -def _setup_single_view_dispatcher_route( - router: APIRouter, constructor: ComponentConstructor -) -> None: - @router.websocket("/stream") - async def model_stream(socket: WebSocket) -> None: - await socket.accept() - send, recv = _make_send_recv_callbacks(socket) - try: - await dispatch_single_view( - Layout(constructor(**dict(socket.query_params))), send, recv - ) - except WebSocketDisconnect as error: - logger.info(f"WebSocket disconnect: {error.code}") - - -def _setup_shared_view_dispatcher_route( - app: FastAPI, router: APIRouter, constructor: ComponentConstructor -) -> None: - dispatcher_future: Future[None] - dispatch_coroutine: SharedViewDispatcher - - @app.on_event("startup") - async def activate_dispatcher() -> None: - nonlocal dispatcher_future - nonlocal dispatch_coroutine - dispatcher_future, dispatch_coroutine = ensure_shared_view_dispatcher_future( - Layout(constructor()) - ) - - @app.on_event("shutdown") - async def deactivate_dispatcher() -> None: - logger.debug("Stopping dispatcher - server is shutting down") - dispatcher_future.cancel() - await asyncio.wait([dispatcher_future]) - - @router.websocket("/stream") - async def model_stream(socket: WebSocket) -> None: - await socket.accept() - - if socket.query_params: - raise ValueError( - "SharedClientState server does not support per-client view parameters" - ) - - send, recv = _make_send_recv_callbacks(socket) - - try: - await dispatch_coroutine(send, recv) - except WebSocketDisconnect as error: - logger.info(f"WebSocket disconnect: {error.code}") - - -def _make_send_recv_callbacks( - socket: WebSocket, -) -> Tuple[SendCoroutine, RecvCoroutine]: - async def sock_send(value: VdomJsonPatch) -> None: - await socket.send_text(json.dumps(value)) - - async def sock_recv() -> LayoutEvent: - return LayoutEvent(**json.loads(await socket.receive_text())) - - return sock_send, sock_recv + config, app = _setup_config_and_app(config, app, FastAPI) + _setup_common_routes(config, app) + _setup_shared_view_dispatcher_route(config["url_prefix"], app, constructor) + return StarletteServer(app) diff --git a/src/idom/server/starlette.py b/src/idom/server/starlette.py new file mode 100644 index 000000000..5b751e140 --- /dev/null +++ b/src/idom/server/starlette.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import sys +from asyncio import Future +from threading import Event, Thread, current_thread +from typing import Any, Dict, Optional, Tuple, TypeVar, Union + +from mypy_extensions import TypedDict +from starlette.applications import Starlette +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import RedirectResponse +from starlette.staticfiles import StaticFiles +from starlette.websockets import WebSocket, WebSocketDisconnect +from uvicorn.config import Config as UvicornConfig +from uvicorn.server import Server as UvicornServer +from uvicorn.supervisors.multiprocess import Multiprocess +from uvicorn.supervisors.statreload import StatReload as ChangeReload + +from idom.config import IDOM_DEBUG_MODE, IDOM_WED_MODULES_DIR +from idom.core.dispatcher import ( + RecvCoroutine, + SendCoroutine, + SharedViewDispatcher, + VdomJsonPatch, + dispatch_single_view, + ensure_shared_view_dispatcher_future, +) +from idom.core.layout import Layout, LayoutEvent +from idom.core.proto import ComponentConstructor + +from .utils import CLIENT_BUILD_DIR, poll, threaded + + +logger = logging.getLogger(__name__) + +_StarletteType = TypeVar("_StarletteType", bound=Starlette) + + +class Config(TypedDict, total=False): + """Config for :class:`StarletteRenderServer`""" + + cors: Union[bool, Dict[str, Any]] + """Enable or configure Cross Origin Resource Sharing (CORS) + + For more information see docs for ``starlette.middleware.cors.CORSMiddleware`` + """ + + redirect_root_to_index: bool + """Whether to redirect the root URL (with prefix) to ``index.html``""" + + serve_static_files: bool + """Whether or not to serve static files (i.e. web modules)""" + + url_prefix: str + """The URL prefix where IDOM resources will be served from""" + + +def PerClientStateServer( + constructor: ComponentConstructor, + config: Optional[Config] = None, + app: Optional[Starlette] = None, +) -> StarletteServer: + """Return a :class:`StarletteServer` where each client has its own state. + + Implements the :class:`~idom.server.proto.ServerFactory` protocol + + Parameters: + constructor: A component constructor + config: Options for configuring server behavior + app: An application instance (otherwise a default instance is created) + """ + config, app = _setup_config_and_app(config, app, Starlette) + _setup_common_routes(config, app) + _setup_single_view_dispatcher_route(config["url_prefix"], app, constructor) + return StarletteServer(app) + + +def SharedClientStateServer( + constructor: ComponentConstructor, + config: Optional[Config] = None, + app: Optional[Starlette] = None, +) -> StarletteServer: + """Return a :class:`StarletteServer` where each client shares state. + + Implements the :class:`~idom.server.proto.ServerFactory` protocol + + Parameters: + constructor: A component constructor + config: Options for configuring server behavior + app: An application instance (otherwise a default instance is created) + """ + config, app = _setup_config_and_app(config, app, Starlette) + _setup_common_routes(config, app) + _setup_shared_view_dispatcher_route(config["url_prefix"], app, constructor) + return StarletteServer(app) + + +class StarletteServer: + """A thin wrapper for running a Starlette application + + See :class:`idom.server.proto.Server` for more info + """ + + _server: UvicornServer + _current_thread: Thread + + def __init__(self, app: Starlette) -> None: + self.app = app + self._did_stop = Event() + app.on_event("shutdown")(self._server_did_stop) + + def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None: + self._current_thread = current_thread() + + self._server = server = UvicornServer( + UvicornConfig( + self.app, host=host, port=port, loop="asyncio", *args, **kwargs + ) + ) + + # The following was copied from the uvicorn source with minimal modification. We + # shouldn't need to do this, but unfortunately there's no easy way to gain access to + # the server instance so you can stop it. + # BUG: https://github.com/encode/uvicorn/issues/742 + config = server.config + + if (config.reload or config.workers > 1) and not isinstance( + server.config.app, str + ): # pragma: no cover + logger = logging.getLogger("uvicorn.error") + logger.warning( + "You must pass the application as an import string to enable 'reload' or " + "'workers'." + ) + sys.exit(1) + + if config.should_reload: # pragma: no cover + sock = config.bind_socket() + supervisor = ChangeReload(config, target=server.run, sockets=[sock]) + supervisor.run() + elif config.workers > 1: # pragma: no cover + sock = config.bind_socket() + supervisor = Multiprocess(config, target=server.run, sockets=[sock]) + supervisor.run() + else: + import asyncio + + asyncio.set_event_loop(asyncio.new_event_loop()) + server.run() + + run_in_thread = threaded(run) + + def wait_until_started(self, timeout: Optional[float] = 3.0) -> None: + poll( + f"start {self.app}", + 0.01, + timeout, + lambda: hasattr(self, "_server") and self._server.started, + ) + + def stop(self, timeout: Optional[float] = 3.0) -> None: + self._server.should_exit = True + self._did_stop.wait(timeout) + + async def _server_did_stop(self) -> None: + self._did_stop.set() + + +def _setup_config_and_app( + config: Optional[Config], + app: Optional[_StarletteType], + app_type: type[_StarletteType], +) -> Tuple[Config, _StarletteType]: + return ( + { + "cors": False, + "url_prefix": "", + "serve_static_files": True, + "redirect_root_to_index": True, + **(config or {}), # type: ignore + }, + app or app_type(debug=IDOM_DEBUG_MODE.current), + ) + + +def _setup_common_routes(config: Config, app: Starlette) -> None: + cors_config = config["cors"] + if cors_config: # pragma: no cover + cors_params = ( + cors_config if isinstance(cors_config, dict) else {"allow_origins": ["*"]} + ) + app.add_middleware(CORSMiddleware, **cors_params) + + # This really should be added to the APIRouter, but there's a bug in Starlette + # BUG: https://github.com/tiangolo/fastapi/issues/1469 + url_prefix = config["url_prefix"] + if config["serve_static_files"]: + app.mount( + f"{url_prefix}/client", + StaticFiles( + directory=str(CLIENT_BUILD_DIR), + html=True, + check_dir=True, + ), + name="idom_client_files", + ) + app.mount( + f"{url_prefix}/modules", + StaticFiles( + directory=str(IDOM_WED_MODULES_DIR.current), + html=True, + check_dir=False, + ), + name="idom_web_module_files", + ) + + if config["redirect_root_to_index"]: + + @app.route(f"{url_prefix}/") + def redirect_to_index(request: Request) -> RedirectResponse: + return RedirectResponse( + f"{url_prefix}/client/index.html?{request.query_params}" + ) + + +def _setup_single_view_dispatcher_route( + url_prefix: str, app: Starlette, constructor: ComponentConstructor +) -> None: + @app.websocket_route(f"{url_prefix}/stream") + async def model_stream(socket: WebSocket) -> None: + await socket.accept() + send, recv = _make_send_recv_callbacks(socket) + try: + await dispatch_single_view( + Layout(constructor(**dict(socket.query_params))), send, recv + ) + except WebSocketDisconnect as error: + logger.info(f"WebSocket disconnect: {error.code}") + + +def _setup_shared_view_dispatcher_route( + url_prefix: str, app: Starlette, constructor: ComponentConstructor +) -> None: + dispatcher_future: Future[None] + dispatch_coroutine: SharedViewDispatcher + + @app.on_event("startup") + async def activate_dispatcher() -> None: + nonlocal dispatcher_future + nonlocal dispatch_coroutine + dispatcher_future, dispatch_coroutine = ensure_shared_view_dispatcher_future( + Layout(constructor()) + ) + + @app.on_event("shutdown") + async def deactivate_dispatcher() -> None: + logger.debug("Stopping dispatcher - server is shutting down") + dispatcher_future.cancel() + await asyncio.wait([dispatcher_future]) + + @app.websocket_route(f"{url_prefix}/stream") + async def model_stream(socket: WebSocket) -> None: + await socket.accept() + + if socket.query_params: + raise ValueError( + "SharedClientState server does not support per-client view parameters" + ) + + send, recv = _make_send_recv_callbacks(socket) + + try: + await dispatch_coroutine(send, recv) + except WebSocketDisconnect as error: + logger.info(f"WebSocket disconnect: {error.code}") + + +def _make_send_recv_callbacks( + socket: WebSocket, +) -> Tuple[SendCoroutine, RecvCoroutine]: + async def sock_send(value: VdomJsonPatch) -> None: + await socket.send_text(json.dumps(value)) + + async def sock_recv() -> LayoutEvent: + return LayoutEvent(**json.loads(await socket.receive_text())) + + return sock_send, sock_recv diff --git a/src/idom/server/utils.py b/src/idom/server/utils.py index 240ccb8e1..1aa775e93 100644 --- a/src/idom/server/utils.py +++ b/src/idom/server/utils.py @@ -22,6 +22,7 @@ "fastapi", "flask", "tornado", + "starlette", ] diff --git a/tests/test_server/test_common/test_multiview.py b/tests/test_server/test_common/test_multiview.py index f8c936aef..56c2deaf8 100644 --- a/tests/test_server/test_common/test_multiview.py +++ b/tests/test_server/test_common/test_multiview.py @@ -4,6 +4,7 @@ from idom.server import fastapi as idom_fastapi from idom.server import flask as idom_flask from idom.server import sanic as idom_sanic +from idom.server import starlette as idom_starlette from idom.server import tornado as idom_tornado from idom.server.prefab import multiview_server from idom.testing import ServerMountPoint @@ -14,10 +15,11 @@ params=[ # add new PerClientStateServer implementations here to # run a suite of tests which check basic functionality - idom_sanic.PerClientStateServer, + idom_fastapi.PerClientStateServer, idom_flask.PerClientStateServer, + idom_sanic.PerClientStateServer, + idom_starlette.PerClientStateServer, idom_tornado.PerClientStateServer, - idom_fastapi.PerClientStateServer, ], ids=lambda cls: f"{cls.__module__}.{cls.__name__}", ) diff --git a/tests/test_server/test_common/test_per_client_state.py b/tests/test_server/test_common/test_per_client_state.py index fe6d0c8bf..29bd2f26f 100644 --- a/tests/test_server/test_common/test_per_client_state.py +++ b/tests/test_server/test_common/test_per_client_state.py @@ -4,6 +4,7 @@ from idom.server import fastapi as idom_fastapi from idom.server import flask as idom_flask from idom.server import sanic as idom_sanic +from idom.server import starlette as idom_starlette from idom.server import tornado as idom_tornado from idom.testing import ServerMountPoint @@ -12,10 +13,11 @@ params=[ # add new PerClientStateServer implementations here to # run a suite of tests which check basic functionality - idom_sanic.PerClientStateServer, + idom_fastapi.PerClientStateServer, idom_flask.PerClientStateServer, + idom_sanic.PerClientStateServer, + idom_starlette.PerClientStateServer, idom_tornado.PerClientStateServer, - idom_fastapi.PerClientStateServer, ], ids=lambda cls: f"{cls.__module__}.{cls.__name__}", ) diff --git a/tests/test_server/test_common/test_shared_state_client.py b/tests/test_server/test_common/test_shared_state_client.py index b40d14993..440e73d07 100644 --- a/tests/test_server/test_common/test_shared_state_client.py +++ b/tests/test_server/test_common/test_shared_state_client.py @@ -6,6 +6,7 @@ import idom from idom.server import fastapi as idom_fastapi from idom.server import sanic as idom_sanic +from idom.server import starlette as idom_starlette from idom.testing import ServerMountPoint @@ -15,6 +16,7 @@ # run a suite of tests which check basic functionality idom_sanic.SharedClientStateServer, idom_fastapi.SharedClientStateServer, + idom_starlette.SharedClientStateServer, ], ids=lambda cls: f"{cls.__module__}.{cls.__name__}", )