diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 7e9c8fb69..b657b360e 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -22,7 +22,8 @@ from . import protos from . import constants -from .logging import error_logger, logger +from .logging import error_logger, logger, is_system_log_category +from .logging import enable_console_logging, disable_console_logging from .tracing import marshall_exception_trace from .utils.wrappers import disable_feature_by @@ -115,13 +116,21 @@ async def dispatch_forever(self): self._loop.set_task_factory( lambda loop, coro: ContextEnabledTask(coro, loop=loop)) + # Attach gRPC logging to the root logger logging_handler = AsyncLoggingHandler() root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) root_logger.addHandler(logging_handler) + + # Since gRPC channel is established, should use it for logging + disable_console_logging() + logger.info('Detach console logging. Switch to gRPC logging') + try: await forever finally: + # Reenable console logging when there's an exception + enable_console_logging() root_logger.removeHandler(logging_handler) finally: DispatcherMeta.__current_dispatcher__ = None @@ -155,10 +164,16 @@ def _on_logging(self, record: logging.LogRecord, formatted_msg: str): else: log_level = getattr(protos.RpcLog, 'None') + if is_system_log_category(record.name): + log_category = protos.RpcLog.RpcLogCategory.System + else: + log_category = protos.RpcLog.RpcLogCategory.User + log = dict( level=log_level, message=formatted_msg, category=record.name, + log_category=log_category ) invocation_id = get_current_invocation_id() @@ -310,12 +325,17 @@ async def _handle__invocation_request(self, req): args[name] = bindings.Out() if fi.is_async: + logger.info('Function is async, request ID: %s,' + 'function ID: %s, invocation ID: %s', + self.request_id, function_id, invocation_id) call_result = await fi.func(**args) else: + logger.info('Function is sync, request ID: %s,' + 'function ID: %s, invocation ID: %s', + self.request_id, function_id, invocation_id) call_result = await self._loop.run_in_executor( self._sync_call_tp, self.__run_sync_func, invocation_id, fi.func, args) - if call_result is not None and not fi.has_return: raise RuntimeError( f'function {fi.name!r} without a $return binding ' @@ -499,10 +519,10 @@ def gen(resp_queue): class AsyncLoggingHandler(logging.Handler): def emit(self, record): - if not record.name.startswith('azure_functions_worker'): - # Skip worker system logs - msg = self.format(record) - Dispatcher.current._on_logging(record, msg) + # Since we disable console log after gRPC channel is initiated + # We should redirect all the messages into dispatcher + msg = self.format(record) + Dispatcher.current._on_logging(record, msg) class ContextEnabledTask(asyncio.Task): diff --git a/azure_functions_worker/logging.py b/azure_functions_worker/logging.py index d2afc7ee5..c2029ca7a 100644 --- a/azure_functions_worker/logging.py +++ b/azure_functions_worker/logging.py @@ -6,6 +6,9 @@ logger = logging.getLogger('azure_functions_worker') error_logger = logging.getLogger('azure_functions_worker_errors') +handler = None +error_handler = None + def setup(log_level, log_destination): if log_level == 'TRACE': @@ -14,9 +17,6 @@ def setup(log_level, log_destination): formatter = logging.Formatter( 'LanguageWorkerConsoleLog %(levelname)s: %(message)s') - error_handler = None - handler = None - if log_destination is None: # With no explicit log destination we do split logging, # errors go into stderr, everything else -- to stdout. @@ -46,3 +46,28 @@ def setup(log_level, log_destination): error_logger.addHandler(error_handler) error_logger.setLevel(getattr(logging, log_level)) + + +def disable_console_logging(): + if logger and handler: + logger.removeHandler(handler) + + if error_logger and error_handler: + error_logger.removeHandler(error_handler) + + +def enable_console_logging(): + if logger and handler: + logger.addHandler(handler) + + if error_logger and error_handler: + error_logger.addHandler(error_handler) + + +def is_system_log_category(ctg: str): + return any( + [ctg.lower().startswith(c) for c in ( + 'azure_functions_worker', + 'azure_functions_worker_errors' + )] + ) diff --git a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto index d9d4e4a4e..3ed1f0586 100644 --- a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto +++ b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto @@ -431,11 +431,12 @@ message RpcException { // Http cookie type. Note that only name and value are used for Http requests message RpcHttpCookie { - // Enum that lets servers require that a cookie shouoldn't be sent with cross-site requests + // Enum that lets servers require that a cookie shouldn't be sent with cross-site requests enum SameSite { None = 0; Lax = 1; Strict = 2; + ExplicitNone = 3; } // Cookie name diff --git a/azure_functions_worker/testutils.py b/azure_functions_worker/testutils.py index 219f27fa0..231216781 100644 --- a/azure_functions_worker/testutils.py +++ b/azure_functions_worker/testutils.py @@ -129,16 +129,13 @@ class WebHostTestCaseMeta(type(unittest.TestCase)): def __new__(mcls, name, bases, dct): for attrname, attr in dct.items(): if attrname.startswith('test_') and callable(attr): - test_case_name = attrname.lstrip('test_') test_case = attr - - check_log_case_name = f'check_log_{test_case_name}' - check_log_case = dct.get(check_log_case_name) + check_log_name = attrname.replace('test_', 'check_log_', 1) + check_log_case = dct.get(check_log_name) @functools.wraps(test_case) def wrapper(self, *args, __meth__=test_case, __check_log__=check_log_case, **kwargs): - if (__check_log__ is not None and callable(__check_log__) and not is_envvar_true(PYAZURE_WEBHOST_DEBUG)): diff --git a/tests/unittests/test_http_functions.py b/tests/unittests/test_http_functions.py index ef44a5663..0ebd04687 100644 --- a/tests/unittests/test_http_functions.py +++ b/tests/unittests/test_http_functions.py @@ -93,6 +93,10 @@ def check_log_async_logging(self, host_out: typing.List[str]): self.assertIn('hello info', host_out) self.assertIn('and another error', host_out) + # We should see "Function is async" + is_async = any([ho.startswith("Function is async") for ho in host_out]) + self.assertTrue(is_async) + def test_sync_logging(self): # Test that logging doesn't *break* things. r = self.webhost.request('GET', 'sync_logging') @@ -100,7 +104,11 @@ def test_sync_logging(self): self.assertEqual(r.text, 'OK-sync') def check_log_sync_logging(self, host_out: typing.List[str]): - self.assertIn('a gracefully handled error') + self.assertIn('a gracefully handled error', host_out) + + # We should see a warning "Function is sync" + is_sync = any([ho.startswith("Function is sync") for ho in host_out]) + self.assertTrue(is_sync) def test_return_context(self): r = self.webhost.request('GET', 'return_context') diff --git a/tests/unittests/test_mock_http_functions.py b/tests/unittests/test_mock_http_functions.py index d18424305..2cf4f2bed 100644 --- a/tests/unittests/test_mock_http_functions.py +++ b/tests/unittests/test_mock_http_functions.py @@ -20,9 +20,10 @@ async def test_call_sync_function_check_logs(self): self.assertEqual(r.response.result.status, protos.StatusResult.Success) - self.assertEqual(len(r.logs), 1) + user_logs = [l for l in r.logs if l.category == 'my function'] + self.assertEqual(len(user_logs), 1) - log = r.logs[0] + log = user_logs[0] self.assertEqual(log.invocation_id, invoke_id) self.assertTrue(log.message.startswith( 'a gracefully handled error')) @@ -45,15 +46,18 @@ async def test_call_async_function_check_logs(self): self.assertEqual(r.response.result.status, protos.StatusResult.Success) - self.assertEqual(len(r.logs), 2) + user_logs = [l for l in r.logs if l.category == 'my function'] + self.assertEqual(len(user_logs), 2) - self.assertEqual(r.logs[0].invocation_id, invoke_id) - self.assertEqual(r.logs[0].message, 'hello info') - self.assertEqual(r.logs[0].level, protos.RpcLog.Information) + first_msg = user_logs[0] + self.assertEqual(first_msg.invocation_id, invoke_id) + self.assertEqual(first_msg.message, 'hello info') + self.assertEqual(first_msg.level, protos.RpcLog.Information) - self.assertEqual(r.logs[1].invocation_id, invoke_id) - self.assertTrue(r.logs[1].message.startswith('and another error')) - self.assertEqual(r.logs[1].level, protos.RpcLog.Error) + second_msg = user_logs[1] + self.assertEqual(second_msg.invocation_id, invoke_id) + self.assertTrue(second_msg.message.startswith('and another error')) + self.assertEqual(second_msg.level, protos.RpcLog.Error) self.assertEqual(r.response.return_value.string, 'OK-async')