diff --git a/azure/functions/__init__.py b/azure/functions/__init__.py index aa17e277..a4638277 100644 --- a/azure/functions/__init__.py +++ b/azure/functions/__init__.py @@ -8,6 +8,7 @@ from ._http import HttpRequest from ._http import HttpResponse from ._http_wsgi import WsgiMiddleware +from ._http_asgi import AsgiMiddleware from .kafka import KafkaEvent, KafkaConverter, KafkaTriggerConverter from ._queue import QueueMessage from ._servicebus import ServiceBusMessage @@ -57,6 +58,7 @@ # Middlewares 'WsgiMiddleware', + 'AsgiMiddleware', # Extensions 'AppExtensionBase', diff --git a/azure/functions/_http_asgi.py b/azure/functions/_http_asgi.py new file mode 100644 index 00000000..cdae357e --- /dev/null +++ b/azure/functions/_http_asgi.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +from typing import Callable, Dict, List, Tuple, Optional, Any, Union +import logging +from wsgiref.headers import Headers + +from ._abc import Context +from ._http import HttpRequest, HttpResponse +from ._http_wsgi import WsgiRequest + + +class AsgiRequest(WsgiRequest): + def __init__(self, func_req: HttpRequest, + func_ctx: Optional[Context] = None): + self.asgi_version = "2.1" + self.asgi_spec_version = "2.1" + self._headers = func_req.headers + super().__init__(func_req, func_ctx) + + def _get_encoded_http_headers(self) -> List[Tuple[bytes, bytes]]: + return [(k.encode("utf8"), v.encode("utf8")) + for k, v in self._headers.items()] + + def _get_server_address(self): + if self.server_name is not None: + return (self.server_name, int(self.server_port)) + return None + + def to_asgi_http_scope(self): + return { + "type": "http", + "asgi.version": self.asgi_version, + "asgi.spec_version": self.asgi_spec_version, + "http_version": "1.1", + "method": self.request_method, + "scheme": "https", + "path": self.path_info, + "raw_path": self.path_info.encode("utf-8"), + "query_string": self.query_string.encode("utf-8"), + "root_path": self.script_name, + "headers": self._get_encoded_http_headers(), + "server": self._get_server_address(), + "client": None, + "azure_functions.function_directory": self.af_function_directory, + "azure_functions.function_name": self.af_function_name, + "azure_functions.invocation_id": self.af_invocation_id + } + # Notes, missing client name, port + + +class AsgiResponse: + def __init__(self): + self._status_code = 0 + self._headers: Union[Headers, Dict] = {} + self._buffer: List[bytes] = [] + self._request_body = b"" + + @classmethod + async def from_app(cls, app, scope: Dict[str, Any], + body: bytes) -> "AsgiResponse": + res = cls() + res._request_body = body + await app(scope, res._receive, res._send) + return res + + def to_func_response(self) -> HttpResponse: + lowercased_headers = {k.lower(): v for k, v in self._headers.items()} + return HttpResponse( + body=b"".join(self._buffer), + status_code=self._status_code, + headers=self._headers, + mimetype=lowercased_headers.get("content-type"), + charset=lowercased_headers.get("content-encoding"), + ) + + def _handle_http_response_start(self, message: Dict[str, Any]): + self._headers = Headers( + [(k.decode(), v.decode()) + for k, v in message["headers"]]) + self._status_code = message["status"] + + def _handle_http_response_body(self, message: Dict[str, Any]): + self._buffer.append(message["body"]) + # XXX : Chunked bodies not supported, see + # https://github.com/Azure/azure-functions-host/issues/4926 + + async def _receive(self): + return { + "type": "http.request", + "body": self._request_body, + "more_body": False, + } + + async def _send(self, message): + logging.debug(f"Received {message} from ASGI worker.") + if message["type"] == "http.response.start": + self._handle_http_response_start(message) + elif message["type"] == "http.response.body": + self._handle_http_response_body(message) + elif message["type"] == "http.disconnect": + pass # Nothing todo here + + +class AsgiMiddleware: + def __init__(self, app): + logging.debug("Instantiating ASGI middleware.") + self._app = app + self.loop = asyncio.new_event_loop() + logging.debug("asyncio event loop initialized.") + + # Usage + # main = func.AsgiMiddleware(app).main + @property + def main(self) -> Callable[[HttpRequest, Context], HttpResponse]: + return self._handle + + # Usage + # return func.AsgiMiddleware(app).handle(req, context) + def handle( + self, req: HttpRequest, context: Optional[Context] = None + ) -> HttpResponse: + logging.info(f"Handling {req.url} as ASGI request.") + return self._handle(req, context) + + def _handle(self, req: HttpRequest, + context: Optional[Context]) -> HttpResponse: + asgi_request = AsgiRequest(req, context) + asyncio.set_event_loop(self.loop) + scope = asgi_request.to_asgi_http_scope() + asgi_response = self.loop.run_until_complete( + AsgiResponse.from_app(self._app, scope, req.get_body()) + ) + + return asgi_response.to_func_response() diff --git a/azure/functions/_http_wsgi.py b/azure/functions/_http_wsgi.py index d91672b7..734d2411 100644 --- a/azure/functions/_http_wsgi.py +++ b/azure/functions/_http_wsgi.py @@ -153,7 +153,7 @@ def main(self) -> Callable[[HttpRequest, Context], HttpResponse]: return self._handle # Usage - # return func.WsgiMiddlewawre(app).handle(req, context) + # return func.WsgiMiddleware(app).handle(req, context) def handle(self, req: HttpRequest, context: Optional[Context] = None) -> HttpResponse: diff --git a/tests/test_http_asgi.py b/tests/test_http_asgi.py new file mode 100644 index 00000000..b90097e5 --- /dev/null +++ b/tests/test_http_asgi.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest + +import azure.functions as func +from azure.functions._http_asgi import ( + AsgiMiddleware +) + + +class MockAsgiApplication: + response_code = 200 + response_body = b'' + response_headers = [ + [b"content-type", b"text/plain"], + ] + + async def __call__(self, scope, receive, send): + self.received_scope = scope + # Verify against ASGI specification + assert scope['type'] == 'http' + assert isinstance(scope['type'], str) + + assert scope['asgi.spec_version'] in ['2.0', '2.1'] + assert isinstance(scope['asgi.spec_version'], str) + + assert scope['asgi.version'] in ['2.0', '2.1', '2.2'] + assert isinstance(scope['asgi.version'], str) + + assert scope['http_version'] in ['1.0', '1.1', '2'] + assert isinstance(scope['http_version'], str) + + assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH'] + assert isinstance(scope['method'], str) + + assert scope['scheme'] in ['http', 'https'] + assert isinstance(scope['scheme'], str) + + assert isinstance(scope['path'], str) + assert isinstance(scope['raw_path'], bytes) + assert isinstance(scope['query_string'], bytes) + assert isinstance(scope['root_path'], str) + + assert hasattr(scope['headers'], '__iter__') + for k, v in scope['headers']: + assert isinstance(k, bytes) + assert isinstance(v, bytes) + + assert scope['client'] is None or hasattr(scope['client'], '__iter__') + if scope['client']: + assert len(scope['client']) == 2 + assert isinstance(scope['client'][0], str) + assert isinstance(scope['client'][1], int) + + assert scope['server'] is None or hasattr(scope['server'], '__iter__') + if scope['server']: + assert len(scope['server']) == 2 + assert isinstance(scope['server'][0], str) + assert isinstance(scope['server'][1], int) + + self.received_request = await receive() + assert self.received_request['type'] == 'http.request' + assert isinstance(self.received_request['body'], bytes) + assert isinstance(self.received_request['more_body'], bool) + + await send( + { + "type": "http.response.start", + "status": self.response_code, + "headers": self.response_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": self.response_body, + } + ) + + +class TestHttpAsgiMiddleware(unittest.TestCase): + def _generate_func_request( + self, + method="POST", + url="https://function.azurewebsites.net/api/http?firstname=rt", + headers={ + "Content-Type": "application/json", + "x-ms-site-restricted-token": "xmsrt" + }, + params={ + "firstname": "roger" + }, + route_params={}, + body=b'{ "lastname": "tsang" }' + ) -> func.HttpRequest: + return func.HttpRequest( + method=method, + url=url, + headers=headers, + params=params, + route_params=route_params, + body=body + ) + + def _generate_func_context( + self, + invocation_id='123e4567-e89b-12d3-a456-426655440000', + function_name='httptrigger', + function_directory='/home/roger/wwwroot/httptrigger' + ) -> func.Context: + class MockContext(func.Context): + def __init__(self, ii, fn, fd): + self._invocation_id = ii + self._function_name = fn + self._function_directory = fd + + @property + def invocation_id(self): + return self._invocation_id + + @property + def function_name(self): + return self._function_name + + @property + def function_directory(self): + return self._function_directory + + return MockContext(invocation_id, function_name, function_directory) + + def test_middleware_calls_app(self): + app = MockAsgiApplication() + test_body = b'Hello world!' + app.response_body = test_body + app.response_code = 200 + req = func.HttpRequest(method='get', url='/test', body=b'') + response = AsgiMiddleware(app).handle(req) + + # Verify asserted + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_body(), test_body) + + def test_middleware_calls_app_with_context(self): + app = MockAsgiApplication() + test_body = b'Hello world!' + app.response_body = test_body + app.response_code = 200 + req = self._generate_func_request() + ctx = self._generate_func_context() + response = AsgiMiddleware(app).handle(req, ctx) + + # Verify asserted + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_body(), test_body)