Skip to content

Commit 9ae5820

Browse files
vetyyantonpirker
andauthored
Add support for async calls in Anthropic and OpenAI integration (#3497)
--------- Co-authored-by: Anton Pirker <[email protected]>
1 parent 891afee commit 9ae5820

File tree

5 files changed

+1366
-209
lines changed

5 files changed

+1366
-209
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 190 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
)
1414

1515
try:
16-
from anthropic.resources import Messages
16+
from anthropic.resources import AsyncMessages, Messages
1717

1818
if TYPE_CHECKING:
1919
from anthropic.types import MessageStreamEvent
2020
except ImportError:
2121
raise DidNotEnable("Anthropic not installed")
2222

23-
2423
if TYPE_CHECKING:
25-
from typing import Any, Iterator
24+
from typing import Any, AsyncIterator, Iterator
2625
from sentry_sdk.tracing import Span
2726

2827

@@ -46,6 +45,7 @@ def setup_once():
4645
raise DidNotEnable("anthropic 0.16 or newer required.")
4746

4847
Messages.create = _wrap_message_create(Messages.create)
48+
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
4949

5050

5151
def _capture_exception(exc):
@@ -75,7 +75,9 @@ def _calculate_token_usage(result, span):
7575

7676
def _get_responses(content):
7777
# type: (list[Any]) -> list[dict[str, Any]]
78-
"""Get JSON of a Anthropic responses."""
78+
"""
79+
Get JSON of a Anthropic responses.
80+
"""
7981
responses = []
8082
for item in content:
8183
if hasattr(item, "text"):
@@ -88,94 +90,202 @@ def _get_responses(content):
8890
return responses
8991

9092

93+
def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
94+
# type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]]
95+
"""
96+
Count token usage and collect content blocks from the AI streaming response.
97+
"""
98+
with capture_internal_exceptions():
99+
if hasattr(event, "type"):
100+
if event.type == "message_start":
101+
usage = event.message.usage
102+
input_tokens += usage.input_tokens
103+
output_tokens += usage.output_tokens
104+
elif event.type == "content_block_start":
105+
pass
106+
elif event.type == "content_block_delta":
107+
if hasattr(event.delta, "text"):
108+
content_blocks.append(event.delta.text)
109+
elif event.type == "content_block_stop":
110+
pass
111+
elif event.type == "message_delta":
112+
output_tokens += event.usage.output_tokens
113+
114+
return input_tokens, output_tokens, content_blocks
115+
116+
117+
def _add_ai_data_to_span(
118+
span, integration, input_tokens, output_tokens, content_blocks
119+
):
120+
# type: (Span, AnthropicIntegration, int, int, list[str]) -> None
121+
"""
122+
Add token usage and content blocks from the AI streaming response to the span.
123+
"""
124+
with capture_internal_exceptions():
125+
if should_send_default_pii() and integration.include_prompts:
126+
complete_message = "".join(content_blocks)
127+
span.set_data(
128+
SPANDATA.AI_RESPONSES,
129+
[{"type": "text", "text": complete_message}],
130+
)
131+
total_tokens = input_tokens + output_tokens
132+
record_token_usage(span, input_tokens, output_tokens, total_tokens)
133+
span.set_data(SPANDATA.AI_STREAMING, True)
134+
135+
136+
def _sentry_patched_create_common(f, *args, **kwargs):
137+
# type: (Any, *Any, **Any) -> Any
138+
integration = kwargs.pop("integration")
139+
if integration is None:
140+
return f(*args, **kwargs)
141+
142+
if "messages" not in kwargs:
143+
return f(*args, **kwargs)
144+
145+
try:
146+
iter(kwargs["messages"])
147+
except TypeError:
148+
return f(*args, **kwargs)
149+
150+
span = sentry_sdk.start_span(
151+
op=OP.ANTHROPIC_MESSAGES_CREATE,
152+
description="Anthropic messages create",
153+
origin=AnthropicIntegration.origin,
154+
)
155+
span.__enter__()
156+
157+
result = yield f, args, kwargs
158+
159+
# add data to span and finish it
160+
messages = list(kwargs["messages"])
161+
model = kwargs.get("model")
162+
163+
with capture_internal_exceptions():
164+
span.set_data(SPANDATA.AI_MODEL_ID, model)
165+
span.set_data(SPANDATA.AI_STREAMING, False)
166+
167+
if should_send_default_pii() and integration.include_prompts:
168+
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)
169+
170+
if hasattr(result, "content"):
171+
if should_send_default_pii() and integration.include_prompts:
172+
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
173+
_calculate_token_usage(result, span)
174+
span.__exit__(None, None, None)
175+
176+
# Streaming response
177+
elif hasattr(result, "_iterator"):
178+
old_iterator = result._iterator
179+
180+
def new_iterator():
181+
# type: () -> Iterator[MessageStreamEvent]
182+
input_tokens = 0
183+
output_tokens = 0
184+
content_blocks = [] # type: list[str]
185+
186+
for event in old_iterator:
187+
input_tokens, output_tokens, content_blocks = _collect_ai_data(
188+
event, input_tokens, output_tokens, content_blocks
189+
)
190+
if event.type != "message_stop":
191+
yield event
192+
193+
_add_ai_data_to_span(
194+
span, integration, input_tokens, output_tokens, content_blocks
195+
)
196+
span.__exit__(None, None, None)
197+
198+
async def new_iterator_async():
199+
# type: () -> AsyncIterator[MessageStreamEvent]
200+
input_tokens = 0
201+
output_tokens = 0
202+
content_blocks = [] # type: list[str]
203+
204+
async for event in old_iterator:
205+
input_tokens, output_tokens, content_blocks = _collect_ai_data(
206+
event, input_tokens, output_tokens, content_blocks
207+
)
208+
if event.type != "message_stop":
209+
yield event
210+
211+
_add_ai_data_to_span(
212+
span, integration, input_tokens, output_tokens, content_blocks
213+
)
214+
span.__exit__(None, None, None)
215+
216+
if str(type(result._iterator)) == "<class 'async_generator'>":
217+
result._iterator = new_iterator_async()
218+
else:
219+
result._iterator = new_iterator()
220+
221+
else:
222+
span.set_data("unknown_response", True)
223+
span.__exit__(None, None, None)
224+
225+
return result
226+
227+
91228
def _wrap_message_create(f):
92229
# type: (Any) -> Any
230+
def _execute_sync(f, *args, **kwargs):
231+
# type: (Any, *Any, **Any) -> Any
232+
gen = _sentry_patched_create_common(f, *args, **kwargs)
233+
234+
try:
235+
f, args, kwargs = next(gen)
236+
except StopIteration as e:
237+
return e.value
238+
239+
try:
240+
try:
241+
result = f(*args, **kwargs)
242+
except Exception as exc:
243+
_capture_exception(exc)
244+
raise exc from None
245+
246+
return gen.send(result)
247+
except StopIteration as e:
248+
return e.value
249+
93250
@wraps(f)
94-
def _sentry_patched_create(*args, **kwargs):
251+
def _sentry_patched_create_sync(*args, **kwargs):
95252
# type: (*Any, **Any) -> Any
96253
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
254+
kwargs["integration"] = integration
97255

98-
if integration is None or "messages" not in kwargs:
99-
return f(*args, **kwargs)
256+
return _execute_sync(f, *args, **kwargs)
100257

101-
try:
102-
iter(kwargs["messages"])
103-
except TypeError:
104-
return f(*args, **kwargs)
258+
return _sentry_patched_create_sync
105259

106-
messages = list(kwargs["messages"])
107-
model = kwargs.get("model")
108260

109-
span = sentry_sdk.start_span(
110-
op=OP.ANTHROPIC_MESSAGES_CREATE,
111-
name="Anthropic messages create",
112-
origin=AnthropicIntegration.origin,
113-
)
114-
span.__enter__()
261+
def _wrap_message_create_async(f):
262+
# type: (Any) -> Any
263+
async def _execute_async(f, *args, **kwargs):
264+
# type: (Any, *Any, **Any) -> Any
265+
gen = _sentry_patched_create_common(f, *args, **kwargs)
115266

116267
try:
117-
result = f(*args, **kwargs)
118-
except Exception as exc:
119-
_capture_exception(exc)
120-
span.__exit__(None, None, None)
121-
raise exc from None
268+
f, args, kwargs = next(gen)
269+
except StopIteration as e:
270+
return await e.value
122271

123-
with capture_internal_exceptions():
124-
span.set_data(SPANDATA.AI_MODEL_ID, model)
125-
span.set_data(SPANDATA.AI_STREAMING, False)
126-
if should_send_default_pii() and integration.include_prompts:
127-
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)
128-
if hasattr(result, "content"):
129-
if should_send_default_pii() and integration.include_prompts:
130-
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
131-
_calculate_token_usage(result, span)
132-
span.__exit__(None, None, None)
133-
elif hasattr(result, "_iterator"):
134-
old_iterator = result._iterator
135-
136-
def new_iterator():
137-
# type: () -> Iterator[MessageStreamEvent]
138-
input_tokens = 0
139-
output_tokens = 0
140-
content_blocks = []
141-
with capture_internal_exceptions():
142-
for event in old_iterator:
143-
if hasattr(event, "type"):
144-
if event.type == "message_start":
145-
usage = event.message.usage
146-
input_tokens += usage.input_tokens
147-
output_tokens += usage.output_tokens
148-
elif event.type == "content_block_start":
149-
pass
150-
elif event.type == "content_block_delta":
151-
if hasattr(event.delta, "text"):
152-
content_blocks.append(event.delta.text)
153-
elif event.type == "content_block_stop":
154-
pass
155-
elif event.type == "message_delta":
156-
output_tokens += event.usage.output_tokens
157-
elif event.type == "message_stop":
158-
continue
159-
yield event
160-
161-
if should_send_default_pii() and integration.include_prompts:
162-
complete_message = "".join(content_blocks)
163-
span.set_data(
164-
SPANDATA.AI_RESPONSES,
165-
[{"type": "text", "text": complete_message}],
166-
)
167-
total_tokens = input_tokens + output_tokens
168-
record_token_usage(
169-
span, input_tokens, output_tokens, total_tokens
170-
)
171-
span.set_data(SPANDATA.AI_STREAMING, True)
172-
span.__exit__(None, None, None)
272+
try:
273+
try:
274+
result = await f(*args, **kwargs)
275+
except Exception as exc:
276+
_capture_exception(exc)
277+
raise exc from None
173278

174-
result._iterator = new_iterator()
175-
else:
176-
span.set_data("unknown_response", True)
177-
span.__exit__(None, None, None)
279+
return gen.send(result)
280+
except StopIteration as e:
281+
return e.value
282+
283+
@wraps(f)
284+
async def _sentry_patched_create_async(*args, **kwargs):
285+
# type: (*Any, **Any) -> Any
286+
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
287+
kwargs["integration"] = integration
178288

179-
return result
289+
return await _execute_async(f, *args, **kwargs)
180290

181-
return _sentry_patched_create
291+
return _sentry_patched_create_async

0 commit comments

Comments
 (0)