Skip to content

Commit 0047a68

Browse files
fateleiDouweM
andauthored
Add ModelResponse.finish_reason and set provider_response_id while streaming (#2590)
Co-authored-by: Douwe Maan <[email protected]>
1 parent f724d01 commit 0047a68

20 files changed

+499
-36
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@
5252
DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
5353
VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
5454

55+
FinishReason: TypeAlias = Literal[
56+
'stop',
57+
'length',
58+
'content_filter',
59+
'tool_call',
60+
'error',
61+
]
62+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
63+
5564

5665
@dataclass(repr=False)
5766
class SystemPromptPart:
@@ -1032,6 +1041,9 @@ class ModelResponse:
10321041
] = None
10331042
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
10341043

1044+
finish_reason: FinishReason | None = None
1045+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
1046+
10351047
@deprecated('`price` is deprecated, use `cost` instead')
10361048
def price(self) -> genai_types.PriceCalculation: # pragma: no cover
10371049
return self.cost()

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..messages import (
2929
FileUrl,
3030
FinalResultEvent,
31+
FinishReason,
3132
ModelMessage,
3233
ModelRequest,
3334
ModelResponse,
@@ -555,6 +556,10 @@ class StreamedResponse(ABC):
555556

556557
final_result_event: FinalResultEvent | None = field(default=None, init=False)
557558

559+
provider_response_id: str | None = field(default=None, init=False)
560+
provider_details: dict[str, Any] | None = field(default=None, init=False)
561+
finish_reason: FinishReason | None = field(default=None, init=False)
562+
558563
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
559564
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
560565
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
@@ -609,6 +614,9 @@ def get(self) -> ModelResponse:
609614
timestamp=self.timestamp,
610615
usage=self.usage(),
611616
provider_name=self.provider_name,
617+
provider_response_id=self.provider_response_id,
618+
provider_details=self.provider_details,
619+
finish_reason=self.finish_reason,
612620
)
613621

614622
def usage(self) -> RequestUsage:

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BuiltinToolCallPart,
2222
BuiltinToolReturnPart,
2323
DocumentUrl,
24+
FinishReason,
2425
ImageUrl,
2526
ModelMessage,
2627
ModelRequest,
@@ -42,6 +43,16 @@
4243
from ..tools import ToolDefinition
4344
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
4445

46+
_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
47+
'end_turn': 'stop',
48+
'max_tokens': 'length',
49+
'stop_sequence': 'stop',
50+
'tool_use': 'tool_call',
51+
'pause_turn': 'stop',
52+
'refusal': 'content_filter',
53+
}
54+
55+
4556
try:
4657
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
4758
from anthropic.types.beta import (
@@ -70,6 +81,7 @@
7081
BetaServerToolUseBlock,
7182
BetaServerToolUseBlockParam,
7283
BetaSignatureDelta,
84+
BetaStopReason,
7385
BetaTextBlock,
7486
BetaTextBlockParam,
7587
BetaTextDelta,
@@ -326,12 +338,20 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
326338
)
327339
)
328340

341+
finish_reason: FinishReason | None = None
342+
provider_details: dict[str, Any] | None = None
343+
if raw_finish_reason := response.stop_reason: # pragma: no branch
344+
provider_details = {'finish_reason': raw_finish_reason}
345+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
346+
329347
return ModelResponse(
330348
parts=items,
331349
usage=_map_usage(response),
332350
model_name=response.model,
333351
provider_response_id=response.id,
334352
provider_name=self._provider.name,
353+
finish_reason=finish_reason,
354+
provider_details=provider_details,
335355
)
336356

337357
async def _process_streamed_response(
@@ -583,6 +603,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
583603
async for event in self._response:
584604
if isinstance(event, BetaRawMessageStartEvent):
585605
self._usage = _map_usage(event)
606+
self.provider_response_id = event.message.id
586607

587608
elif isinstance(event, BetaRawContentBlockStartEvent):
588609
current_block = event.content_block
@@ -646,6 +667,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
646667

647668
elif isinstance(event, BetaRawMessageDeltaEvent):
648669
self._usage = _map_usage(event)
670+
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
671+
self.provider_details = {'finish_reason': raw_finish_reason}
672+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
649673

650674
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
651675
current_block = None

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
2222
FileUrl,
23+
FinishReason,
2324
ModelMessage,
2425
ModelRequest,
2526
ModelResponse,
@@ -54,6 +55,7 @@
5455
ContentUnionDict,
5556
CountTokensConfigDict,
5657
ExecutableCodeDict,
58+
FinishReason as GoogleFinishReason,
5759
FunctionCallDict,
5860
FunctionCallingConfigDict,
5961
FunctionCallingConfigMode,
@@ -99,6 +101,22 @@
99101
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
100102
"""
101103

104+
_FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
105+
GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
106+
GoogleFinishReason.STOP: 'stop',
107+
GoogleFinishReason.MAX_TOKENS: 'length',
108+
GoogleFinishReason.SAFETY: 'content_filter',
109+
GoogleFinishReason.RECITATION: 'content_filter',
110+
GoogleFinishReason.LANGUAGE: 'error',
111+
GoogleFinishReason.OTHER: None,
112+
GoogleFinishReason.BLOCKLIST: 'content_filter',
113+
GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
114+
GoogleFinishReason.SPII: 'content_filter',
115+
GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
116+
GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
117+
GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
118+
}
119+
102120

103121
class GoogleModelSettings(ModelSettings, total=False):
104122
"""Settings used for a Gemini model request."""
@@ -403,11 +421,14 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
403421
'Content field missing from Gemini response', str(response)
404422
) # pragma: no cover
405423
parts = candidate.content.parts or []
406-
vendor_id = response.response_id or None
424+
425+
vendor_id = response.response_id
407426
vendor_details: dict[str, Any] | None = None
408-
finish_reason = candidate.finish_reason
409-
if finish_reason: # pragma: no branch
410-
vendor_details = {'finish_reason': finish_reason.value}
427+
finish_reason: FinishReason | None = None
428+
if raw_finish_reason := candidate.finish_reason: # pragma: no branch
429+
vendor_details = {'finish_reason': raw_finish_reason.value}
430+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
431+
411432
usage = _metadata_as_usage(response)
412433
return _process_response_from_parts(
413434
parts,
@@ -416,6 +437,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
416437
usage,
417438
vendor_id=vendor_id,
418439
vendor_details=vendor_details,
440+
finish_reason=finish_reason,
419441
)
420442

421443
async def _process_streamed_response(
@@ -550,6 +572,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
550572

551573
assert chunk.candidates is not None
552574
candidate = chunk.candidates[0]
575+
576+
if chunk.response_id: # pragma: no branch
577+
self.provider_response_id = chunk.response_id
578+
579+
if raw_finish_reason := candidate.finish_reason:
580+
self.provider_details = {'finish_reason': raw_finish_reason.value}
581+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
582+
553583
if candidate.content is None or candidate.content.parts is None:
554584
if candidate.finish_reason == 'STOP': # pragma: no cover
555585
# Normal completion - skip this chunk
@@ -632,6 +662,7 @@ def _process_response_from_parts(
632662
usage: usage.RequestUsage,
633663
vendor_id: str | None,
634664
vendor_details: dict[str, Any] | None = None,
665+
finish_reason: FinishReason | None = None,
635666
) -> ModelResponse:
636667
items: list[ModelResponsePart] = []
637668
for part in parts:
@@ -672,6 +703,7 @@ def _process_response_from_parts(
672703
provider_response_id=vendor_id,
673704
provider_details=vendor_details,
674705
provider_name=provider_name,
706+
finish_reason=finish_reason,
675707
)
676708

677709

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,10 @@ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_
221221
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
222222
)
223223
elif isinstance(message, ModelResponse): # pragma: no branch
224-
result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
224+
otel_message = _otel_messages.OutputMessage(role='assistant', parts=message.otel_message_parts(self))
225+
if message.finish_reason is not None:
226+
otel_message['finish_reason'] = message.finish_reason
227+
result.append(otel_message)
225228
return result
226229

227230
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
@@ -246,12 +249,10 @@ def handle_messages(self, input_messages: list[ModelMessage], response: ModelRes
246249
else:
247250
output_messages = self.messages_to_otel_messages([response])
248251
assert len(output_messages) == 1
249-
output_message = cast(_otel_messages.OutputMessage, output_messages[0])
250-
if response.provider_details and 'finish_reason' in response.provider_details:
251-
output_message['finish_reason'] = response.provider_details['finish_reason']
252+
output_message = output_messages[0]
252253
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
253254
system_instructions_attributes = self.system_instructions_attributes(instructions)
254-
attributes = {
255+
attributes: dict[str, AttributeValue] = {
255256
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
256257
'gen_ai.output.messages': json.dumps([output_message]),
257258
**system_instructions_attributes,
@@ -436,6 +437,8 @@ def _record_metrics():
436437
)
437438
if response.provider_response_id is not None:
438439
attributes_to_set['gen_ai.response.id'] = response.provider_response_id
440+
if response.finish_reason is not None:
441+
attributes_to_set['gen_ai.response.finish_reasons'] = [response.finish_reason]
439442
span.set_attributes(attributes_to_set)
440443
span.update_name(f'{operation} {request_model}')
441444

0 commit comments

Comments
 (0)