diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index 0140d057f0d..6a3c10847a4 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -1,14 +1,18 @@ import json import logging import os -from typing import Dict, Iterable, Optional, Union +import time +from abc import ABCMeta, abstractmethod +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from ..shared import constants -STD_LOGGING_KEYS = ( +RESERVED_LOG_ATTRS = ( "name", "msg", "args", + "level", "levelname", "levelno", "pathname", @@ -27,50 +31,124 @@ "processName", "process", "asctime", + "location", + "timestamp", ) -class JsonFormatter(logging.Formatter): - """AWS Lambda Logging formatter. +class BasePowertoolsFormatter(logging.Formatter, metaclass=ABCMeta): + @abstractmethod + def append_keys(self, **additional_keys): + raise NotImplementedError() - Formats the log message as a JSON encoded string. If the message is a - dict it will be used directly. If the message can be parsed as JSON, then - the parse d value is used in the output record. + @abstractmethod + def remove_keys(self, keys: Iterable[str]): + raise NotImplementedError() - Originally taken from https://gitlab.com/hadrien/aws_lambda_logging/ - """ +class LambdaPowertoolsFormatter(BasePowertoolsFormatter): + """AWS Lambda Powertools Logging formatter. - def __init__(self, **kwargs): - """Return a JsonFormatter instance. + Formats the log message as a JSON encoded string. If the message is a + dict it will be used directly. + """ - The `json_default` kwarg is used to specify a formatter for otherwise - unserializable values. It must not throw. Defaults to a function that - coerces the value to a string. + default_time_format = "%Y-%m-%d %H:%M:%S.%F%z" # '2021-04-17 18:19:57.656+0200' + custom_ms_time_directive = "%F" + + def __init__( + self, + json_serializer: Optional[Callable[[Any], Any]] = None, + json_deserializer: Optional[Callable[[Any], Any]] = None, + json_default: Optional[Callable[[Any], Any]] = None, + datefmt: str = None, + log_record_order: List[str] = None, + utc: bool = False, + **kwargs + ): + """Return a LambdaPowertoolsFormatter instance. The `log_record_order` kwarg is used to specify the order of the keys used in the structured json logs. By default the order is: "level", "location", "message", "timestamp", "service" and "sampling_rate". Other kwargs are used to specify log field format strings. + + Parameters + ---------- + json_serializer : Callable, optional + function to serialize `obj` to a JSON formatted `str`, by default json.dumps + json_deserializer : Callable, optional + function to deserialize `str`, `bytes`, bytearray` containing a JSON document to a Python `obj`, + by default json.loads + json_default : Callable, optional + function to coerce unserializable values, by default str + + Only used when no custom JSON encoder is set + + datefmt : str, optional + String directives (strftime) to format log timestamp + + See https://docs.python.org/3/library/time.html#time.strftime + utc : bool, optional + set logging timestamp to UTC, by default False to continue to use local time as per stdlib + log_record_order : list, optional + set order of log keys when logging, by default ["level", "location", "message", "timestamp"] + kwargs + Key-value to be included in log messages """ - # Set the default unserializable function, by default values will be cast as str. - self.default_json_formatter = kwargs.pop("json_default", str) - # Set the insertion order for the log messages - self.log_format = dict.fromkeys(kwargs.pop("log_record_order", ["level", "location", "message", "timestamp"])) - self.reserved_keys = ["timestamp", "level", "location"] - # Set the date format used by `asctime` - super(JsonFormatter, self).__init__(datefmt=kwargs.pop("datefmt", None)) + self.json_deserializer = json_deserializer or json.loads + self.json_default = json_default or str + self.json_serializer = json_serializer or partial(json.dumps, default=self.json_default, separators=(",", ":")) + self.datefmt = datefmt + self.utc = utc + self.log_record_order = log_record_order or ["level", "location", "message", "timestamp"] + self.log_format = dict.fromkeys(self.log_record_order) # Set the insertion order for the log messages + self.update_formatter = self.append_keys # alias to old method - self.log_format.update(self._build_root_keys(**kwargs)) + if self.utc: + self.converter = time.gmtime + + super(LambdaPowertoolsFormatter, self).__init__(datefmt=self.datefmt) + + keys_combined = {**self._build_default_keys(), **kwargs} + self.log_format.update(**keys_combined) + + def format(self, record: logging.LogRecord) -> str: # noqa: A003 + """Format logging record as structured JSON str""" + formatted_log = self._extract_log_keys(log_record=record) + formatted_log["message"] = self._extract_log_message(log_record=record) + formatted_log["exception"], formatted_log["exception_name"] = self._extract_log_exception(log_record=record) + formatted_log["xray_trace_id"] = self._get_latest_trace_id() + formatted_log = self._strip_none_records(records=formatted_log) + + return self.json_serializer(formatted_log) + + def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: + record_ts = self.converter(record.created) + if datefmt: + return time.strftime(datefmt, record_ts) + + # NOTE: Python `time.strftime` doesn't provide msec directives + # so we create a custom one (%F) and replace logging record ts + # Reason 2 is that std logging doesn't support msec after TZ + msecs = "%03d" % record.msecs + custom_fmt = self.default_time_format.replace(self.custom_ms_time_directive, msecs) + return time.strftime(custom_fmt, record_ts) + + def append_keys(self, **additional_keys): + self.log_format.update(additional_keys) + + def remove_keys(self, keys: Iterable[str]): + for key in keys: + self.log_format.pop(key, None) @staticmethod - def _build_root_keys(**kwargs): + def _build_default_keys(): return { "level": "%(levelname)s", "location": "%(funcName)s:%(lineno)d", "timestamp": "%(asctime)s", - **kwargs, } @staticmethod @@ -78,12 +156,8 @@ def _get_latest_trace_id(): xray_trace_id = os.getenv(constants.XRAY_TRACE_ID_ENV) return xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None - def update_formatter(self, **kwargs): - self.log_format.update(kwargs) - - @staticmethod - def _extract_log_message(log_record: logging.LogRecord) -> Union[Dict, str, bool, Iterable]: - """Extract message from log record and attempt to JSON decode it + def _extract_log_message(self, log_record: logging.LogRecord) -> Union[Dict[str, Any], str, bool, Iterable]: + """Extract message from log record and attempt to JSON decode it if str Parameters ---------- @@ -95,20 +169,19 @@ def _extract_log_message(log_record: logging.LogRecord) -> Union[Dict, str, bool message: Union[Dict, str, bool, Iterable] Extracted message """ - if isinstance(log_record.msg, dict): - return log_record.msg + message = log_record.msg + if isinstance(message, dict): + return message - message: str = log_record.getMessage() - - # Attempt to decode non-str messages e.g. msg = '{"x": "y"}' - try: - message = json.loads(log_record.msg) - except (json.decoder.JSONDecodeError, TypeError, ValueError): - pass + if isinstance(message, str): # could be a JSON string + try: + message = self.json_deserializer(message) + except (json.decoder.JSONDecodeError, TypeError, ValueError): + pass return message - def _extract_log_exception(self, log_record: logging.LogRecord) -> Optional[str]: + def _extract_log_exception(self, log_record: logging.LogRecord) -> Union[Tuple[str, str], Tuple[None, None]]: """Format traceback information, if available Parameters @@ -118,33 +191,15 @@ def _extract_log_exception(self, log_record: logging.LogRecord) -> Optional[str] Returns ------- - log_record: Optional[str] - Log record with constant traceback info + log_record: Optional[Tuple[str, str]] + Log record with constant traceback info and exception name """ if log_record.exc_info: - return self.formatException(log_record.exc_info) - - return None + return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__ - def _extract_log_exception_name(self, log_record: logging.LogRecord) -> Optional[str]: - """Extract the exception name, if available + return None, None - Parameters - ---------- - log_record : logging.LogRecord - Log record to extract exception name from - - Returns - ------- - log_record: Optional[str] - Log record with exception name - """ - if log_record.exc_info: - return log_record.exc_info[0].__name__ - - return None - - def _extract_log_keys(self, log_record: logging.LogRecord) -> Dict: + def _extract_log_keys(self, log_record: logging.LogRecord) -> Dict[str, Any]: """Extract and parse custom and reserved log keys Parameters @@ -157,36 +212,27 @@ def _extract_log_keys(self, log_record: logging.LogRecord) -> Dict: formatted_log: Dict Structured log as dictionary """ - record_dict = log_record.__dict__.copy() # has extra kwargs we are after - record_dict["asctime"] = self.formatTime(log_record, self.datefmt) + record_dict = log_record.__dict__.copy() + record_dict["asctime"] = self.formatTime(record=log_record, datefmt=self.datefmt) + extras = {k: v for k, v in record_dict.items() if k not in RESERVED_LOG_ATTRS} - formatted_log = {} + formatted_log = {**extras} - # We have to iterate over a default or existing log structure - # then replace any logging expression for reserved keys e.g. '%(level)s' to 'INFO' - # and lastly add or replace incoming keys (those added within the constructor or .structure_logs method) + # Iterate over a default or existing log structure + # then replace any std log attribute e.g. '%(level)s' to 'INFO', '%(process)d to '4773' + # lastly add or replace incoming keys (those added within the constructor or .structure_logs method) for key, value in self.log_format.items(): - if value and key in self.reserved_keys: + if value and key in RESERVED_LOG_ATTRS: formatted_log[key] = value % record_dict else: formatted_log[key] = value - # pick up extra keys when logging a new message e.g. log.info("my message", extra={"additional_key": "value"} - # these messages will be added to the root of the final structure not within `message` key - for key, value in record_dict.items(): - if key not in STD_LOGGING_KEYS: - formatted_log[key] = value - return formatted_log - def format(self, record): # noqa: A003 - formatted_log = self._extract_log_keys(log_record=record) - formatted_log["message"] = self._extract_log_message(log_record=record) - formatted_log["exception_name"] = self._extract_log_exception_name(log_record=record) - formatted_log["exception"] = self._extract_log_exception(log_record=record) - formatted_log.update({"xray_trace_id": self._get_latest_trace_id()}) # fetch latest Trace ID, if any + @staticmethod + def _strip_none_records(records: Dict[str, Any]) -> Dict[str, Any]: + """Remove any key with None as value""" + return {k: v for k, v in records.items() if v is not None} - # Filter out top level key with values that are None - formatted_log = {k: v for k, v in formatted_log.items() if v is not None} - return json.dumps(formatted_log, default=self.default_json_formatter) +JsonFormatter = LambdaPowertoolsFormatter # alias to previous formatter diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 2e9cbb78d2e..77e0f3db059 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -4,7 +4,7 @@ import os import random import sys -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Iterable, Optional, Union import jmespath @@ -12,7 +12,7 @@ from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice from .exceptions import InvalidLoggerSamplingRateError from .filters import SuppressFilter -from .formatter import JsonFormatter +from .formatter import BasePowertoolsFormatter, LambdaPowertoolsFormatter from .lambda_context import build_lambda_context_model logger = logging.getLogger(__name__) @@ -42,6 +42,7 @@ def _is_cold_start() -> bool: # so we need to return to subclassing removed in #97 # All methods/properties continue to be proxied to inner logger # https://github.com/awslabs/aws-lambda-powertools-python/issues/107 +# noinspection PyRedeclaration class Logger(logging.Logger): # lgtm [py/missing-call-to-init] """Creates and setups a logger to format statements in JSON. @@ -69,6 +70,28 @@ class Logger(logging.Logger): # lgtm [py/missing-call-to-init] sample rate for debug calls within execution context defaults to 0.0 stream: sys.stdout, optional valid output for a logging stream, by default sys.stdout + logger_formatter: BasePowertoolsFormatter, optional + custom logging formatter that implements BasePowertoolsFormatter + logger_handler: logging.Handler, optional + custom logging handler e.g. logging.FileHandler("file.log") + + Parameters propagated to LambdaPowertoolsFormatter + --------------------------------------------- + datefmt: str, optional + String directives (strftime) to format log timestamp, by default it uses RFC 3339. + json_serializer : Callable, optional + function to serialize `obj` to a JSON formatted `str`, by default json.dumps + json_deserializer : Callable, optional + function to deserialize `str`, `bytes`, bytearray` containing a JSON document to a Python `obj`, + by default json.loads + json_default : Callable, optional + function to coerce unserializable values, by default `str()` + + Only used when no custom JSON encoder is set + utc : bool, optional + set logging timestamp to UTC, by default False to continue to use local time as per stdlib + log_record_order : list, optional + set order of log keys when logging, by default ["level", "location", "message", "timestamp"] Example ------- @@ -96,7 +119,7 @@ class Logger(logging.Logger): # lgtm [py/missing-call-to-init] >>> logger = Logger(service="payment") >>> >>> def handler(event, context): - logger.structure_logs(append=True, payment_id=event["payment_id"]) + logger.append_keys(payment_id=event["payment_id"]) logger.info("Hello") **Create child Logger using logging inheritance via child param** @@ -110,6 +133,30 @@ class Logger(logging.Logger): # lgtm [py/missing-call-to-init] >>> from aws_lambda_powertools import Logger >>> logger = Logger(service="payment", child=True) + **Logging in UTC timezone** + + >>> # app.py + >>> import logging + >>> from aws_lambda_powertools import Logger + >>> + >>> logger = Logger(service="payment", utc=True) + + **Brings message as the first key in log statements** + + >>> # app.py + >>> import logging + >>> from aws_lambda_powertools import Logger + >>> + >>> logger = Logger(service="payment", log_record_order=["message"]) + + **Logging to a file instead of standard output for testing** + + >>> # app.py + >>> import logging + >>> from aws_lambda_powertools import Logger + >>> + >>> logger = Logger(service="payment", logger_handler=logging.FileHandler("log.json")) + Raises ------ InvalidLoggerSamplingRateError @@ -123,22 +170,26 @@ def __init__( child: bool = False, sampling_rate: float = None, stream: sys.stdout = None, + logger_formatter: Optional[BasePowertoolsFormatter] = None, + logger_handler: Optional[logging.Handler] = None, **kwargs, ): self.service = resolve_env_var_choice( choice=service, env=os.getenv(constants.SERVICE_NAME_ENV, "service_undefined") ) self.sampling_rate = resolve_env_var_choice( - choice=sampling_rate, env=os.getenv(constants.LOGGER_LOG_SAMPLING_RATE, 0.0) + choice=sampling_rate, env=os.getenv(constants.LOGGER_LOG_SAMPLING_RATE) ) + self.child = child + self.logger_formatter = logger_formatter + self.logger_handler = logger_handler or logging.StreamHandler(stream) + self.log_level = self._get_log_level(level) self._is_deduplication_disabled = resolve_truthy_env_var_choice( env=os.getenv(constants.LOGGER_LOG_DEDUPLICATION_ENV, "false") ) - self.log_level = self._get_log_level(level) - self.child = child - self._handler = logging.StreamHandler(stream) if stream is not None else logging.StreamHandler(sys.stdout) self._default_log_keys = {"service": self.service, "sampling_rate": self.sampling_rate} self._logger = self._get_logger() + self._init_logger(**kwargs) def __getattr__(self, name): @@ -168,7 +219,7 @@ def _init_logger(self, **kwargs): self._configure_sampling() self._logger.setLevel(self.log_level) - self._logger.addHandler(self._handler) + self._logger.addHandler(self.logger_handler) self.structure_logs(**kwargs) # Pytest Live Log feature duplicates log records for colored output @@ -269,7 +320,7 @@ def handler(event, context): def decorate(event, context): lambda_context = build_lambda_context_model(context) cold_start = _is_cold_start() - self.structure_logs(append=True, cold_start=cold_start, **lambda_context.__dict__) + self.append_keys(cold_start=cold_start, **lambda_context.__dict__) if correlation_id_path: self.set_correlation_id(jmespath.search(correlation_id_path, event)) @@ -282,30 +333,44 @@ def decorate(event, context): return decorate - def structure_logs(self, append: bool = False, **kwargs): + def append_keys(self, **additional_keys): + self.registered_formatter.append_keys(**additional_keys) + + def remove_keys(self, keys: Iterable[str]): + self.registered_formatter.remove_keys(keys) + + @property + def registered_handler(self) -> logging.Handler: + """Convenience property to access logger handler""" + handlers = self._logger.parent.handlers if self.child else self._logger.handlers + return handlers[0] + + @property + def registered_formatter(self) -> Optional[BasePowertoolsFormatter]: + """Convenience property to access logger formatter""" + return self.registered_handler.formatter + + def structure_logs(self, append: bool = False, **keys): """Sets logging formatting to JSON. Optionally, it can append keyword arguments - to an existing logger so it is available - across future log statements. + to an existing logger so it is available across future log statements. Last keyword argument and value wins if duplicated. Parameters ---------- append : bool, optional - [description], by default False + append keys provided to logger formatter, by default False """ - # Child loggers don't have handlers attached, use its parent handlers - handlers = self._logger.parent.handlers if self.child else self._logger.handlers - for handler in handlers: - if append: - # Update existing formatter in an existing logger handler - handler.formatter.update_formatter(**kwargs) - else: - # Set a new formatter for a logger handler - handler.setFormatter(JsonFormatter(**self._default_log_keys, **kwargs)) + if append: + # Maintenance: Add deprecation warning for major version. Refer to append_keys() when docs are updated + self.append_keys(**keys) + else: + log_keys = {**self._default_log_keys, **keys} + formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys) + self.registered_handler.setFormatter(formatter) def set_correlation_id(self, value: str): """Sets the correlation_id in the logging json @@ -315,7 +380,7 @@ def set_correlation_id(self, value: str): value : str Value for the correlation id """ - self.structure_logs(append=True, correlation_id=value) + self.append_keys(correlation_id=value) @staticmethod def _get_log_level(level: Union[str, int, None]) -> Union[str, int]: diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index ddf5ee226f5..ba6e82b72af 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -5,12 +5,14 @@ import random import string from collections import namedtuple +from typing import Iterable import pytest from aws_lambda_powertools import Logger, Tracer from aws_lambda_powertools.logging import correlation_paths from aws_lambda_powertools.logging.exceptions import InvalidLoggerSamplingRateError +from aws_lambda_powertools.logging.formatter import BasePowertoolsFormatter from aws_lambda_powertools.logging.logger import set_package_logger from aws_lambda_powertools.shared import constants @@ -474,3 +476,89 @@ def handler(event, context): # THEN log = capture_logging_output(stdout) assert request_id == log["correlation_id"] + + +def test_logger_append_remove_keys(stdout, service_name): + # GIVEN a Logger is initialized + logger = Logger(service=service_name, stream=stdout) + extra_keys = {"request_id": "id", "context": "value"} + + # WHEN keys are updated + logger.append_keys(**extra_keys) + logger.info("message with new keys") + + # And removed + logger.remove_keys(extra_keys.keys()) + logger.info("message after keys being removed") + + # THEN additional keys should only be present in the first log statement + extra_keys_log, keys_removed_log = capture_multiple_logging_statements_output(stdout) + + assert extra_keys.items() <= extra_keys_log.items() + assert (extra_keys.items() <= keys_removed_log.items()) is False + + +def test_logger_custom_formatter(stdout, service_name, lambda_context): + class CustomFormatter(BasePowertoolsFormatter): + custom_format = {} + + def append_keys(self, **additional_keys): + self.custom_format.update(additional_keys) + + def remove_keys(self, keys: Iterable[str]): + for key in keys: + self.custom_format.pop(key, None) + + def format(self, record: logging.LogRecord) -> str: # noqa: A003 + return json.dumps( + { + "message": super().format(record), + "timestamp": self.formatTime(record), + "my_default_key": "test", + **self.custom_format, + } + ) + + custom_formatter = CustomFormatter() + + # GIVEN a Logger is initialized with a custom formatter + logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter) + + # WHEN a lambda function is decorated with logger + @logger.inject_lambda_context + def handler(event, context): + logger.info("Hello") + + handler({}, lambda_context) + + lambda_context_keys = ( + "function_name", + "function_memory_size", + "function_arn", + "function_request_id", + ) + + log = capture_logging_output(stdout) + + # THEN custom key should always be present + # and lambda contextual info should also be in the logs + assert "my_default_key" in log + assert all(k in log for k in lambda_context_keys) + + +def test_logger_custom_handler(lambda_context, service_name, tmp_path): + # GIVEN a Logger is initialized with a FileHandler + log_file = tmp_path / "log.json" + handler = logging.FileHandler(filename=log_file) + logger = Logger(service=service_name, logger_handler=handler) + + # WHEN a log statement happens + @logger.inject_lambda_context + def handler(event, context): + logger.info("custom handler") + + handler({}, lambda_context) + + # THEN we should output to a file not stdout + log = log_file.read_text() + assert "custom handler" in log diff --git a/tests/functional/test_aws_lambda_logging.py b/tests/functional/test_logger_powertools_formatter.py similarity index 84% rename from tests/functional/test_aws_lambda_logging.py rename to tests/functional/test_logger_powertools_formatter.py index b28a753fc98..4b92e6b47b9 100644 --- a/tests/functional/test_aws_lambda_logging.py +++ b/tests/functional/test_logger_powertools_formatter.py @@ -3,6 +3,7 @@ import json import random import string +import time import pytest @@ -167,13 +168,13 @@ def test_log_custom_formatting(stdout, service_name): log_dict: dict = json.loads(stdout.getvalue()) # THEN the `location` and "timestamp" should match the formatting - assert log_dict["location"] == "[test_log_custom_formatting] test_aws_lambda_logging" + assert log_dict["location"] == "[test_log_custom_formatting] test_logger_powertools_formatter" assert log_dict["timestamp"] == "fake-datefmt" def test_log_dict_key_strip_nones(stdout, service_name): # GIVEN a logger confirmation where we set `location` and `timestamp` to None - # Note: level, sampling_rate and service can not be suppressed + # Note: level and service cannot be suppressed logger = Logger(stream=stdout, level=None, location=None, timestamp=None, sampling_rate=None, service=None) # WHEN logging a message @@ -181,8 +182,8 @@ def test_log_dict_key_strip_nones(stdout, service_name): log_dict: dict = json.loads(stdout.getvalue()) - # THEN the keys should only include `level`, `message`, `service`, `sampling_rate` - assert sorted(log_dict.keys()) == ["level", "message", "sampling_rate", "service"] + # THEN the keys should only include `level`, `message`, `service` + assert sorted(log_dict.keys()) == ["level", "message", "service"] assert log_dict["service"] == "service_undefined" @@ -241,3 +242,36 @@ def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch, s assert log_dict_2["xray_trace_id"] == trace_id_2 monkeypatch.delenv(name="_X_AMZN_TRACE_ID") + + +def test_log_custom_std_log_attribute(stdout, service_name): + # GIVEN a logger where we have a standard log attr process + # https://docs.python.org/3/library/logging.html#logrecord-attributes + logger = Logger(service=service_name, stream=stdout, process="%(process)d") + + # WHEN logging a message + logger.info("foo") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN process key should be evaluated + assert "%" not in log_dict["process"] + + +def test_log_in_utc(service_name): + # GIVEN a logger where UTC TZ has been set + logger = Logger(service=service_name, utc=True) + + # THEN logging formatter time converter should use gmtime fn + assert logger._logger.handlers[0].formatter.converter == time.gmtime + + +@pytest.mark.parametrize("message", ["hello", 1.10, {}, [], True, object()]) +def test_logging_various_primitives(stdout, service_name, message): + # GIVEN a logger with default settings + logger = Logger(service=service_name, stream=stdout) + + # WHEN logging a message of multiple common types + # THEN it should raise no serialization/deserialization error + logger.info(message) + json.loads(stdout.getvalue())