diff --git a/aws_lambda_powertools/tracing/base.py b/aws_lambda_powertools/tracing/base.py index 74b146ad6e8..aa074680678 100644 --- a/aws_lambda_powertools/tracing/base.py +++ b/aws_lambda_powertools/tracing/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -import abc +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Generator, Sequence @@ -9,10 +9,10 @@ import traceback -class BaseSegment(abc.ABC): +class BaseSegment(ABC): """Holds common properties and methods on segment and subsegment.""" - @abc.abstractmethod + @abstractmethod def close(self, end_time: int | None = None): """Close the trace entity by setting `end_time` and flip the in progress flag to False. @@ -23,15 +23,15 @@ def close(self, end_time: int | None = None): Time in epoch seconds, by default current time will be used. """ - @abc.abstractmethod + @abstractmethod def add_subsegment(self, subsegment: Any): """Add input subsegment as a child subsegment.""" - @abc.abstractmethod + @abstractmethod def remove_subsegment(self, subsegment: Any): """Remove input subsegment from child subsegments.""" - @abc.abstractmethod + @abstractmethod def put_annotation(self, key: str, value: str | numbers.Number | bool) -> None: """Annotate segment or subsegment with a key-value pair. @@ -45,7 +45,7 @@ def put_annotation(self, key: str, value: str | numbers.Number | bool) -> None: Annotation value """ - @abc.abstractmethod + @abstractmethod def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: """Add metadata to segment or subsegment. Metadata is not indexed but can be later retrieved by BatchGetTraces API. @@ -60,7 +60,7 @@ def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None Metadata namespace, by default 'default' """ - @abc.abstractmethod + @abstractmethod def add_exception(self, exception: BaseException, stack: list[traceback.StackSummary], remote: bool = False): """Add an exception to trace entities. @@ -77,8 +77,8 @@ def add_exception(self, exception: BaseException, stack: list[traceback.StackSum """ -class BaseProvider(abc.ABC): - @abc.abstractmethod +class BaseProvider(ABC): + @abstractmethod @contextmanager def in_subsegment(self, name=None, **kwargs) -> Generator[BaseSegment, None, None]: """Return a subsegment context manger. @@ -91,7 +91,7 @@ def in_subsegment(self, name=None, **kwargs) -> Generator[BaseSegment, None, Non Optional parameters to be propagated to segment """ - @abc.abstractmethod + @abstractmethod @contextmanager def in_subsegment_async(self, name=None, **kwargs) -> Generator[BaseSegment, None, None]: """Return a subsegment async context manger. @@ -104,7 +104,7 @@ def in_subsegment_async(self, name=None, **kwargs) -> Generator[BaseSegment, Non Optional parameters to be propagated to segment """ - @abc.abstractmethod + @abstractmethod def put_annotation(self, key: str, value: str | numbers.Number | bool) -> None: """Annotate current active trace entity with a key-value pair. @@ -118,7 +118,7 @@ def put_annotation(self, key: str, value: str | numbers.Number | bool) -> None: Annotation value """ - @abc.abstractmethod + @abstractmethod def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: """Add metadata to the current active trace entity. @@ -134,7 +134,7 @@ def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None Metadata namespace, by default 'default' """ - @abc.abstractmethod + @abstractmethod def patch(self, modules: Sequence[str]) -> None: """Instrument a set of supported libraries @@ -144,6 +144,6 @@ def patch(self, modules: Sequence[str]) -> None: Set of modules to be patched """ - @abc.abstractmethod + @abstractmethod def patch_all(self) -> None: """Instrument all supported libraries""" diff --git a/aws_lambda_powertools/tracing/provider/__init__.py b/aws_lambda_powertools/tracing/provider/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py new file mode 100644 index 00000000000..6eedd3aac43 --- /dev/null +++ b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager, contextmanager +from numbers import Number +from typing import Any, AsyncGenerator, Generator, Literal, Sequence + +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.lazy_import import LazyLoader +from aws_lambda_powertools.tracing.provider.base import BaseProvider, BaseSpan + +aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) + + +class XraySpan(BaseSpan): + def __init__(self, subsegment): + self.subsegment = subsegment + self.add_subsegment = self.subsegment.add_subsegment + self.remove_subsegment = self.subsegment.remove_subsegment + self.put_annotation = self.subsegment.put_annotation + self.put_metadata = self.subsegment.put_metadata + self.add_exception = self.subsegment.add_exception + self.close = self.subsegment.close + + def set_attribute( + self, + key: str, + value: Any, + category: Literal["Annotation", "Metadata", "Auto"] = "Auto", + **kwargs, + ) -> None: + """ + Set an attribute on this span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + Value for attribute + category: Literal["Annotation","Metadata","Auto"] = "Auto" + This parameter specifies the category of attribute to set. + - **"Annotation"**: Sets the attribute as an Annotation. + - **"Metadata"**: Sets the attribute as Metadata. + - **"Auto" (default)**: Automatically determines the attribute + type based on its value. + + kwargs: Optional[dict] + Optional parameters to be passed to provider.set_attributes + """ + if category == "Annotation": + self.put_annotation(key=key, value=value) + return + + if category == "Metadata": + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + return + + # Auto + if isinstance(value, (str, Number, bool)): + self.put_annotation(key=key, value=value) + return + + # Auto & not in (str, Number, bool) + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + + def record_exception(self, exception: BaseException, **kwargs): + stack = aws_xray_sdk.core.utils.stacktrace.get_stacktrace() + self.add_exception(exception=exception, stack=stack) + + +class AwsXrayProvider(BaseProvider): + + def __init__( + self, + service: str = "", + disabled: bool | None = None, + auto_patch: bool | None = None, + patch_modules: Sequence[str] | None = None, + ): + from aws_xray_sdk.core import xray_recorder # type: ignore + + self.recorder = xray_recorder + self.in_subsegment = self.recorder.in_subsegment + self.in_subsegment_async = self.recorder.in_subsegment_async + + self.service = service + + super().__init__( + service=self.service, + ) + + @contextmanager + def trace(self, name: str, **kwargs) -> Generator[XraySpan, None, None]: + with self.in_subsegment(name=name, **kwargs) as sub_segment: + yield XraySpan(subsegment=sub_segment) + + @asynccontextmanager + async def trace_async(self, name: str, **kwargs) -> AsyncGenerator[XraySpan, None]: + async with self.in_subsegment_async(name=name, **kwargs) as subsegment: + yield XraySpan(subsegment=subsegment) + + def set_attribute( + self, + key: str, + value: Any, + category: Literal["Annotation", "Metadata", "Auto"] = "Auto", + **kwargs, + ) -> None: + """ + Set an attribute on the current active span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + Value for attribute + category: Literal["Annotation","Metadata","Auto"] = "Auto" + This parameter specifies the type of attribute to set. + - **"Annotation"**: Sets the attribute as an Annotation. + - **"Metadata"**: Sets the attribute as Metadata. + - **"Auto" (default)**: Automatically determines the attribute + type based on its value. + + kwargs: Optional[dict] + Optional parameters to be passed to provider.set_attributes + """ + if category == "Annotation": + self.put_annotation(key=key, value=value) + return + + if category == "Metadata": + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + return + + # Auto + if isinstance(value, (str, Number, bool)): + self.put_annotation(key=key, value=value) + return + + # Auto & not in (str, Number, bool) + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + + def put_annotation(self, key: str, value: str | Number | bool) -> None: + return self.recorder.put_annotation(key=key, value=value) + + def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: + return self.recorder.put_metadata(key=key, value=value, namespace=namespace) + + def patch(self, modules: Sequence[str]) -> None: + return aws_xray_sdk.core.patch(modules) + + def patch_all(self) -> None: + return aws_xray_sdk.core.patch_all() diff --git a/aws_lambda_powertools/tracing/provider/base.py b/aws_lambda_powertools/tracing/provider/base.py new file mode 100644 index 00000000000..a29af067887 --- /dev/null +++ b/aws_lambda_powertools/tracing/provider/base.py @@ -0,0 +1,635 @@ +from __future__ import annotations + +import contextlib +import functools +import inspect +import logging +import os +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager, contextmanager +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Sequence, cast, overload + +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice, sanitize_xray_segment_name +from aws_lambda_powertools.shared.types import AnyCallableT + +if TYPE_CHECKING: + from aws_lambda_powertools.tracing.base import BaseSegment + +logger = logging.getLogger(__name__) + + +is_cold_start = True + + +class BaseSpan(ABC): + """A span represents a unit of work or operation within a trace. + Spans are the building blocks of Traces.""" + + @abstractmethod + def set_attribute(self, key: str, value: Any, **kwargs) -> None: + """Set an attribute for a span with a key-value pair. + + Parameters + ---------- + key: str + Attribute key + value: Any + Attribute value + kwargs: dict | None + Optional parameters + """ + + @abstractmethod + def record_exception(self, exception: BaseException, **kwargs): + """Records an exception to this Span. + + Parameters + ---------- + exception: Exception + Caught exception during the execution of this Span + kwargs: dict | None + Optional parameters + """ + + +class BaseProvider(ABC): + """BaseProvider is an abstract base class that defines the expected behavior for tracing providers + used by Tracer. Inheriting classes must implement this interface to be compatible with Tracer. + """ + + def __init__(self, service: str = ""): + self.service = service + + @abstractmethod + @contextmanager + def trace(self, name: str, **kwargs) -> Generator[BaseSpan, None, None]: + """Context manager for creating a new span and set it + as the current span in this tracer's context. + + Exiting the context manager will call the span's end method, + as well as return the current span to its previous value by + returning to the previous context. + + Parameters + ---------- + name: str + Span name + kwargs: dict | None + Optional parameters to be propagated to the span + """ + + @abstractmethod + @asynccontextmanager + def trace_async(self, name: str, **kwargs) -> AsyncGenerator[BaseSpan, None]: + """Async Context manager for creating a new span async and set it + as the current span in this tracer's context. + + Exiting the context manager will call the span's end method, + as well as return the current span to its previous value by + returning to the previous context. + + Parameters + ---------- + name: str + Span name + kwargs: dict | None + Optional parameters to be propagated to the span + """ + + @abstractmethod + def set_attribute(self, key: str, value: Any, **kwargs) -> None: + """set attribute on current active span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + attribute value + kwargs: dict | None + Optional parameters to be propagated to the span + """ + + @abstractmethod + def patch(self, modules: Sequence[str]) -> None: + """Instrument a set of given libraries if supported by provider + See specific provider for more detail + + Exmaple + ------- + tracer = Tracer(service="payment") + libraries = (['aioboto3',mysql]) + # provider.patch will be called by tracer.patch + tracer.patch(libraries) + + Parameters + ---------- + modules: Set[str] + Set of modules to be patched + """ + + @abstractmethod + def patch_all(self) -> None: + """Instrument all supported libraries""" + + def capture_lambda_handler( + self, + lambda_handler: Any = None, + capture_response: bool | None = None, + capture_error: bool | None = None, + ): + """Decorator to create subsegment for lambda handlers + + As Lambda follows (event, context) signature we can remove some of the boilerplate + and also capture any exception any Lambda function throws or its response as metadata + + Parameters + ---------- + lambda_handler : Callable + Method to annotate on + capture_response : bool, optional + Instructs tracer to not include handler's response as metadata + capture_error : bool, optional + Instructs tracer to not include handler's error as metadata, by default True + + Example + ------- + **Lambda function using capture_lambda_handler decorator** + + tracer = Tracer(service="payment") + @tracer.capture_lambda_handler + def handler(event, context): + ... + + **Preventing Tracer to log response as metadata** + + tracer = Tracer(service="payment") + @tracer.capture_lambda_handler(capture_response=False) + def handler(event, context): + ... + + Raises + ------ + err + Exception raised by method + """ + # If handler is None we've been called with parameters + # Return a partial function with args filled + if lambda_handler is None: + logger.debug("Decorator called with parameters") + return functools.partial( + self.capture_lambda_handler, + capture_response=capture_response, + capture_error=capture_error, + ) + + lambda_handler_name = lambda_handler.__name__ + capture_response = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"), + choice=capture_response, + ) + capture_error = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_ERROR_ENV, "true"), + choice=capture_error, + ) + + @functools.wraps(lambda_handler) + def decorate(event, context, **kwargs): + with self.trace(name=f"## {lambda_handler_name}") as subsegment: + try: + logger.debug("Calling lambda handler") + response = lambda_handler(event, context, **kwargs) + logger.debug("Received lambda handler response successfully") + self._add_response_as_metadata( + method_name=lambda_handler_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from {lambda_handler_name}") + self._add_full_exception_as_metadata( + method_name=lambda_handler_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + + raise + finally: + global is_cold_start + logger.debug("Annotating cold start") + subsegment.put_annotation(key="ColdStart", value=is_cold_start) + + if is_cold_start: + is_cold_start = False + + if self.service: + subsegment.put_annotation(key="Service", value=self.service) + + return response + + return decorate + + # see #465 + @overload + def capture_method(self, method: AnyCallableT) -> AnyCallableT: ... # pragma: no cover + + @overload + def capture_method( + self, + method: None = None, + capture_response: bool | None = None, + capture_error: bool | None = None, + ) -> Callable[[AnyCallableT], AnyCallableT]: ... # pragma: no cover + + def capture_method( + self, + method: AnyCallableT | None = None, + capture_response: bool | None = None, + capture_error: bool | None = None, + ) -> AnyCallableT: + """Decorator to create subsegment for arbitrary functions + + It also captures both response and exceptions as metadata + and creates a subsegment named `## ` + # see here: [Qualified name for classes and functions](https://peps.python.org/pep-3155/) + + When running [async functions concurrently](https://docs.python.org/3/library/asyncio-task.html#id6), + methods may impact each others subsegment, and can trigger + and AlreadyEndedException from X-Ray due to async nature. + + For this use case, either use `capture_method` only where + `async.gather` is called, or use `in_subsegment_async` + context manager via our escape hatch mechanism - See examples. + + Parameters + ---------- + method : Callable + Method to annotate on + capture_response : bool, optional + Instructs tracer to not include method's response as metadata + capture_error : bool, optional + Instructs tracer to not include handler's error as metadata, by default True + + Example + ------- + **Custom function using capture_method decorator** + + tracer = Tracer(service="payment") + @tracer.capture_method + def some_function() + + **Custom async method using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + async def confirm_booking(booking_id: str) -> Dict: + resp = call_to_booking_service() + + tracer.put_annotation("BookingConfirmation", resp["requestId"]) + tracer.put_metadata("Booking confirmation", resp) + + return resp + + def lambda_handler(event: dict, context: Any) -> Dict: + booking_id = event.get("booking_id") + asyncio.run(confirm_booking(booking_id=booking_id)) + + **Custom generator function using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + def bookings_generator(booking_id): + resp = call_to_booking_service() + yield resp[0] + yield resp[1] + + def lambda_handler(event: dict, context: Any) -> Dict: + gen = bookings_generator(booking_id=booking_id) + result = list(gen) + + **Custom generator context manager using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + @contextlib.contextmanager + def booking_actions(booking_id): + resp = call_to_booking_service() + yield "example result" + cleanup_stuff() + + def lambda_handler(event: dict, context: Any) -> Dict: + booking_id = event.get("booking_id") + + with booking_actions(booking_id=booking_id) as booking: + result = booking + + **Tracing nested async calls** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + async def get_identity(): + ... + + @tracer.capture_method + async def long_async_call(): + ... + + @tracer.capture_method + async def async_tasks(): + await get_identity() + ret = await long_async_call() + + return { "task": "done", **ret } + + **Safely tracing concurrent async calls with decorator** + + This may not needed once [this bug is closed](https://github.com/aws/aws-xray-sdk-python/issues/164) + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + async def get_identity(): + async with aioboto3.client("sts") as sts: + account = await sts.get_caller_identity() + return account + + async def long_async_call(): + ... + + @tracer.capture_method + async def async_tasks(): + _, ret = await asyncio.gather(get_identity(), long_async_call(), return_exceptions=True) + + return { "task": "done", **ret } + + **Safely tracing each concurrent async calls with escape hatch** + + This may not needed once [this bug is closed](https://github.com/aws/aws-xray-sdk-python/issues/164) + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + async def get_identity(): + async tracer.provider.in_subsegment_async("## get_identity"): + ... + + async def long_async_call(): + async tracer.provider.in_subsegment_async("## long_async_call"): + ... + + @tracer.capture_method + async def async_tasks(): + _, ret = await asyncio.gather(get_identity(), long_async_call(), return_exceptions=True) + + return { "task": "done", **ret } + + Raises + ------ + err + Exception raised by method + """ + # If method is None we've been called with parameters + # Return a partial function with args filled + if method is None: + logger.debug("Decorator called with parameters") + return cast( + AnyCallableT, + functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error), + ) + + # Example: app.ClassA.get_all # noqa ERA001 + # Valid characters can be found at http://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html + method_name = sanitize_xray_segment_name(f"{method.__module__}.{method.__qualname__}") + + capture_response = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"), + choice=capture_response, + ) + capture_error = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_ERROR_ENV, "true"), + choice=capture_error, + ) + + # Maintenance: Need a factory/builder here to simplify this now + if inspect.iscoroutinefunction(method): + return self._decorate_async_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + elif inspect.isgeneratorfunction(method): + return self._decorate_generator_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): + return self._decorate_generator_function_with_context_manager( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + else: + return self._decorate_sync_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + + def _decorate_async_function( + self, + method: Callable, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, + ): + @functools.wraps(method) + async def decorate(*args, **kwargs): + async with self.trace_async(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = await method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return response + + return decorate + + def _decorate_generator_function( + self, + method: Callable, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, + ): + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + result = yield from method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=result, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return result + + return decorate + + def _decorate_generator_function_with_context_manager( + self, + method: Callable, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, + ): + @functools.wraps(method) + @contextlib.contextmanager + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + with method(*args, **kwargs) as return_val: + result = return_val + yield result + self._add_response_as_metadata( + method_name=method_name, + data=result, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return decorate + + def _decorate_sync_function( + self, + method: AnyCallableT, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, + ) -> AnyCallableT: + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return response + + return cast(AnyCallableT, decorate) + + def _add_response_as_metadata( + self, + method_name: str | None = None, + data: Any | None = None, + subsegment: BaseSegment | None = None, + capture_response: bool | str | None = None, + ): + """Add response as metadata for given subsegment + + Parameters + ---------- + method_name : str, optional + method name to add as metadata key, by default None + data : Any, optional + data to add as subsegment metadata, by default None + subsegment : BaseSegment, optional + existing subsegment to add metadata on, by default None + capture_response : bool, optional + Do not include response as metadata + """ + if data is None or not capture_response or subsegment is None: + return + + subsegment.put_metadata(key=f"{method_name} response", value=data, namespace=self.service) + + def _add_full_exception_as_metadata( + self, + method_name: str, + error: Exception, + subsegment: BaseSegment, + capture_error: bool | None = None, + ): + """Add full exception object as metadata for given subsegment + + Parameters + ---------- + method_name : str + method name to add as metadata key, by default None + error : Exception + error to add as subsegment metadata, by default None + subsegment : BaseSegment + existing subsegment to add metadata on, by default None + capture_error : bool, optional + Do not include error as metadata, by default True + """ + if not capture_error: + return + + subsegment.put_metadata(key=f"{method_name} error", value=error, namespace=self.service) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 0e2bbeb4a3e..92e51dca183 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -6,7 +6,7 @@ import inspect import logging import os -from typing import TYPE_CHECKING, Any, Callable, Sequence, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, cast, overload from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import ( @@ -16,6 +16,7 @@ ) from aws_lambda_powertools.shared.lazy_import import LazyLoader from aws_lambda_powertools.shared.types import AnyCallableT +from aws_lambda_powertools.tracing.provider.aws_xray.aws_xray_tracer import AwsXrayProvider if TYPE_CHECKING: import numbers @@ -27,6 +28,8 @@ aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) +T = TypeVar("T") + class Tracer: """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions @@ -255,7 +258,7 @@ def patch(self, modules: Sequence[str] | None = None): def capture_lambda_handler( self, - lambda_handler: Callable[[dict, Any], Any] | Callable[[dict, Any, dict | None], Any] | None = None, + lambda_handler: Callable[[T, Any], Any] | Callable[[T, Any, Any], Any] | None = None, capture_response: bool | None = None, capture_error: bool | None = None, ): @@ -803,27 +806,23 @@ def __build_config( is_service = resolve_env_var_choice(choice=service, env=os.getenv(constants.SERVICE_NAME_ENV)) # Logic: Choose overridden option first, previously cached config, or default if available - self._config["provider"] = provider or self._config["provider"] or self._patch_xray_provider() self._config["auto_patch"] = auto_patch if auto_patch is not None else self._config["auto_patch"] self._config["service"] = is_service or self._config["service"] self._config["disabled"] = is_disabled or self._config["disabled"] self._config["patch_modules"] = patch_modules or self._config["patch_modules"] + self._config["provider"] = provider or self._config["provider"] or self._patch_xray_provider() @classmethod def _reset_config(cls): cls._config = copy.copy(cls._default_config) def _patch_xray_provider(self): - # Due to Lazy Import, we need to activate `core` attrib via import - # we also need to include `patch`, `patch_all` methods - # to ensure patch calls are done via the provider - from aws_xray_sdk.core import xray_recorder # type: ignore - - provider = xray_recorder - provider.patch = aws_xray_sdk.core.patch - provider.patch_all = aws_xray_sdk.core.patch_all - - return provider + return AwsXrayProvider( + service=self._config["service"], + auto_patch=self._config["auto_patch"], + patch_modules=self._config["patch_modules"], + disabled=self._config["disabled"], + ) def _disable_xray_trace_batching(self): """Configure X-Ray SDK to send subsegment individually over batching @@ -836,7 +835,7 @@ def _disable_xray_trace_batching(self): aws_xray_sdk.core.xray_recorder.configure(streaming_threshold=0) def _is_xray_provider(self): - return "aws_xray_sdk" in self.provider.__module__ + return any(module in self.provider.__module__ for module in ("aws_xray_sdk", "aws_xray_tracer")) def ignore_endpoint(self, hostname: str | None = None, urls: list[str] | None = None): """If you want to ignore certain httplib requests you can do so based on the hostname or URL that is being diff --git a/poetry.lock b/poetry.lock index e0b25c11ccf..ba6b60f0a54 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3746,6 +3746,17 @@ files = [ doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] test = ["mypy", "pytest", "typing-extensions"] +[[package]] +name = "types-aws-xray-sdk" +version = "2.14.0.20240606" +description = "Typing stubs for aws-xray-sdk" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-aws-xray-sdk-2.14.0.20240606.tar.gz", hash = "sha256:3215f8f80b48c9da9f7ff16021234cd631b538095933d5432e3fa4c5e2d76a22"}, + {file = "types_aws_xray_sdk-2.14.0.20240606-py3-none-any.whl", hash = "sha256:c238ad639bb50896f1326c12bcc36b7832b5bc7c4b5e2b19a7efcd89d7d28b94"}, +] + [[package]] name = "types-awscrt" version = "0.21.0" @@ -4196,4 +4207,4 @@ validation = ["fastjsonschema"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0.0" -content-hash = "e6a93ae2514bd23686e766fcf06cd42cba18822272b07e116436edcaf9b3bfa7" +content-hash = "3b7bb5f4264d95c47b306fa4187235e0662c5ed2cf6c318c8f5b6fe722b8d56e" diff --git a/pyproject.toml b/pyproject.toml index 57df08cd021..c49745887e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ types-redis = "^4.6.0.7" testcontainers = { extras = ["redis"], version = "^3.7.1" } multiprocess = "^0.70.16" boto3-stubs = {extras = ["appconfig", "appconfigdata", "cloudformation", "cloudwatch", "dynamodb", "lambda", "logs", "s3", "secretsmanager", "ssm", "xray"], version = "^1.34.139"} +types-aws-xray-sdk = "^2.14.0.20240606" [tool.coverage.run] source = ["aws_lambda_powertools"] diff --git a/tests/functional/test_tracing.py b/tests/functional/test_tracing.py index 5f48b233d91..c7926c24b7a 100644 --- a/tests/functional/test_tracing.py +++ b/tests/functional/test_tracing.py @@ -3,6 +3,7 @@ import pytest from aws_lambda_powertools import Tracer +from aws_lambda_powertools.tracing.provider.aws_xray.aws_xray_tracer import AwsXrayProvider @pytest.fixture @@ -23,7 +24,7 @@ def service_name(): def test_capture_lambda_handler(dummy_response): # GIVEN tracer lambda handler decorator is used - tracer = Tracer(disabled=True) + tracer = AwsXrayProvider(disabled=True) # WHEN a lambda handler is run @tracer.capture_lambda_handler