diff --git a/sentry_sdk/integrations/anthropic.py b/sentry_sdk/integrations/anthropic.py index 08c40bc7b6..87e69a3113 100644 --- a/sentry_sdk/integrations/anthropic.py +++ b/sentry_sdk/integrations/anthropic.py @@ -13,16 +13,15 @@ ) try: - from anthropic.resources import Messages + from anthropic.resources import AsyncMessages, Messages if TYPE_CHECKING: from anthropic.types import MessageStreamEvent except ImportError: raise DidNotEnable("Anthropic not installed") - if TYPE_CHECKING: - from typing import Any, Iterator + from typing import Any, AsyncIterator, Iterator from sentry_sdk.tracing import Span @@ -46,6 +45,7 @@ def setup_once(): raise DidNotEnable("anthropic 0.16 or newer required.") Messages.create = _wrap_message_create(Messages.create) + AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create) def _capture_exception(exc): @@ -75,7 +75,9 @@ def _calculate_token_usage(result, span): def _get_responses(content): # type: (list[Any]) -> list[dict[str, Any]] - """Get JSON of a Anthropic responses.""" + """ + Get JSON of a Anthropic responses. + """ responses = [] for item in content: if hasattr(item, "text"): @@ -88,94 +90,202 @@ def _get_responses(content): return responses +def _collect_ai_data(event, input_tokens, output_tokens, content_blocks): + # type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]] + """ + Count token usage and collect content blocks from the AI streaming response. + """ + with capture_internal_exceptions(): + if hasattr(event, "type"): + if event.type == "message_start": + usage = event.message.usage + input_tokens += usage.input_tokens + output_tokens += usage.output_tokens + elif event.type == "content_block_start": + pass + elif event.type == "content_block_delta": + if hasattr(event.delta, "text"): + content_blocks.append(event.delta.text) + elif event.type == "content_block_stop": + pass + elif event.type == "message_delta": + output_tokens += event.usage.output_tokens + + return input_tokens, output_tokens, content_blocks + + +def _add_ai_data_to_span( + span, integration, input_tokens, output_tokens, content_blocks +): + # type: (Span, AnthropicIntegration, int, int, list[str]) -> None + """ + Add token usage and content blocks from the AI streaming response to the span. + """ + with capture_internal_exceptions(): + if should_send_default_pii() and integration.include_prompts: + complete_message = "".join(content_blocks) + span.set_data( + SPANDATA.AI_RESPONSES, + [{"type": "text", "text": complete_message}], + ) + total_tokens = input_tokens + output_tokens + record_token_usage(span, input_tokens, output_tokens, total_tokens) + span.set_data(SPANDATA.AI_STREAMING, True) + + +def _sentry_patched_create_common(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + integration = kwargs.pop("integration") + if integration is None: + return f(*args, **kwargs) + + if "messages" not in kwargs: + return f(*args, **kwargs) + + try: + iter(kwargs["messages"]) + except TypeError: + return f(*args, **kwargs) + + span = sentry_sdk.start_span( + op=OP.ANTHROPIC_MESSAGES_CREATE, + description="Anthropic messages create", + origin=AnthropicIntegration.origin, + ) + span.__enter__() + + result = yield f, args, kwargs + + # add data to span and finish it + messages = list(kwargs["messages"]) + model = kwargs.get("model") + + with capture_internal_exceptions(): + span.set_data(SPANDATA.AI_MODEL_ID, model) + span.set_data(SPANDATA.AI_STREAMING, False) + + if should_send_default_pii() and integration.include_prompts: + span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages) + + if hasattr(result, "content"): + if should_send_default_pii() and integration.include_prompts: + span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content)) + _calculate_token_usage(result, span) + span.__exit__(None, None, None) + + # Streaming response + elif hasattr(result, "_iterator"): + old_iterator = result._iterator + + def new_iterator(): + # type: () -> Iterator[MessageStreamEvent] + input_tokens = 0 + output_tokens = 0 + content_blocks = [] # type: list[str] + + for event in old_iterator: + input_tokens, output_tokens, content_blocks = _collect_ai_data( + event, input_tokens, output_tokens, content_blocks + ) + if event.type != "message_stop": + yield event + + _add_ai_data_to_span( + span, integration, input_tokens, output_tokens, content_blocks + ) + span.__exit__(None, None, None) + + async def new_iterator_async(): + # type: () -> AsyncIterator[MessageStreamEvent] + input_tokens = 0 + output_tokens = 0 + content_blocks = [] # type: list[str] + + async for event in old_iterator: + input_tokens, output_tokens, content_blocks = _collect_ai_data( + event, input_tokens, output_tokens, content_blocks + ) + if event.type != "message_stop": + yield event + + _add_ai_data_to_span( + span, integration, input_tokens, output_tokens, content_blocks + ) + span.__exit__(None, None, None) + + if str(type(result._iterator)) == "": + result._iterator = new_iterator_async() + else: + result._iterator = new_iterator() + + else: + span.set_data("unknown_response", True) + span.__exit__(None, None, None) + + return result + + def _wrap_message_create(f): # type: (Any) -> Any + def _execute_sync(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _sentry_patched_create_common(f, *args, **kwargs) + + try: + f, args, kwargs = next(gen) + except StopIteration as e: + return e.value + + try: + try: + result = f(*args, **kwargs) + except Exception as exc: + _capture_exception(exc) + raise exc from None + + return gen.send(result) + except StopIteration as e: + return e.value + @wraps(f) - def _sentry_patched_create(*args, **kwargs): + def _sentry_patched_create_sync(*args, **kwargs): # type: (*Any, **Any) -> Any integration = sentry_sdk.get_client().get_integration(AnthropicIntegration) + kwargs["integration"] = integration - if integration is None or "messages" not in kwargs: - return f(*args, **kwargs) + return _execute_sync(f, *args, **kwargs) - try: - iter(kwargs["messages"]) - except TypeError: - return f(*args, **kwargs) + return _sentry_patched_create_sync - messages = list(kwargs["messages"]) - model = kwargs.get("model") - span = sentry_sdk.start_span( - op=OP.ANTHROPIC_MESSAGES_CREATE, - name="Anthropic messages create", - origin=AnthropicIntegration.origin, - ) - span.__enter__() +def _wrap_message_create_async(f): + # type: (Any) -> Any + async def _execute_async(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _sentry_patched_create_common(f, *args, **kwargs) try: - result = f(*args, **kwargs) - except Exception as exc: - _capture_exception(exc) - span.__exit__(None, None, None) - raise exc from None + f, args, kwargs = next(gen) + except StopIteration as e: + return await e.value - with capture_internal_exceptions(): - span.set_data(SPANDATA.AI_MODEL_ID, model) - span.set_data(SPANDATA.AI_STREAMING, False) - if should_send_default_pii() and integration.include_prompts: - span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages) - if hasattr(result, "content"): - if should_send_default_pii() and integration.include_prompts: - span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content)) - _calculate_token_usage(result, span) - span.__exit__(None, None, None) - elif hasattr(result, "_iterator"): - old_iterator = result._iterator - - def new_iterator(): - # type: () -> Iterator[MessageStreamEvent] - input_tokens = 0 - output_tokens = 0 - content_blocks = [] - with capture_internal_exceptions(): - for event in old_iterator: - if hasattr(event, "type"): - if event.type == "message_start": - usage = event.message.usage - input_tokens += usage.input_tokens - output_tokens += usage.output_tokens - elif event.type == "content_block_start": - pass - elif event.type == "content_block_delta": - if hasattr(event.delta, "text"): - content_blocks.append(event.delta.text) - elif event.type == "content_block_stop": - pass - elif event.type == "message_delta": - output_tokens += event.usage.output_tokens - elif event.type == "message_stop": - continue - yield event - - if should_send_default_pii() and integration.include_prompts: - complete_message = "".join(content_blocks) - span.set_data( - SPANDATA.AI_RESPONSES, - [{"type": "text", "text": complete_message}], - ) - total_tokens = input_tokens + output_tokens - record_token_usage( - span, input_tokens, output_tokens, total_tokens - ) - span.set_data(SPANDATA.AI_STREAMING, True) - span.__exit__(None, None, None) + try: + try: + result = await f(*args, **kwargs) + except Exception as exc: + _capture_exception(exc) + raise exc from None - result._iterator = new_iterator() - else: - span.set_data("unknown_response", True) - span.__exit__(None, None, None) + return gen.send(result) + except StopIteration as e: + return e.value + + @wraps(f) + async def _sentry_patched_create_async(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(AnthropicIntegration) + kwargs["integration"] = integration - return result + return await _execute_async(f, *args, **kwargs) - return _sentry_patched_create + return _sentry_patched_create_async diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 272f142b05..e6ac36f3cb 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -15,12 +15,12 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Iterable, List, Optional, Callable, Iterator + from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator from sentry_sdk.tracing import Span try: - from openai.resources.chat.completions import Completions - from openai.resources import Embeddings + from openai.resources.chat.completions import Completions, AsyncCompletions + from openai.resources import Embeddings, AsyncEmbeddings if TYPE_CHECKING: from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk @@ -48,6 +48,11 @@ def setup_once(): Completions.create = _wrap_chat_completion_create(Completions.create) Embeddings.create = _wrap_embeddings_create(Embeddings.create) + AsyncCompletions.create = _wrap_async_chat_completion_create( + AsyncCompletions.create + ) + AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create) + def count_tokens(self, s): # type: (OpenAIIntegration, str) -> int if self.tiktoken_encoding is not None: @@ -109,160 +114,316 @@ def _calculate_chat_completion_usage( record_token_usage(span, prompt_tokens, completion_tokens, total_tokens) +def _new_chat_completion_common(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) + if integration is None: + return f(*args, **kwargs) + + if "messages" not in kwargs: + # invalid call (in all versions of openai), let it return error + return f(*args, **kwargs) + + try: + iter(kwargs["messages"]) + except TypeError: + # invalid call (in all versions), messages must be iterable + return f(*args, **kwargs) + + kwargs["messages"] = list(kwargs["messages"]) + messages = kwargs["messages"] + model = kwargs.get("model") + streaming = kwargs.get("stream") + + span = sentry_sdk.start_span( + op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE, + description="Chat Completion", + origin=OpenAIIntegration.origin, + ) + span.__enter__() + + res = yield f, args, kwargs + + with capture_internal_exceptions(): + if should_send_default_pii() and integration.include_prompts: + set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages) + + set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) + set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) + + if hasattr(res, "choices"): + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, + "ai.responses", + list(map(lambda x: x.message, res.choices)), + ) + _calculate_chat_completion_usage( + messages, res, span, None, integration.count_tokens + ) + span.__exit__(None, None, None) + elif hasattr(res, "_iterator"): + data_buf: list[list[str]] = [] # one for each choice + + old_iterator = res._iterator + + def new_iterator(): + # type: () -> Iterator[ChatCompletionChunk] + with capture_internal_exceptions(): + for x in old_iterator: + if hasattr(x, "choices"): + choice_index = 0 + for choice in x.choices: + if hasattr(choice, "delta") and hasattr( + choice.delta, "content" + ): + content = choice.delta.content + if len(data_buf) <= choice_index: + data_buf.append([]) + data_buf[choice_index].append(content or "") + choice_index += 1 + yield x + if len(data_buf) > 0: + all_responses = list( + map(lambda chunk: "".join(chunk), data_buf) + ) + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, SPANDATA.AI_RESPONSES, all_responses + ) + _calculate_chat_completion_usage( + messages, + res, + span, + all_responses, + integration.count_tokens, + ) + span.__exit__(None, None, None) + + async def new_iterator_async(): + # type: () -> AsyncIterator[ChatCompletionChunk] + with capture_internal_exceptions(): + async for x in old_iterator: + if hasattr(x, "choices"): + choice_index = 0 + for choice in x.choices: + if hasattr(choice, "delta") and hasattr( + choice.delta, "content" + ): + content = choice.delta.content + if len(data_buf) <= choice_index: + data_buf.append([]) + data_buf[choice_index].append(content or "") + choice_index += 1 + yield x + if len(data_buf) > 0: + all_responses = list( + map(lambda chunk: "".join(chunk), data_buf) + ) + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, SPANDATA.AI_RESPONSES, all_responses + ) + _calculate_chat_completion_usage( + messages, + res, + span, + all_responses, + integration.count_tokens, + ) + span.__exit__(None, None, None) + + if str(type(res._iterator)) == "": + res._iterator = new_iterator_async() + else: + res._iterator = new_iterator() + + else: + set_data_normalized(span, "unknown_response", True) + span.__exit__(None, None, None) + return res + + def _wrap_chat_completion_create(f): # type: (Callable[..., Any]) -> Callable[..., Any] + def _execute_sync(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _new_chat_completion_common(f, *args, **kwargs) + + try: + f, args, kwargs = next(gen) + except StopIteration as e: + return e.value + + try: + try: + result = f(*args, **kwargs) + except Exception as e: + _capture_exception(e) + raise e from None + + return gen.send(result) + except StopIteration as e: + return e.value @wraps(f) - def new_chat_completion(*args, **kwargs): + def _sentry_patched_create_sync(*args, **kwargs): # type: (*Any, **Any) -> Any integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None or "messages" not in kwargs: # no "messages" means invalid call (in all versions of openai), let it return error return f(*args, **kwargs) - try: - iter(kwargs["messages"]) - except TypeError: - # invalid call (in all versions), messages must be iterable - return f(*args, **kwargs) + return _execute_sync(f, *args, **kwargs) - kwargs["messages"] = list(kwargs["messages"]) - messages = kwargs["messages"] - model = kwargs.get("model") - streaming = kwargs.get("stream") + return _sentry_patched_create_sync + + +def _wrap_async_chat_completion_create(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + async def _execute_async(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _new_chat_completion_common(f, *args, **kwargs) - span = sentry_sdk.start_span( - op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE, - name="Chat Completion", - origin=OpenAIIntegration.origin, - ) - span.__enter__() try: - res = f(*args, **kwargs) - except Exception as e: - _capture_exception(e) - span.__exit__(None, None, None) - raise e from None + f, args, kwargs = next(gen) + except StopIteration as e: + return await e.value - with capture_internal_exceptions(): - if should_send_default_pii() and integration.include_prompts: - set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages) - - set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) - set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) - - if hasattr(res, "choices"): - if should_send_default_pii() and integration.include_prompts: - set_data_normalized( - span, - "ai.responses", - list(map(lambda x: x.message, res.choices)), - ) - _calculate_chat_completion_usage( - messages, res, span, None, integration.count_tokens - ) - span.__exit__(None, None, None) - elif hasattr(res, "_iterator"): - data_buf: list[list[str]] = [] # one for each choice - - old_iterator = res._iterator # type: Iterator[ChatCompletionChunk] - - def new_iterator(): - # type: () -> Iterator[ChatCompletionChunk] - with capture_internal_exceptions(): - for x in old_iterator: - if hasattr(x, "choices"): - choice_index = 0 - for choice in x.choices: - if hasattr(choice, "delta") and hasattr( - choice.delta, "content" - ): - content = choice.delta.content - if len(data_buf) <= choice_index: - data_buf.append([]) - data_buf[choice_index].append(content or "") - choice_index += 1 - yield x - if len(data_buf) > 0: - all_responses = list( - map(lambda chunk: "".join(chunk), data_buf) - ) - if ( - should_send_default_pii() - and integration.include_prompts - ): - set_data_normalized( - span, SPANDATA.AI_RESPONSES, all_responses - ) - _calculate_chat_completion_usage( - messages, - res, - span, - all_responses, - integration.count_tokens, - ) - span.__exit__(None, None, None) + try: + try: + result = await f(*args, **kwargs) + except Exception as e: + _capture_exception(e) + raise e from None - res._iterator = new_iterator() - else: - set_data_normalized(span, "unknown_response", True) - span.__exit__(None, None, None) - return res + return gen.send(result) + except StopIteration as e: + return e.value + + @wraps(f) + async def _sentry_patched_create_async(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) + if integration is None or "messages" not in kwargs: + # no "messages" means invalid call (in all versions of openai), let it return error + return await f(*args, **kwargs) + + return await _execute_async(f, *args, **kwargs) + + return _sentry_patched_create_async + + +def _new_embeddings_create_common(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) + if integration is None: + return f(*args, **kwargs) + + with sentry_sdk.start_span( + op=consts.OP.OPENAI_EMBEDDINGS_CREATE, + description="OpenAI Embedding Creation", + origin=OpenAIIntegration.origin, + ) as span: + if "input" in kwargs and ( + should_send_default_pii() and integration.include_prompts + ): + if isinstance(kwargs["input"], str): + set_data_normalized(span, "ai.input_messages", [kwargs["input"]]) + elif ( + isinstance(kwargs["input"], list) + and len(kwargs["input"]) > 0 + and isinstance(kwargs["input"][0], str) + ): + set_data_normalized(span, "ai.input_messages", kwargs["input"]) + if "model" in kwargs: + set_data_normalized(span, "ai.model_id", kwargs["model"]) + + response = yield f, args, kwargs + + prompt_tokens = 0 + total_tokens = 0 + if hasattr(response, "usage"): + if hasattr(response.usage, "prompt_tokens") and isinstance( + response.usage.prompt_tokens, int + ): + prompt_tokens = response.usage.prompt_tokens + if hasattr(response.usage, "total_tokens") and isinstance( + response.usage.total_tokens, int + ): + total_tokens = response.usage.total_tokens + + if prompt_tokens == 0: + prompt_tokens = integration.count_tokens(kwargs["input"] or "") - return new_chat_completion + record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens) + + return response def _wrap_embeddings_create(f): - # type: (Callable[..., Any]) -> Callable[..., Any] + # type: (Any) -> Any + def _execute_sync(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _new_embeddings_create_common(f, *args, **kwargs) + + try: + f, args, kwargs = next(gen) + except StopIteration as e: + return e.value + + try: + try: + result = f(*args, **kwargs) + except Exception as e: + _capture_exception(e) + raise e from None + + return gen.send(result) + except StopIteration as e: + return e.value @wraps(f) - def new_embeddings_create(*args, **kwargs): + def _sentry_patched_create_sync(*args, **kwargs): # type: (*Any, **Any) -> Any integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None: return f(*args, **kwargs) - with sentry_sdk.start_span( - op=consts.OP.OPENAI_EMBEDDINGS_CREATE, - name="OpenAI Embedding Creation", - origin=OpenAIIntegration.origin, - ) as span: - if "input" in kwargs and ( - should_send_default_pii() and integration.include_prompts - ): - if isinstance(kwargs["input"], str): - set_data_normalized(span, "ai.input_messages", [kwargs["input"]]) - elif ( - isinstance(kwargs["input"], list) - and len(kwargs["input"]) > 0 - and isinstance(kwargs["input"][0], str) - ): - set_data_normalized(span, "ai.input_messages", kwargs["input"]) - if "model" in kwargs: - set_data_normalized(span, "ai.model_id", kwargs["model"]) + return _execute_sync(f, *args, **kwargs) + + return _sentry_patched_create_sync + + +def _wrap_async_embeddings_create(f): + # type: (Any) -> Any + async def _execute_async(f, *args, **kwargs): + # type: (Any, *Any, **Any) -> Any + gen = _new_embeddings_create_common(f, *args, **kwargs) + + try: + f, args, kwargs = next(gen) + except StopIteration as e: + return await e.value + + try: try: - response = f(*args, **kwargs) + result = await f(*args, **kwargs) except Exception as e: _capture_exception(e) raise e from None - prompt_tokens = 0 - total_tokens = 0 - if hasattr(response, "usage"): - if hasattr(response.usage, "prompt_tokens") and isinstance( - response.usage.prompt_tokens, int - ): - prompt_tokens = response.usage.prompt_tokens - if hasattr(response.usage, "total_tokens") and isinstance( - response.usage.total_tokens, int - ): - total_tokens = response.usage.total_tokens - - if prompt_tokens == 0: - prompt_tokens = integration.count_tokens(kwargs["input"] or "") + return gen.send(result) + except StopIteration as e: + return e.value - record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens) + @wraps(f) + async def _sentry_patched_create_async(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) + if integration is None: + return await f(*args, **kwargs) - return response + return await _execute_async(f, *args, **kwargs) - return new_embeddings_create + return _sentry_patched_create_async diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py index 7e33ac831d..8ce12e70f5 100644 --- a/tests/integrations/anthropic/test_anthropic.py +++ b/tests/integrations/anthropic/test_anthropic.py @@ -1,7 +1,16 @@ from unittest import mock +try: + from unittest.mock import AsyncMock +except ImportError: + + class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + import pytest -from anthropic import Anthropic, AnthropicError, Stream +from anthropic import AsyncAnthropic, Anthropic, AnthropicError, AsyncStream, Stream from anthropic.types import MessageDeltaUsage, TextDelta, Usage from anthropic.types.content_block_delta_event import ContentBlockDeltaEvent from anthropic.types.content_block_start_event import ContentBlockStartEvent @@ -48,6 +57,11 @@ ) +async def async_iterator(values): + for value in values: + yield value + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [ @@ -115,6 +129,74 @@ def test_nonstreaming_create_message( assert span["data"]["ai.streaming"] is False +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +async def test_nonstreaming_create_message_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[AnthropicIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + client = AsyncAnthropic(api_key="z") + client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE) + + messages = [ + { + "role": "user", + "content": "Hello, Claude", + } + ] + + with start_transaction(name="anthropic"): + response = await client.messages.create( + max_tokens=1024, messages=messages, model="model" + ) + + assert response == EXAMPLE_MESSAGE + usage = response.usage + + assert usage.input_tokens == 10 + assert usage.output_tokens == 20 + + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert event["transaction"] == "anthropic" + + assert len(event["spans"]) == 1 + (span,) = event["spans"] + + assert span["op"] == OP.ANTHROPIC_MESSAGES_CREATE + assert span["description"] == "Anthropic messages create" + assert span["data"][SPANDATA.AI_MODEL_ID] == "model" + + if send_default_pii and include_prompts: + assert span["data"][SPANDATA.AI_INPUT_MESSAGES] == messages + assert span["data"][SPANDATA.AI_RESPONSES] == [ + {"type": "text", "text": "Hi, I'm Claude."} + ] + else: + assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] + assert SPANDATA.AI_RESPONSES not in span["data"] + + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 10 + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 20 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 + assert span["data"]["ai.streaming"] is False + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [ @@ -215,6 +297,109 @@ def test_streaming_create_message( assert span["data"]["ai.streaming"] is True +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +async def test_streaming_create_message_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + client = AsyncAnthropic(api_key="z") + returned_stream = AsyncStream(cast_to=None, response=None, client=client) + returned_stream._iterator = async_iterator( + [ + MessageStartEvent( + message=EXAMPLE_MESSAGE, + type="message_start", + ), + ContentBlockStartEvent( + type="content_block_start", + index=0, + content_block=TextBlock(type="text", text=""), + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="Hi", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text=" I'm Claude!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockStopEvent(type="content_block_stop", index=0), + MessageDeltaEvent( + delta=Delta(), + usage=MessageDeltaUsage(output_tokens=10), + type="message_delta", + ), + ] + ) + + sentry_init( + integrations=[AnthropicIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + client.messages._post = AsyncMock(return_value=returned_stream) + + messages = [ + { + "role": "user", + "content": "Hello, Claude", + } + ] + + with start_transaction(name="anthropic"): + message = await client.messages.create( + max_tokens=1024, messages=messages, model="model", stream=True + ) + + async for _ in message: + pass + + assert message == returned_stream + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert event["transaction"] == "anthropic" + + assert len(event["spans"]) == 1 + (span,) = event["spans"] + + assert span["op"] == OP.ANTHROPIC_MESSAGES_CREATE + assert span["description"] == "Anthropic messages create" + assert span["data"][SPANDATA.AI_MODEL_ID] == "model" + + if send_default_pii and include_prompts: + assert span["data"][SPANDATA.AI_INPUT_MESSAGES] == messages + assert span["data"][SPANDATA.AI_RESPONSES] == [ + {"type": "text", "text": "Hi! I'm Claude!"} + ] + + else: + assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] + assert SPANDATA.AI_RESPONSES not in span["data"] + + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 10 + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 30 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 40 + assert span["data"]["ai.streaming"] is True + + @pytest.mark.skipif( ANTHROPIC_VERSION < (0, 27), reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.", @@ -345,6 +530,143 @@ def test_streaming_create_message_with_input_json_delta( assert span["data"]["ai.streaming"] is True +@pytest.mark.asyncio +@pytest.mark.skipif( + ANTHROPIC_VERSION < (0, 27), + reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.", +) +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +async def test_streaming_create_message_with_input_json_delta_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + client = AsyncAnthropic(api_key="z") + returned_stream = AsyncStream(cast_to=None, response=None, client=client) + returned_stream._iterator = async_iterator( + [ + MessageStartEvent( + message=Message( + id="msg_0", + content=[], + model="claude-3-5-sonnet-20240620", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=366, output_tokens=10), + ), + type="message_start", + ), + ContentBlockStartEvent( + type="content_block_start", + index=0, + content_block=ToolUseBlock( + id="toolu_0", input={}, name="get_weather", type="tool_use" + ), + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json="", type="input_json_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta( + partial_json="{'location':", type="input_json_delta" + ), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json=" 'S", type="input_json_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json="an ", type="input_json_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta( + partial_json="Francisco, C", type="input_json_delta" + ), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json="A'}", type="input_json_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockStopEvent(type="content_block_stop", index=0), + MessageDeltaEvent( + delta=Delta(stop_reason="tool_use", stop_sequence=None), + usage=MessageDeltaUsage(output_tokens=41), + type="message_delta", + ), + ] + ) + + sentry_init( + integrations=[AnthropicIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + client.messages._post = AsyncMock(return_value=returned_stream) + + messages = [ + { + "role": "user", + "content": "What is the weather like in San Francisco?", + } + ] + + with start_transaction(name="anthropic"): + message = await client.messages.create( + max_tokens=1024, messages=messages, model="model", stream=True + ) + + async for _ in message: + pass + + assert message == returned_stream + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert event["transaction"] == "anthropic" + + assert len(event["spans"]) == 1 + (span,) = event["spans"] + + assert span["op"] == OP.ANTHROPIC_MESSAGES_CREATE + assert span["description"] == "Anthropic messages create" + assert span["data"][SPANDATA.AI_MODEL_ID] == "model" + + if send_default_pii and include_prompts: + assert span["data"][SPANDATA.AI_INPUT_MESSAGES] == messages + assert span["data"][SPANDATA.AI_RESPONSES] == [ + {"text": "", "type": "text"} + ] # we do not record InputJSONDelta because it could contain PII + + else: + assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] + assert SPANDATA.AI_RESPONSES not in span["data"] + + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 366 + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 51 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 417 + assert span["data"]["ai.streaming"] is True + + def test_exception_message_create(sentry_init, capture_events): sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -364,6 +686,26 @@ def test_exception_message_create(sentry_init, capture_events): assert event["level"] == "error" +@pytest.mark.asyncio +async def test_exception_message_create_async(sentry_init, capture_events): + sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) + events = capture_events() + + client = AsyncAnthropic(api_key="z") + client.messages._post = AsyncMock( + side_effect=AnthropicError("API rate limit reached") + ) + with pytest.raises(AnthropicError): + await client.messages.create( + model="some-model", + messages=[{"role": "system", "content": "I'm throwing an exception"}], + max_tokens=1024, + ) + + (event,) = events + assert event["level"] == "error" + + def test_span_origin(sentry_init, capture_events): sentry_init( integrations=[AnthropicIntegration()], @@ -388,3 +730,30 @@ def test_span_origin(sentry_init, capture_events): assert event["contexts"]["trace"]["origin"] == "manual" assert event["spans"][0]["origin"] == "auto.ai.anthropic" + + +@pytest.mark.asyncio +async def test_span_origin_async(sentry_init, capture_events): + sentry_init( + integrations=[AnthropicIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + client = AsyncAnthropic(api_key="z") + client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE) + + messages = [ + { + "role": "user", + "content": "Hello, Claude", + } + ] + + with start_transaction(name="anthropic"): + await client.messages.create(max_tokens=1024, messages=messages, model="model") + + (event,) = events + + assert event["contexts"]["trace"]["origin"] == "manual" + assert event["spans"][0]["origin"] == "auto.ai.anthropic" diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index b0ffc9e768..011192e49f 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -1,5 +1,5 @@ import pytest -from openai import OpenAI, Stream, OpenAIError +from openai import AsyncOpenAI, OpenAI, AsyncStream, Stream, OpenAIError from openai.types import CompletionUsage, CreateEmbeddingResponse, Embedding from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionChunk from openai.types.chat.chat_completion import Choice @@ -7,10 +7,21 @@ from openai.types.create_embedding_response import Usage as EmbeddingTokenUsage from sentry_sdk import start_transaction -from sentry_sdk.integrations.openai import OpenAIIntegration +from sentry_sdk.integrations.openai import ( + OpenAIIntegration, + _calculate_chat_completion_usage, +) from unittest import mock # python 3.3 and above +try: + from unittest.mock import AsyncMock +except ImportError: + + class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + EXAMPLE_CHAT_COMPLETION = ChatCompletion( id="chat-id", @@ -34,6 +45,11 @@ ) +async def async_iterator(values): + for value in values: + yield value + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], @@ -78,6 +94,48 @@ def test_nonstreaming_chat_completion( assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_nonstreaming_chat_completion_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + client.chat.completions._post = AsyncMock(return_value=EXAMPLE_CHAT_COMPLETION) + + with start_transaction(name="openai tx"): + response = await client.chat.completions.create( + model="some-model", messages=[{"role": "system", "content": "hello"}] + ) + response = response.choices[0].message.content + + assert response == "the model response" + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "ai.chat_completions.create.openai" + + if send_default_pii and include_prompts: + assert "hello" in span["data"]["ai.input_messages"]["content"] + assert "the model response" in span["data"]["ai.responses"]["content"] + else: + assert "ai.input_messages" not in span["data"] + assert "ai.responses" not in span["data"] + + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 10 + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 20 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 + + def tiktoken_encoding_if_installed(): try: import tiktoken # type: ignore # noqa # pylint: disable=unused-import @@ -176,6 +234,102 @@ def test_streaming_chat_completion( pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly +# noinspection PyTypeChecker +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_streaming_chat_completion_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[ + OpenAIIntegration( + include_prompts=include_prompts, + tiktoken_encoding_name=tiktoken_encoding_if_installed(), + ) + ], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + returned_stream = AsyncStream(cast_to=None, response=None, client=client) + returned_stream._iterator = async_iterator( + [ + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=0, delta=ChoiceDelta(content="hel"), finish_reason=None + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=1, delta=ChoiceDelta(content="lo "), finish_reason=None + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=2, + delta=ChoiceDelta(content="world"), + finish_reason="stop", + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ] + ) + + client.chat.completions._post = AsyncMock(return_value=returned_stream) + with start_transaction(name="openai tx"): + response_stream = await client.chat.completions.create( + model="some-model", messages=[{"role": "system", "content": "hello"}] + ) + + response_string = "" + async for x in response_stream: + response_string += x.choices[0].delta.content + + assert response_string == "hello world" + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "ai.chat_completions.create.openai" + + if send_default_pii and include_prompts: + assert "hello" in span["data"]["ai.input_messages"]["content"] + assert "hello world" in span["data"]["ai.responses"] + else: + assert "ai.input_messages" not in span["data"] + assert "ai.responses" not in span["data"] + + try: + import tiktoken # type: ignore # noqa # pylint: disable=unused-import + + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 2 + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 1 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 3 + except ImportError: + pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly + + def test_bad_chat_completion(sentry_init, capture_events): sentry_init(integrations=[OpenAIIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -193,6 +347,24 @@ def test_bad_chat_completion(sentry_init, capture_events): assert event["level"] == "error" +@pytest.mark.asyncio +async def test_bad_chat_completion_async(sentry_init, capture_events): + sentry_init(integrations=[OpenAIIntegration()], traces_sample_rate=1.0) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + client.chat.completions._post = AsyncMock( + side_effect=OpenAIError("API rate limit reached") + ) + with pytest.raises(OpenAIError): + await client.chat.completions.create( + model="some-model", messages=[{"role": "system", "content": "hello"}] + ) + + (event,) = events + assert event["level"] == "error" + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], @@ -240,6 +412,109 @@ def test_embeddings_create( assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_embeddings_create_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + + returned_embedding = CreateEmbeddingResponse( + data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])], + model="some-model", + object="list", + usage=EmbeddingTokenUsage( + prompt_tokens=20, + total_tokens=30, + ), + ) + + client.embeddings._post = AsyncMock(return_value=returned_embedding) + with start_transaction(name="openai tx"): + response = await client.embeddings.create( + input="hello", model="text-embedding-3-large" + ) + + assert len(response.data[0].embedding) == 3 + + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "ai.embeddings.create.openai" + if send_default_pii and include_prompts: + assert "hello" in span["data"]["ai.input_messages"] + else: + assert "ai.input_messages" not in span["data"] + + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 20 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 + + +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_embeddings_create_raises_error( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = OpenAI(api_key="z") + + client.embeddings._post = mock.Mock( + side_effect=OpenAIError("API rate limit reached") + ) + + with pytest.raises(OpenAIError): + client.embeddings.create(input="hello", model="text-embedding-3-large") + + (event,) = events + assert event["level"] == "error" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_embeddings_create_raises_error_async( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[OpenAIIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + + client.embeddings._post = AsyncMock( + side_effect=OpenAIError("API rate limit reached") + ) + + with pytest.raises(OpenAIError): + await client.embeddings.create(input="hello", model="text-embedding-3-large") + + (event,) = events + assert event["level"] == "error" + + def test_span_origin_nonstreaming_chat(sentry_init, capture_events): sentry_init( integrations=[OpenAIIntegration()], @@ -261,6 +536,28 @@ def test_span_origin_nonstreaming_chat(sentry_init, capture_events): assert event["spans"][0]["origin"] == "auto.ai.openai" +@pytest.mark.asyncio +async def test_span_origin_nonstreaming_chat_async(sentry_init, capture_events): + sentry_init( + integrations=[OpenAIIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + client.chat.completions._post = AsyncMock(return_value=EXAMPLE_CHAT_COMPLETION) + + with start_transaction(name="openai tx"): + await client.chat.completions.create( + model="some-model", messages=[{"role": "system", "content": "hello"}] + ) + + (event,) = events + + assert event["contexts"]["trace"]["origin"] == "manual" + assert event["spans"][0]["origin"] == "auto.ai.openai" + + def test_span_origin_streaming_chat(sentry_init, capture_events): sentry_init( integrations=[OpenAIIntegration()], @@ -311,6 +608,7 @@ def test_span_origin_streaming_chat(sentry_init, capture_events): response_stream = client.chat.completions.create( model="some-model", messages=[{"role": "system", "content": "hello"}] ) + "".join(map(lambda x: x.choices[0].delta.content, response_stream)) (event,) = events @@ -319,6 +617,72 @@ def test_span_origin_streaming_chat(sentry_init, capture_events): assert event["spans"][0]["origin"] == "auto.ai.openai" +@pytest.mark.asyncio +async def test_span_origin_streaming_chat_async(sentry_init, capture_events): + sentry_init( + integrations=[OpenAIIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + returned_stream = AsyncStream(cast_to=None, response=None, client=client) + returned_stream._iterator = async_iterator( + [ + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=0, delta=ChoiceDelta(content="hel"), finish_reason=None + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=1, delta=ChoiceDelta(content="lo "), finish_reason=None + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ChatCompletionChunk( + id="1", + choices=[ + DeltaChoice( + index=2, + delta=ChoiceDelta(content="world"), + finish_reason="stop", + ) + ], + created=100000, + model="model-id", + object="chat.completion.chunk", + ), + ] + ) + + client.chat.completions._post = AsyncMock(return_value=returned_stream) + with start_transaction(name="openai tx"): + response_stream = await client.chat.completions.create( + model="some-model", messages=[{"role": "system", "content": "hello"}] + ) + async for _ in response_stream: + pass + + # "".join(map(lambda x: x.choices[0].delta.content, response_stream)) + + (event,) = events + + assert event["contexts"]["trace"]["origin"] == "manual" + assert event["spans"][0]["origin"] == "auto.ai.openai" + + def test_span_origin_embeddings(sentry_init, capture_events): sentry_init( integrations=[OpenAIIntegration()], @@ -346,3 +710,154 @@ def test_span_origin_embeddings(sentry_init, capture_events): assert event["contexts"]["trace"]["origin"] == "manual" assert event["spans"][0]["origin"] == "auto.ai.openai" + + +@pytest.mark.asyncio +async def test_span_origin_embeddings_async(sentry_init, capture_events): + sentry_init( + integrations=[OpenAIIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + client = AsyncOpenAI(api_key="z") + + returned_embedding = CreateEmbeddingResponse( + data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])], + model="some-model", + object="list", + usage=EmbeddingTokenUsage( + prompt_tokens=20, + total_tokens=30, + ), + ) + + client.embeddings._post = AsyncMock(return_value=returned_embedding) + with start_transaction(name="openai tx"): + await client.embeddings.create(input="hello", model="text-embedding-3-large") + + (event,) = events + + assert event["contexts"]["trace"]["origin"] == "manual" + assert event["spans"][0]["origin"] == "auto.ai.openai" + + +def test_calculate_chat_completion_usage_a(): + span = mock.MagicMock() + + def count_tokens(msg): + return len(str(msg)) + + response = mock.MagicMock() + response.usage = mock.MagicMock() + response.usage.completion_tokens = 10 + response.usage.prompt_tokens = 20 + response.usage.total_tokens = 30 + messages = [] + streaming_message_responses = [] + + with mock.patch( + "sentry_sdk.integrations.openai.record_token_usage" + ) as mock_record_token_usage: + _calculate_chat_completion_usage( + messages, response, span, streaming_message_responses, count_tokens + ) + mock_record_token_usage.assert_called_once_with(span, 20, 10, 30) + + +def test_calculate_chat_completion_usage_b(): + span = mock.MagicMock() + + def count_tokens(msg): + return len(str(msg)) + + response = mock.MagicMock() + response.usage = mock.MagicMock() + response.usage.completion_tokens = 10 + response.usage.total_tokens = 10 + messages = [ + {"content": "one"}, + {"content": "two"}, + {"content": "three"}, + ] + streaming_message_responses = [] + + with mock.patch( + "sentry_sdk.integrations.openai.record_token_usage" + ) as mock_record_token_usage: + _calculate_chat_completion_usage( + messages, response, span, streaming_message_responses, count_tokens + ) + mock_record_token_usage.assert_called_once_with(span, 11, 10, 10) + + +def test_calculate_chat_completion_usage_c(): + span = mock.MagicMock() + + def count_tokens(msg): + return len(str(msg)) + + response = mock.MagicMock() + response.usage = mock.MagicMock() + response.usage.prompt_tokens = 20 + response.usage.total_tokens = 20 + messages = [] + streaming_message_responses = [ + "one", + "two", + "three", + ] + + with mock.patch( + "sentry_sdk.integrations.openai.record_token_usage" + ) as mock_record_token_usage: + _calculate_chat_completion_usage( + messages, response, span, streaming_message_responses, count_tokens + ) + mock_record_token_usage.assert_called_once_with(span, 20, 11, 20) + + +def test_calculate_chat_completion_usage_d(): + span = mock.MagicMock() + + def count_tokens(msg): + return len(str(msg)) + + response = mock.MagicMock() + response.usage = mock.MagicMock() + response.usage.prompt_tokens = 20 + response.usage.total_tokens = 20 + response.choices = [ + mock.MagicMock(message="one"), + mock.MagicMock(message="two"), + mock.MagicMock(message="three"), + ] + messages = [] + streaming_message_responses = [] + + with mock.patch( + "sentry_sdk.integrations.openai.record_token_usage" + ) as mock_record_token_usage: + _calculate_chat_completion_usage( + messages, response, span, streaming_message_responses, count_tokens + ) + mock_record_token_usage.assert_called_once_with(span, 20, None, 20) + + +def test_calculate_chat_completion_usage_e(): + span = mock.MagicMock() + + def count_tokens(msg): + return len(str(msg)) + + response = mock.MagicMock() + messages = [] + streaming_message_responses = None + + with mock.patch( + "sentry_sdk.integrations.openai.record_token_usage" + ) as mock_record_token_usage: + _calculate_chat_completion_usage( + messages, response, span, streaming_message_responses, count_tokens + ) + mock_record_token_usage.assert_called_once_with(span, None, None, None) diff --git a/tox.ini b/tox.ini index 0302c3ebb7..a90a7fa248 100644 --- a/tox.ini +++ b/tox.ini @@ -316,6 +316,7 @@ deps = aiohttp-latest: pytest-asyncio # Anthropic + anthropic: pytest-asyncio anthropic-v0.25: anthropic~=0.25.0 anthropic-v0.16: anthropic~=0.16.0 anthropic-latest: anthropic @@ -532,6 +533,7 @@ deps = loguru-latest: loguru # OpenAI + openai: pytest-asyncio openai-v1: openai~=1.0.0 openai-v1: tiktoken~=0.6.0 openai-latest: openai