diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 41d6357994..d69342bc6b 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -13,9 +13,11 @@ from __future__ import annotations as _annotations -from collections.abc import Hashable +from collections.abc import Generator, Hashable from dataclasses import dataclass, field, replace -from typing import Any +from typing import Any, Generic, Literal, TypeVar, cast + +from pydantic import BaseModel, model_validator from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -36,9 +38,11 @@ VendorId = Hashable """ -Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.) +Type alias for a vendor part identifier, which can be any hashable type (e.g., a string, UUID, etc.) """ +ThinkingTags = tuple[str, str] + ManagedPart = ModelResponsePart | ToolCallPartDelta """ A union of types that are managed by the ModelResponsePartsManager. @@ -46,6 +50,159 @@ this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's. """ +PartT = TypeVar('PartT', bound=ModelResponsePart) + + +@dataclass +class _ExistingPart(Generic[PartT]): + part: PartT + index: int + found_by: Literal['vendor_part_id', 'latest_part'] + + +def suffix_prefix_overlap(s1: str, s2: str) -> int: + """Return the length of the longest suffix of s1 that is a prefix of s2.""" + n = min(len(s1), len(s2)) + for k in range(n, 0, -1): + if s1.endswith(s2[:k]): + return k + return 0 + + +class PartialThinkingTag(BaseModel, validate_assignment=True): + respective_tag: str + buffer: str = '' + previous_part_index: int + vendor_part_id: VendorId | None = None + + @model_validator(mode='after') + def validate_buffer(self) -> PartialThinkingTag: + if not self.respective_tag.startswith(self.buffer): # pragma: no cover + raise ValueError(f"Buffer '{self.buffer}' does not match the start of tag '{self.respective_tag}'") + return self + + @property + def expected_next(self) -> str: + return self.respective_tag[len(self.buffer) :] + + @property + def is_complete(self) -> bool: + return self.buffer == self.respective_tag + + @property + def has_previous_part(self) -> bool: + return self.previous_part_index >= 0 + + +@dataclass +class StartTagValidation: + flushed_buffer: str = '' + """Any buffered content that was flushed because the tag was invalid.""" + + thinking_content: str = '' + """Any content following the valid opening tag.""" + + +class PartialStartTag(PartialThinkingTag): + def validate_new_content(self, new_content: str) -> StartTagValidation: + combined = self.buffer + new_content + if combined.startswith(self.respective_tag): + # combined = 'content' + self.buffer = combined[: len(self.respective_tag)] + thinking_content = combined[len(self.respective_tag) :] + return StartTagValidation(thinking_content=thinking_content) + elif self.respective_tag.startswith(combined): + # combined = '' - buffer new_content, flush old buffer - handles stutter + flushed_buffer = self.buffer + self.buffer = new_content + return StartTagValidation(flushed_buffer=flushed_buffer) + elif new_content.startswith(self.respective_tag): + # new_content = 'content' + flushed_buffer = self.buffer + self.buffer = new_content[: len(self.respective_tag)] + thinking_content = new_content[len(self.respective_tag) :] + return StartTagValidation(flushed_buffer=flushed_buffer, thinking_content=thinking_content) + else: + self.buffer = '' + return StartTagValidation(flushed_buffer=combined) + + +@dataclass +class EndTagValidation: + content_before_closed: str = '' + """Any content before the tag was closed.""" + + content_after_closed: str = '' + """Any content remaining after the tag was closed.""" + + +class PartialEndTag(PartialThinkingTag): + """A partial end tag that tracks the closing of a thinking part. + + A PartialEndTag is created when an opening thinking tag completes (e.g., after seeing ``). + PartialEndTags are tracked in `_partial_tags_list` by their vendor_part_id and previous_part_index fields. + + The PartialEndTag.previous_part_index initially inherits from the preceding PartialStartTag, + which may be -1 (if `` was first content) or a TextPart index. + + If content follows the opening tag, a ThinkingPart is created and previous_part_index is updated to point to it. + + Lifecycle: + - Empty thinking (``): PartialEndTag removed, no ThinkingPart created, no event emitted + - Normal completion: PartialEndTag removed when closing tag completes + - Stream ends with buffer: Buffered content (e.g., ` str: + """Return buffered content for flushing. + + - if no ThinkingPart was emitted (delayed thinking), include opening tag. + - if ThinkingPart was emitted, only return closing tag buffer. + """ + if self.thinking_was_emitted: + return self.buffer + else: + return self.respective_opening_tag + self.buffer + + def validate_new_content(self, new_content: str, trim_whitespace: bool = False) -> EndTagValidation: + if trim_whitespace and not self.has_previous_part: # pragma: no cover + new_content = new_content.lstrip() + + if not new_content: + return EndTagValidation() + combined = self.buffer + new_content + + # check if the complete closing tag appears in combined + if self.respective_tag in combined: + self.buffer = self.respective_tag + content_before_closed, content_after_closed = combined.split(self.respective_tag, 1) + return EndTagValidation( + content_before_closed=content_before_closed, content_after_closed=content_after_closed + ) + + if new_content.startswith(self.expected_next): # pragma: no cover + tag_content = combined[: len(self.respective_tag)] + self.buffer = tag_content + content_after_closed = combined[len(self.respective_tag) :] + return EndTagValidation(content_after_closed=content_after_closed) + elif (overlap := suffix_prefix_overlap(combined, self.respective_tag)) > 0: + content_to_add = combined[:-overlap] + content_to_buffer = combined[-overlap:] + # buffer partial closing tags + self.buffer = content_to_buffer + return EndTagValidation(content_before_closed=content_to_add) + else: + content_before_closed = combined + self.buffer = '' + return EndTagValidation(content_before_closed=content_before_closed) + @dataclass class ModelResponsePartsManager: @@ -56,8 +213,121 @@ class ModelResponsePartsManager: _parts: list[ManagedPart] = field(default_factory=list, init=False) """A list of parts (text or tool calls) that make up the current state of the model's response.""" + _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) - """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" + """Tracks the vendor part IDs of parts to their indices in the `_parts` list. + + Not all parts arrive with vendor part IDs, so the length of the tracker doesn't mirror the length of the _parts. + `ThinkingPart`s that are created via embedded thinking will stop being tracked once their closing tag is seen. + """ + + _partial_tags_list: list[PartialStartTag | PartialEndTag] = field(default_factory=list, init=False) + """Tracks active partial thinking tags. Tags contain their own previous_part_index and vendor_part_id.""" + + def _append_and_track_new_part(self, part: ManagedPart, vendor_part_id: VendorId | None) -> int: + """Append a new part to the manager and track it by vendor part ID if provided. + + Will overwrite any existing mapping for the given vendor part ID. + """ + new_part_index = len(self._parts) + if vendor_part_id is not None: + self._vendor_id_to_part_index[vendor_part_id] = new_part_index + self._parts.append(part) + return new_part_index + + def _replace_part(self, part_index: int, part: ManagedPart, vendor_part_id: VendorId) -> int: + """Replace an existing part at the given index.""" + self._parts[part_index] = part + self._vendor_id_to_part_index[vendor_part_id] = part_index + return part_index + + def _stop_tracking_vendor_id(self, vendor_part_id: VendorId) -> None: + """Stop tracking the given vendor part ID. + + This is useful when a part is considered complete and should no longer be updated. + + Args: + vendor_part_id: The vendor part ID to stop tracking. + """ + self._vendor_id_to_part_index.pop(vendor_part_id, None) + + def _get_part_and_index_by_vendor_id(self, vendor_part_id: VendorId) -> tuple[ManagedPart | None, int | None]: + """Get a part by its vendor part ID.""" + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + if part_index is not None: + return self._parts[part_index], part_index + return None, None + + def _get_partial_by_part_index(self, part_index: int) -> PartialStartTag | PartialEndTag | None: + """Get a partial thinking tag by its associated part index.""" + for tag in self._partial_tags_list: + if tag.previous_part_index == part_index: + return tag + return None + + def _stop_tracking_partial_tag(self, partial_tag: PartialStartTag | PartialEndTag) -> None: + """Stop tracking a partial tag.""" + if partial_tag in self._partial_tags_list: # pragma: no cover + # this is a defensive check in case we try to remove a tag that wasn't tracked + self._partial_tags_list.remove(partial_tag) + + def _get_active_partial_tag( + self, + existing_part: _ExistingPart[TextPart] | _ExistingPart[ThinkingPart] | None, + vendor_part_id: VendorId | None = None, + ) -> PartialStartTag | PartialEndTag | None: + """Get the active partial tag. + + - if vendor_part_id provided: lookup by vendor_id first (most relevant) + - if existing_part exists: lookup by that part's index + - if no existing_part: lookup by latest part's index, or index -1 for unattached tags + """ + if vendor_part_id is not None: + for tag in self._partial_tags_list: + if tag.vendor_part_id == vendor_part_id: # pragma: no branch + return tag + + if existing_part is not None: + return self._get_partial_by_part_index(existing_part.index) + elif self._parts: + latest_index = len(self._parts) - 1 + return self._get_partial_by_part_index(latest_index) + else: + return self._get_partial_by_part_index(-1) + + def _emit_text_start( + self, + *, + content: str, + vendor_part_id: VendorId | None, + id: str | None = None, + ) -> PartStartEvent: + new_text_part = TextPart(content=content, id=id) + new_part_index = self._append_and_track_new_part(new_text_part, vendor_part_id=vendor_part_id) + return PartStartEvent(index=new_part_index, part=new_text_part) + + def _emit_text_delta( + self, + *, + text_part: TextPart, + part_index: int, + content: str, + ) -> PartDeltaEvent: + part_delta = TextPartDelta(content_delta=content) + self._parts[part_index] = part_delta.apply(text_part) + return PartDeltaEvent(index=part_index, delta=part_delta) + + def _emit_thinking_delta_from_text( + self, + *, + thinking_part: ThinkingPart, + part_index: int, + content: str, + ) -> PartDeltaEvent: + """Emit a ThinkingPartDelta from text content. Used only for embedded thinking.""" + part_delta = ThinkingPartDelta(content_delta=content, signature_delta=None, provider_name=None) + self._parts[part_index] = part_delta.apply(thinking_part) + return PartDeltaEvent(index=part_index, delta=part_delta) def get_parts(self) -> list[ModelResponsePart]: """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). @@ -73,14 +343,27 @@ def handle_text_delta( vendor_part_id: VendorId | None, content: str, id: str | None = None, - thinking_tags: tuple[str, str] | None = None, + thinking_tags: ThinkingTags | None = None, ignore_leading_whitespace: bool = False, - ) -> ModelResponseStreamEvent | None: + ) -> Generator[ModelResponseStreamEvent, None, None]: """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. - When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; - otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding - to that vendor ID is either created or updated. + This function also handles what we'll call "embedded thinking", which is the generation of + `ThinkingPart`s via explicit thinking tags embedded in the text content. + Activating embedded thinking requires `thinking_tags` to be provided as a tuple of `(opening_tag, closing_tag)`. + + ### Embedded thinking will be processed under the following constraints: + - C1: Thinking tags are only processed when `thinking_tags` is provided. + - C2: Opening thinking tags are only recognized at the start of a content chunk. + - C3.0: Closing thinking tags are recognized anywhere within a content chunk. + - C3.1: Any text following a closing thinking tag in the same content chunk is treated as a new TextPart. + + ### Supported edge cases of embedded thinking: + - Thinking tags may arrive split across multiple content chunks. E.g., '' in the next. + - Partial Opening and Closing tags without adjacent content won't emit an event. + - EC2: No event is emitted for opening tags until they are fully formed and there is content following them. + - This is called 'delayed thinking' + - No event is emitted for closing tags that complete a `ThinkingPart` without any adjacent content. Args: vendor_part_id: The ID the vendor uses to identify this piece @@ -99,58 +382,334 @@ def handle_text_delta( Raises: UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart. """ - existing_text_part_and_index: tuple[TextPart, int] | None = None + existing_part: _ExistingPart[TextPart] | _ExistingPart[ThinkingPart] | None = None if vendor_part_id is None: - # If the vendor_part_id is None, check if the latest part is a TextPart to update if self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] if isinstance(latest_part, TextPart): - existing_text_part_and_index = latest_part, part_index + existing_part = _ExistingPart(part=latest_part, index=part_index, found_by='latest_part') + elif isinstance(latest_part, ThinkingPart): + # Only update ThinkingParts created by embedded thinking (have PartialEndTag) + # to avoid incorrectly updating ThinkingParts from handle_thinking_delta (native thinking) + partial = self._get_partial_by_part_index(part_index) + if isinstance(partial, PartialEndTag): + existing_part = _ExistingPart(part=latest_part, index=part_index, found_by='latest_part') else: - # Otherwise, attempt to look up an existing TextPart by vendor_part_id - part_index = self._vendor_id_to_part_index.get(vendor_part_id) + maybe_part, part_index = self._get_part_and_index_by_vendor_id(vendor_part_id) if part_index is not None: - existing_part = self._parts[part_index] - - if thinking_tags and isinstance(existing_part, ThinkingPart): - # We may be building a thinking part instead of a text part if we had previously seen a thinking tag - if content == thinking_tags[1]: - # When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part - self._vendor_id_to_part_index.pop(vendor_part_id) - return None - else: - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content) - elif isinstance(existing_part, TextPart): - existing_text_part_and_index = existing_part, part_index + if isinstance(maybe_part, ThinkingPart): + existing_part = _ExistingPart(part=maybe_part, index=part_index, found_by='vendor_part_id') + elif isinstance(maybe_part, TextPart): + existing_part = _ExistingPart(part=maybe_part, index=part_index, found_by='vendor_part_id') else: - raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') - - if thinking_tags and content == thinking_tags[0]: - # When we see a thinking start tag (which is a single token), we'll build a new thinking part instead - self._vendor_id_to_part_index.pop(vendor_part_id, None) - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') - - if existing_text_part_and_index is None: - # This is a workaround for models that emit `\n\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3), - # which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`. - if ignore_leading_whitespace and (len(content) == 0 or content.isspace()): - return None - - # There is no existing text part that should be updated, so create a new one - new_part_index = len(self._parts) - part = TextPart(content=content, id=id) - if vendor_part_id is not None: - self._vendor_id_to_part_index[vendor_part_id] = new_part_index - self._parts.append(part) - return PartStartEvent(index=new_part_index, part=part) + raise UnexpectedModelBehavior(f'Cannot apply a text delta to {maybe_part=}') + + if existing_part is None and ignore_leading_whitespace: + content = content.lstrip() + + # NOTE this breaks `test_direct.py`, `test_streaming.py` and `test_ui.py` expectations. + # `test.py` (`TestModel`) is set to generate an empty part at the beginning of the stream. + # if not content: + # return + + # we quickly handle good ol' text + if not thinking_tags: + yield from self._handle_plain_text(existing_part, content, vendor_part_id, id) + return + + # from here on we handle embedded thinking + partial_tag = self._get_active_partial_tag(existing_part, vendor_part_id) + + # 6. Handle based on current state + if existing_part is not None and isinstance(existing_part.part, ThinkingPart): + # Must be closing a ThinkingPart + thinking_part_existing = cast(_ExistingPart[ThinkingPart], existing_part) + if partial_tag is None: # pragma: no cover + raise RuntimeError('Embedded ThinkingParts must have an associated PartialEndTag') + if not isinstance(partial_tag, PartialEndTag): # pragma: no cover + raise RuntimeError('ThinkingPart cannot be associated with a PartialStartTag') + + yield from self._handle_thinking_closing( + thinking_part_existing.part, + thinking_part_existing.index, + partial_tag, + content, + vendor_part_id, + ignore_leading_whitespace, + ) + return + + if isinstance(partial_tag, PartialEndTag): + # Delayed thinking: have PartialEndTag but no ThinkingPart yet + existing_part = cast(_ExistingPart[TextPart] | None, existing_part) + yield from self._handle_delayed_thinking( + existing_part, partial_tag, content, vendor_part_id, ignore_leading_whitespace + ) + else: - # Update the existing TextPart with the new content delta - existing_text_part, part_index = existing_text_part_and_index + # Opening tag scenario (partial_tag is None or PartialStartTag) + opening_tag, closing_tag = thinking_tags + yield from self._handle_thinking_opening( + existing_part, + partial_tag, + content, + opening_tag, + closing_tag, + vendor_part_id, + id, + ignore_leading_whitespace, + ) + + def _handle_plain_text( + self, + existing_part: _ExistingPart[TextPart] | _ExistingPart[ThinkingPart] | None, + content: str, + vendor_part_id: VendorId | None, + id: str | None, + ) -> Generator[PartDeltaEvent | PartStartEvent, None, None]: + """Handle plain text content (no thinking tags).""" + if existing_part and isinstance(existing_part.part, TextPart): + existing_part = cast(_ExistingPart[TextPart], existing_part) part_delta = TextPartDelta(content_delta=content) - self._parts[part_index] = part_delta.apply(existing_text_part) - return PartDeltaEvent(index=part_index, delta=part_delta) + self._parts[existing_part.index] = part_delta.apply(existing_part.part) + yield PartDeltaEvent(index=existing_part.index, delta=part_delta) + else: + new_text_part = TextPart(content=content, id=id) + new_part_index = self._append_and_track_new_part(new_text_part, vendor_part_id) + yield PartStartEvent(index=new_part_index, part=new_text_part) + + def _handle_thinking_closing( + self, + thinking_part: ThinkingPart, + part_index: int, + partial_end_tag: PartialEndTag, + content: str, + vendor_part_id: VendorId, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle closing tag validation for an existing ThinkingPart.""" + end_tag_validation = partial_end_tag.validate_new_content(content, trim_whitespace=ignore_leading_whitespace) + + if end_tag_validation.content_before_closed: + yield self._emit_thinking_delta_from_text( + thinking_part=thinking_part, + part_index=part_index, + content=end_tag_validation.content_before_closed, + ) + + if partial_end_tag.is_complete: + self._stop_tracking_vendor_id(vendor_part_id) + self._stop_tracking_partial_tag(partial_end_tag) + + if end_tag_validation.content_after_closed: + yield self._emit_text_start( + content=end_tag_validation.content_after_closed, + vendor_part_id=vendor_part_id, + id=None, + ) + + def _handle_delayed_thinking( + self, + text_part: _ExistingPart[TextPart] | None, + partial_end_tag: PartialEndTag, + content: str, + vendor_part_id: VendorId | None, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle delayed thinking: PartialEndTag exists but no ThinkingPart created yet.""" + end_tag_validation = partial_end_tag.validate_new_content(content, trim_whitespace=ignore_leading_whitespace) + + if end_tag_validation.content_before_closed: + # Create ThinkingPart with this content + new_thinking_part = ThinkingPart(content=end_tag_validation.content_before_closed) + new_part_index = self._append_and_track_new_part(new_thinking_part, vendor_part_id) + partial_end_tag.previous_part_index = new_part_index + partial_end_tag.thinking_was_emitted = True + + yield PartStartEvent(index=new_part_index, part=new_thinking_part) + + if partial_end_tag.is_complete: + self._stop_tracking_partial_tag(partial_end_tag) + + if end_tag_validation.content_after_closed: + yield self._emit_text_start( + content=end_tag_validation.content_after_closed, + vendor_part_id=vendor_part_id, + id=None, + ) + + def _handle_thinking_opening( + self, + text_part: _ExistingPart[TextPart] | _ExistingPart[ThinkingPart] | None, + partial_start_tag: PartialStartTag | None, + content: str, + opening_tag: str, + closing_tag: str, + vendor_part_id: VendorId | None, + id: str | None, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle opening tag validation and buffering.""" + text_part = cast(_ExistingPart[TextPart] | None, text_part) + + if partial_start_tag is None: + partial_start_tag = PartialStartTag( + respective_tag=opening_tag, + # Use -1 as sentinel for "no existing part" to enable consistent lookups via _get_partial_by_part_index + previous_part_index=text_part.index if text_part is not None else -1, + vendor_part_id=vendor_part_id, + ) + self._partial_tags_list.append(partial_start_tag) + + start_tag_validation = partial_start_tag.validate_new_content(content) + + # Emit flushed buffer as text + if start_tag_validation.flushed_buffer: + if text_part: + yield self._emit_text_delta( + text_part=text_part.part, + part_index=text_part.index, + content=start_tag_validation.flushed_buffer, + ) + else: + text_start_event = self._emit_text_start( + content=start_tag_validation.flushed_buffer, + vendor_part_id=vendor_part_id, + id=id, + ) + partial_start_tag.previous_part_index = text_start_event.index + yield text_start_event + + # if tag completed, transition to PartialEndTag + if partial_start_tag.is_complete: + # Remove PartialStartTag before creating PartialEndTag to avoid tracking both simultaneously + self._stop_tracking_partial_tag(partial_start_tag) + + # Create PartialEndTag to track closing tag and subsequent thinking content + yield from self._create_partial_end_tag( + closing_tag=closing_tag, + preceeding_partial_start_tag=partial_start_tag, + thinking_content=start_tag_validation.thinking_content, + vendor_part_id=vendor_part_id, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + def _create_partial_end_tag( + self, + *, + closing_tag: str, + preceeding_partial_start_tag: PartialStartTag, + thinking_content: str, + vendor_part_id: VendorId | None, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Create a PartialEndTag and process any thinking content.""" + partial_end_tag = PartialEndTag( + respective_tag=closing_tag, + previous_part_index=preceeding_partial_start_tag.previous_part_index, + respective_opening_tag=preceeding_partial_start_tag.buffer, + thinking_was_emitted=False, + vendor_part_id=vendor_part_id, + ) + + end_tag_validation = partial_end_tag.validate_new_content( + thinking_content, trim_whitespace=ignore_leading_whitespace + ) + + if end_tag_validation.content_before_closed: + new_thinking_part = ThinkingPart(content=end_tag_validation.content_before_closed) + new_part_index = self._append_and_track_new_part(new_thinking_part, vendor_part_id) + partial_end_tag.previous_part_index = new_part_index + partial_end_tag.thinking_was_emitted = True + + # Track PartialEndTag + self._partial_tags_list.append(partial_end_tag) + + yield PartStartEvent(index=new_part_index, part=new_thinking_part) + + if partial_end_tag.is_complete: + self._stop_tracking_vendor_id(vendor_part_id) + self._stop_tracking_partial_tag(partial_end_tag) + if end_tag_validation.content_after_closed: + yield self._emit_text_start( + content=end_tag_validation.content_after_closed, + vendor_part_id=vendor_part_id, + id=None, + ) + elif partial_end_tag.is_complete: + # Empty thinking: - no part to track + if end_tag_validation.content_after_closed: + yield self._emit_text_start( + content=end_tag_validation.content_after_closed, + vendor_part_id=vendor_part_id, + id=None, + ) + else: + # Partial closing tag but no content yet - add to tracking list + self._partial_tags_list.append(partial_end_tag) + + def final_flush(self) -> Generator[ModelResponseStreamEvent, None, None]: + """Emit any buffered content from the last part in the manager. + + This function isn't used internally, it's used by the overarching StreamedResponse + to ensure any buffered content is flushed when the stream ends. + """ + last_part_index = len(self._parts) - 1 + + if last_part_index >= 0: + part = self._parts[last_part_index] + partial_tag = self._get_partial_by_part_index(last_part_index) + else: + part = None + partial_tag = None + + def remove_partial_and_emit_buffered( + partial: PartialStartTag | PartialEndTag, + part_index: int, + part: TextPart | ThinkingPart, + ) -> Generator[PartStartEvent | PartDeltaEvent, None, None]: + buffered_content = partial.flush() if isinstance(partial, PartialEndTag) else partial.buffer + + self._stop_tracking_partial_tag(partial) + + if buffered_content: + delta_type = TextPartDelta if isinstance(part, TextPart) else ThinkingPartDelta + if part.content: + content_delta = delta_type(content_delta=buffered_content) + self._parts[part_index] = content_delta.apply(part) + yield PartDeltaEvent(index=part_index, delta=content_delta) + else: + updated_part = replace(part, content=buffered_content) + self._parts[part_index] = updated_part + yield PartStartEvent(index=part_index, part=updated_part) + + if part is not None and isinstance(part, TextPart | ThinkingPart) and partial_tag is not None: + yield from remove_partial_and_emit_buffered(partial_tag, last_part_index, part) + + # Flush remaining partial tags + for partial_tag in list(self._partial_tags_list): + buffered_content = partial_tag.flush() if isinstance(partial_tag, PartialEndTag) else partial_tag.buffer + if not buffered_content: + self._stop_tracking_partial_tag(partial_tag) # partial tag has an associated part index of -1 here + continue + + if not partial_tag.has_previous_part: + # No associated part - create new TextPart + self._stop_tracking_partial_tag(partial_tag) # partial tag has an associated part index of -1 here + + new_text_part = TextPart(content='') + new_part_index = self._append_and_track_new_part(new_text_part, vendor_part_id=None) + yield from remove_partial_and_emit_buffered(partial_tag, new_part_index, new_text_part) + else: + # exclude the -1 sentinel (unattached tag) from part lookup + part_index = partial_tag.previous_part_index + part = self._parts[part_index] + if isinstance(part, TextPart | ThinkingPart): + yield from remove_partial_and_emit_buffered(partial_tag, part_index, part) + else: # pragma: no cover + raise RuntimeError('Partial tag is associated with a non-text/non-thinking part') def handle_thinking_delta( self, @@ -160,12 +719,12 @@ def handle_thinking_delta( id: str | None = None, signature: str | None = None, provider_name: str | None = None, - ) -> ModelResponseStreamEvent: + ) -> Generator[ModelResponseStreamEvent, None, None]: """Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate. When `vendor_part_id` is None, the latest part is updated if it exists and is a ThinkingPart; otherwise, a new ThinkingPart is created. When a non-None ID is specified, the ThinkingPart corresponding - to that vendor ID is either created or updated. + to that vendor part ID is either created or updated. Args: vendor_part_id: The ID the vendor uses to identify this piece @@ -185,41 +744,33 @@ def handle_thinking_delta( existing_thinking_part_and_index: tuple[ThinkingPart, int] | None = None if vendor_part_id is None: - # If the vendor_part_id is None, check if the latest part is a ThinkingPart to update if self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] - if isinstance(latest_part, ThinkingPart): # pragma: no branch + if isinstance(latest_part, ThinkingPart): existing_thinking_part_and_index = latest_part, part_index else: - # Otherwise, attempt to look up an existing ThinkingPart by vendor_part_id - part_index = self._vendor_id_to_part_index.get(vendor_part_id) + existing_part, part_index = self._get_part_and_index_by_vendor_id(vendor_part_id) if part_index is not None: - existing_part = self._parts[part_index] if not isinstance(existing_part, ThinkingPart): raise UnexpectedModelBehavior(f'Cannot apply a thinking delta to {existing_part=}') existing_thinking_part_and_index = existing_part, part_index if existing_thinking_part_and_index is None: if content is not None or signature is not None: - # There is no existing thinking part that should be updated, so create a new one - new_part_index = len(self._parts) part = ThinkingPart(content=content or '', id=id, signature=signature, provider_name=provider_name) - if vendor_part_id is not None: # pragma: no branch - self._vendor_id_to_part_index[vendor_part_id] = new_part_index - self._parts.append(part) - return PartStartEvent(index=new_part_index, part=part) + new_part_index = self._append_and_track_new_part(part, vendor_part_id) + yield PartStartEvent(index=new_part_index, part=part) else: raise UnexpectedModelBehavior('Cannot create a ThinkingPart with no content or signature') else: if content is not None or signature is not None: - # Update the existing ThinkingPart with the new content and/or signature delta existing_thinking_part, part_index = existing_thinking_part_and_index part_delta = ThinkingPartDelta( content_delta=content, signature_delta=signature, provider_name=provider_name ) self._parts[part_index] = part_delta.apply(existing_thinking_part) - return PartDeltaEvent(index=part_index, delta=part_delta) + yield PartDeltaEvent(index=part_index, delta=part_delta) else: raise UnexpectedModelBehavior('Cannot update a ThinkingPart with no content or signature') @@ -261,46 +812,33 @@ def handle_tool_call_delta( ) if vendor_part_id is None: - # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update - # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather - # than a delta on an existing one. We can change this behavior in the future if necessary for some model. if tool_name is None and self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] - if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): # pragma: no branch + if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): existing_matching_part_and_index = latest_part, part_index else: - # vendor_part_id is provided, so look up the corresponding part or delta - part_index = self._vendor_id_to_part_index.get(vendor_part_id) + existing_part, part_index = self._get_part_and_index_by_vendor_id(vendor_part_id) if part_index is not None: - existing_part = self._parts[part_index] if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart | BuiltinToolCallPart): raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}') existing_matching_part_and_index = existing_part, part_index if existing_matching_part_and_index is None: - # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed) delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) part = delta.as_part() or delta - if vendor_part_id is not None: - self._vendor_id_to_part_index[vendor_part_id] = len(self._parts) - new_part_index = len(self._parts) - self._parts.append(part) - # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart + new_part_index = self._append_and_track_new_part(part, vendor_part_id) if isinstance(part, ToolCallPart | BuiltinToolCallPart): return PartStartEvent(index=new_part_index, part=part) else: - # Update the existing part or delta with the new information existing_part, part_index = existing_matching_part_and_index delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) updated_part = delta.apply(existing_part) self._parts[part_index] = updated_part if isinstance(updated_part, ToolCallPart | BuiltinToolCallPart): if isinstance(existing_part, ToolCallPartDelta): - # We just upgraded a delta to a full part, so emit a PartStartEvent return PartStartEvent(index=part_index, part=updated_part) else: - # We updated an existing part, so emit a PartDeltaEvent if updated_part.tool_call_id and not delta.tool_call_id: delta = replace(delta, tool_call_id=updated_part.tool_call_id) return PartDeltaEvent(index=part_index, delta=delta) @@ -338,18 +876,14 @@ def handle_tool_call_part( ) if vendor_part_id is None: # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list - new_part_index = len(self._parts) - self._parts.append(new_part) + new_part_index = self._append_and_track_new_part(new_part, vendor_part_id) else: # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart. - maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) - if maybe_part_index is not None and isinstance(self._parts[maybe_part_index], ToolCallPart): - new_part_index = maybe_part_index - self._parts[new_part_index] = new_part + maybe_part, part_index = self._get_part_and_index_by_vendor_id(vendor_part_id) + if part_index is not None and isinstance(maybe_part, ToolCallPart): + new_part_index = self._replace_part(part_index, new_part, vendor_part_id) else: - new_part_index = len(self._parts) - self._parts.append(new_part) - self._vendor_id_to_part_index[vendor_part_id] = new_part_index + new_part_index = self._append_and_track_new_part(new_part, vendor_part_id) return PartStartEvent(index=new_part_index, part=new_part) def handle_part( @@ -371,16 +905,12 @@ def handle_part( """ if vendor_part_id is None: # vendor_part_id is None, so we unconditionally append a new part to the end of the list - new_part_index = len(self._parts) - self._parts.append(part) + new_part_index = self._append_and_track_new_part(part, vendor_part_id) else: # vendor_part_id is provided, so find and overwrite or create a new part. - maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) - if maybe_part_index is not None and isinstance(self._parts[maybe_part_index], type(part)): - new_part_index = maybe_part_index - self._parts[new_part_index] = part + maybe_part, part_index = self._get_part_and_index_by_vendor_id(vendor_part_id) + if part_index is not None and isinstance(maybe_part, type(part)): + new_part_index = self._replace_part(part_index, part, vendor_part_id) else: - new_part_index = len(self._parts) - self._parts.append(part) - self._vendor_id_to_part_index[vendor_part_id] = new_part_index + new_part_index = self._append_and_track_new_part(part, vendor_part_id) return PartStartEvent(index=new_part_index, part=part) diff --git a/pydantic_ai_slim/pydantic_ai/_thinking_part.py b/pydantic_ai_slim/pydantic_ai/_thinking_part.py index db67fda847..0c303720a9 100644 --- a/pydantic_ai_slim/pydantic_ai/_thinking_part.py +++ b/pydantic_ai_slim/pydantic_ai/_thinking_part.py @@ -29,3 +29,8 @@ def split_content_into_text_and_thinking(content: str, thinking_tags: tuple[str, if content: parts.append(TextPart(content=content)) return parts + + +# NOTE: this utility is used by models/: `groq`, `huggingface`, `openai`, `outlines` and `tests/test_thinking_part.py` +# not sure if it could be replaced by the new handling in the `_parts_manager.py` but it's worth taking a closer look. +# if that's the case we could use this file to partly isolate the embedded thinking handling logic and declutter the parts manager. diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 98214910bd..0fb9961952 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -7,6 +7,7 @@ from __future__ import annotations as _annotations import base64 +import copy import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Iterator @@ -646,7 +647,11 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent | if end_event: yield end_event - self._event_iterator = iterator_with_part_end(iterator_with_final_event(self._get_event_iterator())) + self._event_iterator = iterator_with_part_end( + iterator_with_final_event( + chain_async_and_sync_iters(self._get_event_iterator(), self._parts_manager.final_flush()) + ) + ) return self._event_iterator @abstractmethod @@ -664,8 +669,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: def get(self) -> ModelResponse: """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + # Flush any buffered content before building response + # clone parts manager to avoid modifying the ongoing stream state + cloned_manager = copy.deepcopy(self._parts_manager) + for _ in cloned_manager.final_flush(): + pass + return ModelResponse( - parts=self._parts_manager.get_parts(), + parts=cloned_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp, usage=self.usage(), @@ -699,6 +710,16 @@ def timestamp(self) -> datetime: raise NotImplementedError() +async def chain_async_and_sync_iters( + iter1: AsyncIterator[ModelResponseStreamEvent], iter2: Iterator[ModelResponseStreamEvent] +) -> AsyncIterator[ModelResponseStreamEvent]: + """Chain an async iterator with a sync iterator.""" + async for event in iter1: + yield event + for event in iter2: + yield event + + ALLOW_MODEL_REQUESTS = True """Whether to allow requests to models. diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index c636ba9cfc..c087f4e1f5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -821,25 +821,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block if isinstance(current_block, BetaTextBlock) and current_block.text: - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=current_block.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(current_block, BetaThinkingBlock): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, content=current_block.thinking, signature=current_block.signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(current_block, BetaRedactedThinkingBlock): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, id='redacted_thinking', signature=current_block.data, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(current_block, BetaToolUseBlock): maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=event.index, @@ -895,23 +896,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=event.delta.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(event.delta, BetaThinkingDelta): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, content=event.delta.thinking, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(event.delta, BetaSignatureDelta): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, signature=event.delta.signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(event.delta, BetaInputJSONDelta): maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=event.index, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index acb98e5ec0..799c80af27 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -747,24 +747,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: delta = content_block_delta['delta'] if 'reasoningContent' in delta: if redacted_content := delta['reasoningContent'].get('redactedContent'): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=index, id='redacted_content', signature=redacted_content.decode('utf-8'), provider_name=self.provider_name, - ) + ): + yield e else: signature = delta['reasoningContent'].get('signature') - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=index, content=delta['reasoningContent'].get('text'), signature=signature, provider_name=self.provider_name if signature else None, - ) + ): + yield e if text := delta.get('text'): - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text): + yield event if 'toolUse' in delta: tool_use = delta['toolUse'] maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 37876e3e80..77f366b94a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -186,6 +186,7 @@ async def request_stream( yield FunctionStreamedResponse( model_request_parameters=model_request_parameters, + _model_profile=self.profile, _model_name=self._model_name, _iter=response_stream, ) @@ -286,32 +287,35 @@ class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" _model_name: str + _model_profile: ModelProfile _iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns] _timestamp: datetime = field(default_factory=_utils.now_utc) def __post_init__(self): self._usage += _estimate_usage([]) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 async for item in self._iter: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) self._usage += usage.RequestUsage(output_tokens=response_tokens) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta( + vendor_part_id='content', content=item, thinking_tags=self._model_profile.thinking_tags + ): + yield event elif isinstance(item, dict) and item: for dtc_index, delta in item.items(): if isinstance(delta, DeltaThinkingPart): if delta.content: # pragma: no branch response_tokens = _estimate_string_tokens(delta.content) self._usage += usage.RequestUsage(output_tokens=response_tokens) - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=dtc_index, content=delta.content, signature=delta.signature, provider_name='function' if delta.signature else None, - ) + ): + yield e elif isinstance(delta, DeltaToolCall): if delta.json_args: response_tokens = _estimate_string_tokens(delta.json_args) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 10c227d0db..a3d88b3d61 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -465,11 +465,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=None, content=gemini_part['text'] - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 3a5cfe9258..8d252fe8bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -675,24 +675,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for part in parts: if part.thought_signature: signature = base64.b64encode(part.thought_signature).decode('utf-8') - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', signature=signature, provider_name=self.provider_name, - ) + ): + yield e if part.text is not None: if len(part.text) > 0: if part.thought: - yield self._parts_manager.handle_thinking_delta( + for event in self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', content=part.text - ) + ): + yield event else: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=part.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif part.function_call: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 67c27a19c2..3df184a42c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -549,9 +549,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: reasoning = True # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`. - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning - ) + ): + yield e else: reasoning = False @@ -574,14 +575,17 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, + # where does `ignore_leading_whitespace` come from? + # `GroqModel._process_streamed_response()` returns a `GroqStreamedResponse(_model_profile=self.profile,)` + # `Groq.profile`` is set at `super().__init__(settings=settings, profile=profile or provider.model_profile)` + # so `_model_profile` comes either from `GroqModel(profile=...)` or `GroqModel(provider=GroqProvider(...))` where the provider infers a profile. ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event # Handle the tool calls for dtc in choice.delta.tool_calls or []: diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 94598aee7e..fb398eaba5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -487,14 +487,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 770c8ff6ca..dd0065a848 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -639,7 +639,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: content = choice.delta.content text, thinking = _map_content(content) for thought in thinking: - self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought) + for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought): + yield event if text: # Attempt to produce an output tool call from the received text output_tools = {c.name: c for c in self.model_request_parameters.output_tools} @@ -655,9 +656,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text): + yield event # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 5da7e0ccd4..2d5f6bc212 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1713,7 +1713,7 @@ class OpenAIStreamedResponse(StreamedResponse): _provider_name: str _provider_url: str - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 async for chunk in self._response: self._usage += _map_usage(chunk, self._provider_name, self._provider_url, self._model_name) @@ -1739,38 +1739,39 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # The `reasoning_content` field is only present in DeepSeek models. # https://api-docs.deepseek.com/guides/reasoning_model if reasoning_content := getattr(choice.delta, 'reasoning_content', None): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='reasoning_content', id='reasoning_content', content=reasoning_content, provider_name=self.provider_name, - ) + ): + yield e # The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter. # - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api # - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens if reasoning := getattr(choice.delta, 'reasoning', None): # pragma: no cover - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id='reasoning', id='reasoning', content=reasoning, provider_name=self.provider_name, - ) + ): + yield e # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart): - maybe_event.part.id = 'content' - maybe_event.part.provider_name = self.provider_name - yield maybe_event + ): + if isinstance(event, PartStartEvent) and isinstance(event.part, ThinkingPart): + event.part.id = 'content' + event.part.provider_name = self.provider_name + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( @@ -1921,12 +1922,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(chunk.item, responses.ResponseReasoningItem): if signature := chunk.item.encrypted_content: # pragma: no branch # Add the signature to the part corresponding to the first summary item - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item.id}-0', id=chunk.item.id, signature=signature, provider_name=self.provider_name, - ) + ): + yield e elif isinstance(chunk.item, responses.ResponseCodeInterpreterToolCall): _, return_part, file_parts = _map_code_interpreter_tool_call(chunk.item, self.provider_name) for i, file_part in enumerate(file_parts): @@ -1959,11 +1961,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part) elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}', content=chunk.part.text, id=chunk.item_id, - ) + ): + yield e elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent): pass # there's nothing we need to do here @@ -1972,22 +1975,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass # there's nothing we need to do here elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent): - yield self._parts_manager.handle_thinking_delta( + for e in self._parts_manager.handle_thinking_delta( vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}', content=chunk.delta, id=chunk.item_id, - ) + ): + yield e elif isinstance(chunk, responses.ResponseOutputTextAnnotationAddedEvent): # TODO(Marcelo): We should support annotations in the future. pass # there's nothing we need to do here elif isinstance(chunk, responses.ResponseTextDeltaEvent): - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif isinstance(chunk, responses.ResponseTextDoneEvent): pass # there's nothing we need to do here diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py index 5b439952c1..170e17e6fd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/outlines.py +++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py @@ -6,7 +6,7 @@ from __future__ import annotations import io -from collections.abc import AsyncIterable, AsyncIterator, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, replace from datetime import datetime, timezone @@ -546,15 +546,18 @@ class OutlinesStreamedResponse(StreamedResponse): _provider_name: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - async for event in self._response: - event = self._parts_manager.handle_text_delta( - vendor_part_id='content', - content=event, - thinking_tags=self._model_profile.thinking_tags, - ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, + async for chunk in self._response: + events = cast( + Iterator[ModelResponseStreamEvent], + self._parts_manager.handle_text_delta( + vendor_part_id='content', + content=chunk, + thinking_tags=self._model_profile.thinking_tags, + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, + ), ) - if event is not None: # pragma: no branch - yield event + for e in events: + yield e @property def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 170113a999..eddc98787b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -313,14 +313,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: mid = len(text) // 2 words = [text[:mid], text[mid:]] self._usage += _get_string_usage('') - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=''): + yield event for word in words: self._usage += _get_string_usage(word) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=word): + yield event elif isinstance(part, ToolCallPart): yield self._parts_manager.handle_tool_call_part( vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id diff --git a/pyproject.toml b/pyproject.toml index 3c13afdece..dc0f190dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -317,4 +317,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -ignore-words-list = 'asend,aci' +ignore-words-list = 'asend,aci,thi' diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 515892d58c..57f385f514 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -2014,7 +2014,6 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap parts=[ ThinkingPart( content="""\ - Okay, so I want to make Uruguayan alfajores. I've heard they're a type of South American cookie sandwich with dulce de leche. I'm not entirely sure about the exact steps, but I can try to figure it out based on what I know. First, I think alfajores are cookies, so I'll need to make the cookie part. From what I remember, the dough is probably made with flour, sugar, butter, eggs, vanilla, and maybe some baking powder or baking soda. I should look up a typical cookie dough recipe and adjust it for alfajores. @@ -2103,9 +2102,7 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap assert event_parts == snapshot( [ - PartStartEvent(index=0, part=ThinkingPart(content='')), - PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')), - PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')), + PartStartEvent(index=0, part=ThinkingPart(content='Okay')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' so')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' I')), @@ -2625,7 +2622,6 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap index=0, part=ThinkingPart( content="""\ - Okay, so I want to make Uruguayan alfajores. I've heard they're a type of South American cookie sandwich with dulce de leche. I'm not entirely sure about the exact steps, but I can try to figure it out based on what I know. First, I think alfajores are cookies, so I'll need to make the cookie part. From what I remember, the dough is probably made with flour, sugar, butter, eggs, vanilla, and maybe some baking powder or baking soda. I should look up a typical cookie dough recipe and adjust it for alfajores. diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b6a52e0c25..651a5b6a41 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -118,12 +118,10 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: self._usage = RequestUsage(input_tokens=300, output_tokens=400) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') - if maybe_event is not None: # pragma: no branch - yield maybe_event - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1'): + yield event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2'): + yield event @property def model_name(self) -> str: diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index e68c64abe3..99f4e9a731 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -625,11 +625,53 @@ async def test_stream_text_empty_think_tag_and_text_before_tool_call(allow_model async with agent.run_stream('') as result: assert not result.is_complete assert [c async for c in result.stream_output(debounce_by=None)] == snapshot( - [{}, {'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] ) assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'}) +async def test_stream_with_embedded_thinking_sets_metadata(allow_model_requests: None): + """Test that embedded thinking creates ThinkingPart with id='content' and provider_name='openai'. + + COVERAGE: This test covers openai.py lines 1748-1749 which set: + event.part.id = 'content' + event.part.provider_name = self.provider_name + """ + stream = [ + text_chunk(''), + text_chunk('reasoning'), + text_chunk(''), + text_chunk('response'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['response']) + + # Verify ThinkingPart has id='content' and provider_name='openai' (covers lines 1748-1749) + messages = result.all_messages() + assert len(messages) == 2 + assert isinstance(messages[0], ModelRequest) + assert isinstance(messages[1], ModelResponse) + + response = messages[1] + assert len(response.parts) == 2 + + # This is what we're testing - the ThinkingPart should have these metadata fields set + thinking_part = response.parts[0] + assert isinstance(thinking_part, ThinkingPart) + assert thinking_part.id == 'content' # Line 1748 in openai.py + assert thinking_part.provider_name == 'openai' # Line 1749 in openai.py + assert thinking_part.content == 'reasoning' + + text_part = response.parts[1] + assert isinstance(text_part, TextPart) + assert text_part.content == 'response' + + async def test_no_delta(allow_model_requests: None): stream = [ chunk([]), diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py index 59ce3e31a9..86f981b7f9 100644 --- a/tests/test_parts_manager.py +++ b/tests/test_parts_manager.py @@ -28,140 +28,41 @@ def test_handle_text_deltas(vendor_part_id: str | None): manager = ModelResponsePartsManager() assert manager.get_parts() == [] - event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) + events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ')) + assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))]) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) + events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world')) + assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))]) assert manager.get_parts() == snapshot([TextPart(content='hello world', part_kind='text')]) def test_handle_dovetailed_text_deltas(): manager = ModelResponsePartsManager() - event = manager.handle_text_delta(vendor_part_id='first', content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) + events = list(manager.handle_text_delta(vendor_part_id='first', content='hello ')) + assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))]) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ') - assert event == snapshot( - PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start') - ) + events = list(manager.handle_text_delta(vendor_part_id='second', content='goodbye ')) + assert events == snapshot([PartStartEvent(index=1, part=TextPart(content='goodbye '))]) assert manager.get_parts() == snapshot( [TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='first', content='world') - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) + events = list(manager.handle_text_delta(vendor_part_id='first', content='world')) + assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))]) assert manager.get_parts() == snapshot( [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] ) - event = manager.handle_text_delta(vendor_part_id='second', content='Samuel') - assert event == snapshot( - PartDeltaEvent( - index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta' - ) - ) + events = list(manager.handle_text_delta(vendor_part_id='second', content='Samuel')) + assert events == snapshot([PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Samuel'))]) assert manager.get_parts() == snapshot( [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye Samuel', part_kind='text')] ) -def test_handle_text_deltas_with_think_tags(): - manager = ModelResponsePartsManager() - thinking_tags = ('', '') - - event = manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags) - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')]) - - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')]) - - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event == snapshot( - PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot( - [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='', part_kind='thinking')] - ) - - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( - PartDeltaEvent( - index=1, - delta=ThinkingPartDelta(content_delta='thinking', part_delta_kind='thinking'), - event_kind='part_delta', - ) - ) - assert manager.get_parts() == snapshot( - [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')] - ) - - event = manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags) - assert event == snapshot( - PartDeltaEvent( - index=1, delta=ThinkingPartDelta(content_delta=' more', part_delta_kind='thinking'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - ] - ) - - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event is None - - event = manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags) - assert event == snapshot( - PartStartEvent(index=2, part=TextPart(content='post-', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - TextPart(content='post-', part_kind='text'), - ] - ) - - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( - PartDeltaEvent( - index=2, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot( - [ - TextPart(content='pre-thinking', part_kind='text'), - ThinkingPart(content='thinking more', part_kind='thinking'), - TextPart(content='post-thinking', part_kind='text'), - ] - ) - - def test_handle_tool_call_deltas(): manager = ModelResponsePartsManager() @@ -376,10 +277,8 @@ def test_handle_tool_call_part(): def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | None, tool_vendor_part_id: str | None): manager = ModelResponsePartsManager() - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ')) + assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))]) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) event = manager.handle_tool_call_delta( @@ -393,15 +292,9 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ) ) - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world') + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world')) if text_vendor_part_id is None: - assert event == snapshot( - PartStartEvent( - index=2, - part=TextPart(content='world', part_kind='text'), - event_kind='part_start', - ) - ) + assert events == snapshot([PartStartEvent(index=2, part=TextPart(content='world'))]) assert manager.get_parts() == snapshot( [ TextPart(content='hello ', part_kind='text'), @@ -410,11 +303,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ] ) else: - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) + assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))]) assert manager.get_parts() == snapshot( [ TextPart(content='hello world', part_kind='text'), @@ -425,7 +314,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non def test_cannot_convert_from_text_to_tool_call(): manager = ModelResponsePartsManager() - manager.handle_text_delta(vendor_part_id=1, content='hello') + for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'): + pass with pytest.raises( UnexpectedModelBehavior, match=re.escape('Cannot apply a tool call delta to existing_part=TextPart(') ): @@ -436,9 +326,10 @@ def test_cannot_convert_from_tool_call_to_text(): manager = ModelResponsePartsManager() manager.handle_tool_call_delta(vendor_part_id=1, tool_name='tool1', args='{"arg1":', tool_call_id=None) with pytest.raises( - UnexpectedModelBehavior, match=re.escape('Cannot apply a text delta to existing_part=ToolCallPart(') + UnexpectedModelBehavior, match=re.escape('Cannot apply a text delta to maybe_part=ToolCallPart(') ): - manager.handle_text_delta(vendor_part_id=1, content='hello') + for _ in manager.handle_text_delta(vendor_part_id=1, content='hello'): + pass def test_tool_call_id_delta(): @@ -529,14 +420,12 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part(): manager = ModelResponsePartsManager() # Add a thinking part first - event = manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None) - assert isinstance(event, PartStartEvent) - assert event.index == 0 + events = list(manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None)) + assert events == snapshot([PartStartEvent(index=0, part=ThinkingPart(content='initial thought'))]) # Now add another thinking delta with no vendor_part_id - should update the latest thinking part - event = manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None) - assert isinstance(event, PartDeltaEvent) - assert event.index == 0 + events = list(manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None)) + assert events == snapshot([PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' more'))]) parts = manager.get_parts() assert parts == snapshot([ThinkingPart(content='initial thought more')]) @@ -546,19 +435,20 @@ def test_handle_thinking_delta_wrong_part_type(): manager = ModelResponsePartsManager() # Add a text part first - manager.handle_text_delta(vendor_part_id='text', content='hello') + for _ in manager.handle_text_delta(vendor_part_id='text', content='hello'): + pass # Try to apply thinking delta to the text part - should raise error with pytest.raises(UnexpectedModelBehavior, match=r'Cannot apply a thinking delta to existing_part='): - manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='text', content='thinking', signature=None): + pass def test_handle_thinking_delta_new_part_with_vendor_id(): manager = ModelResponsePartsManager() - event = manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None) - assert isinstance(event, PartStartEvent) - assert event.index == 0 + events = list(manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None)) + assert events == snapshot([PartStartEvent(index=0, part=ThinkingPart(content='new thought'))]) parts = manager.get_parts() assert parts == snapshot([ThinkingPart(content='new thought')]) @@ -568,18 +458,21 @@ def test_handle_thinking_delta_no_content(): manager = ModelResponsePartsManager() with pytest.raises(UnexpectedModelBehavior, match='Cannot create a ThinkingPart with no content'): - manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id=None, content=None, signature=None): + pass def test_handle_thinking_delta_no_content_or_signature(): manager = ModelResponsePartsManager() # Add a thinking part first - manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content='initial', signature=None): + pass # Try to update with no content or signature - should raise error with pytest.raises(UnexpectedModelBehavior, match='Cannot update a ThinkingPart with no content or signature'): - manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None) + for _ in manager.handle_thinking_delta(vendor_part_id='thinking', content=None, signature=None): + pass def test_handle_part(): @@ -611,3 +504,52 @@ def test_handle_part(): event = manager.handle_part(vendor_part_id=None, part=part3) assert event == snapshot(PartStartEvent(index=1, part=part3)) assert manager.get_parts() == snapshot([part2, part3]) + + +def test_handle_thinking_delta_when_latest_is_not_thinking(): + """Test that handle_thinking_delta creates new part when latest part is not ThinkingPart.""" + manager = ModelResponsePartsManager() + + # Create TextPart first + list(manager.handle_text_delta(vendor_part_id='content', content='text')) + + # Call handle_thinking_delta with vendor_part_id=None + # Should create NEW ThinkingPart instead of trying to update TextPart + events = list(manager.handle_thinking_delta(vendor_part_id=None, content='thinking')) + + assert events == snapshot([PartStartEvent(index=1, part=ThinkingPart(content='thinking'))]) + assert manager.get_parts() == snapshot([TextPart(content='text'), ThinkingPart(content='thinking')]) + + +def test_handle_tool_call_delta_when_latest_is_not_tool_call(): + """Test that handle_tool_call_delta creates new part when latest part is not a tool call.""" + manager = ModelResponsePartsManager() + + # Create TextPart first + list(manager.handle_text_delta(vendor_part_id='content', content='text')) + + # Call handle_tool_call_delta with vendor_part_id=None + # Should create NEW ToolCallPart instead of trying to update TextPart + event = manager.handle_tool_call_delta(vendor_part_id=None, tool_name='my_tool') + + assert event == snapshot(PartStartEvent(index=1, part=ToolCallPart(tool_name='my_tool', tool_call_id=IsStr()))) + assert manager.get_parts() == snapshot( + [TextPart(content='text'), ToolCallPart(tool_name='my_tool', tool_call_id=IsStr())] + ) + + +def test_handle_tool_call_delta_without_tool_name_when_latest_is_not_tool_call(): + """Test handle_tool_call_delta with vendor_part_id=None and tool_name=None when latest is not a tool call.""" + manager = ModelResponsePartsManager() + + # Create TextPart first + list(manager.handle_text_delta(vendor_part_id='content', content='text')) + + # Call handle_tool_call_delta with BOTH vendor_part_id=None AND tool_name=None + # Latest part is TextPart (not a tool call), so should create new ToolCallPartDelta + event = manager.handle_tool_call_delta(vendor_part_id=None, tool_name=None, args='{"foo": "bar"}') + + # Since no tool_name provided, no event is emitted until we have enough info + assert event == snapshot(None) + # But a ToolCallPartDelta should not be in get_parts() (only complete parts) + assert manager.get_parts() == snapshot([TextPart(content='text')]) diff --git a/tests/test_parts_manager_thinking_tags.py b/tests/test_parts_manager_thinking_tags.py new file mode 100644 index 0000000000..f12cb24002 --- /dev/null +++ b/tests/test_parts_manager_thinking_tags.py @@ -0,0 +1,570 @@ +"""This file tests the "embedded thinking handling" functionality of the Parts Manager (_parts_manager.py). + +It tests each case with both vendor_part_id='content' and vendor_part_id=None to ensure consistent behavior. +""" + +from __future__ import annotations as _annotations + +from collections.abc import Hashable, Sequence +from dataclasses import dataclass, field + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import ( + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, +) +from pydantic_ai._parts_manager import ModelResponsePart, ModelResponsePartsManager +from pydantic_ai.messages import ModelResponseStreamEvent + + +def stream_text_deltas( + case: Case, +) -> tuple[list[ModelResponseStreamEvent], list[ModelResponseStreamEvent], list[ModelResponsePart]]: + """Helper to stream chunks through manager and return all events + final parts.""" + manager = ModelResponsePartsManager() + normal_events: list[ModelResponseStreamEvent] = [] + + for chunk in case.chunks: + for event in manager.handle_text_delta( + vendor_part_id=case.vendor_part_id, + content=chunk, + thinking_tags=case.thinking_tags, + ignore_leading_whitespace=case.ignore_leading_whitespace, + ): + normal_events.append(event) + + flushed_events: list[ModelResponseStreamEvent] = [] + for event in manager.final_flush(): + flushed_events.append(event) + + return normal_events, flushed_events, manager.get_parts() + + +def init_model_response_stream_event_iterator() -> Sequence[ModelResponseStreamEvent]: + # both pyright and pre-commit asked for this + return [] + + +@dataclass +class Case: + name: str + chunks: list[str] + expected_parts: list[ModelResponsePart] # [TextPart|ThinkingPart('final content')] + expected_normal_events: Sequence[ModelResponseStreamEvent] + expected_flushed_events: Sequence[ModelResponseStreamEvent] = field( + default_factory=init_model_response_stream_event_iterator + ) + vendor_part_id: Hashable | None = 'content' + ignore_leading_whitespace: bool = False + thinking_tags: tuple[str, str] | None = ('', '') + + +FULL_SPLITS = [ + Case( + name='full_split_partial_closing', + chunks=['con', 'tentcon', 'tent', 'after'], + expected_parts=[ThinkingPart('content'), TextPart('after')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('con')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='tent')), + PartStartEvent(index=1, part=TextPart('after')), + ], + ), + Case( + name='full_split_on_both_sides_closing_buffer_and_stutter', + chunks=['con', 'tent', 'after'], + expected_parts=[ThinkingPart('contentcon', 'tent', 'after', 'content'], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ), +] + +# Category 2: Delayed Thinking (no event until content after complete opening) +DELAYED_THINKING_CASES: list[Case] = [ + Case( + name='delayed_thinking_with_content_closes_in_next_chunk', + chunks=['', 'content'], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ), + Case( + name='delayed_thinking_with_leading_whitespace_trimmed', + chunks=['', ' content', ''], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ignore_leading_whitespace=True, + ), + Case( + name='delayed_empty_thinking_closes_in_separate_chunk_with_after', + chunks=['', 'after'], + expected_parts=[TextPart('after')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart('after')), + ], + # NOTE empty thinking is skipped entirely + expected_flushed_events=[], + ), +] + +# Category 3: Invalid Opening Tags (prefixes, invalid continuations, flushes) +INVALID_OPENING_CASES: list[Case] = [ + Case( + name='multiple_partial_openings_buffered_until_invalid_continuation', + chunks=[''], + expected_parts=[TextPart('pre')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart('pre')), + ], + ), +] + +# Category 4: Full Thinking Tags (complete cycles: open + content + close, with/without after) +FULL_THINKING_CASES: list[Case] = [ + Case( + name='new_part_empty_thinking_treated_as_text', + chunks=[''], + expected_parts=[], # Empty thinking is now skipped entirely + expected_normal_events=[], + ), + Case( + name='new_part_empty_thinking_with_after_treated_as_text', + chunks=['more'], + expected_parts=[TextPart('more')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart('more')), + ], + ), + Case( + name='new_part_complete_thinking_with_content_no_after', + chunks=['content'], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ), + Case( + name='new_part_complete_thinking_with_content_with_after', + chunks=['contentmore'], + expected_parts=[ThinkingPart('content'), TextPart('more')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartStartEvent(index=1, part=TextPart('more')), + ], + ), +] + +# Category 5: Closing Tag Handling (clean closings, with before/after, no before) +CLOSING_TAG_CASES: list[Case] = [ + Case( + name='existing_thinking_clean_closing', + chunks=['content', ''], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ), + Case( + name='existing_thinking_closing_with_before', + chunks=['content', 'more'], + expected_parts=[ThinkingPart('contentmore')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='more')), + ], + ), + Case( + name='existing_thinking_closing_with_before_after', + chunks=['content', 'moreafter'], + expected_parts=[ThinkingPart('contentmore'), TextPart('after')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='more')), + PartStartEvent(index=1, part=TextPart('after')), + ], + ), + Case( + name='existing_thinking_closing_no_before_with_after', + chunks=['content', 'after'], + expected_parts=[ThinkingPart('content'), TextPart('after')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartStartEvent(index=1, part=TextPart('after')), + ], + ), +] + +# Category 6: Partial Closing Tags (partials, overlaps, completes, with content) +PARTIAL_CLOSING_CASES: list[Case] = [ + Case( + name='new_part_opening_with_content_partial_closing', + chunks=['contentcontent', 'content', ''], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + ], + ), + Case( + name='existing_thinking_partial_closing_with_content_to_add', + chunks=['content', 'morecontent', 'more'], + expected_parts=[ThinkingPart('contentmore')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='more')), + ], + ), + Case( + name='new_part_empty_thinking_with_partial_closing_treated_as_text', + chunks=['content', 'more'], + expected_parts=[ThinkingPart('contentmore')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='more')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='')), + ], + ), +] + +# Category 7: Adding Content to Existing (updates without closing) +ADDING_CONTENT_CASES: list[Case] = [ + Case( + name='existing_thinking_add_more_content', + chunks=['content', 'more'], + expected_parts=[ThinkingPart('contentmore')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='more')), + ], + ), +] + +# Category 8: Whitespace Handling (ignore leading, mixed, not ignore) +WHITESPACE_CASES: list[Case] = [ + Case( + name='new_part_ignore_whitespace_empty', + chunks=[' '], + expected_parts=[], + expected_normal_events=[], + ignore_leading_whitespace=True, + ), + Case( + name='new_part_not_ignore_whitespace', + chunks=[' '], + expected_parts=[TextPart(' ')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart(' ')), + ], + ), + Case( + name='new_part_no_vendor_id_ignore_whitespace_not_empty', + chunks=[' content'], + expected_parts=[TextPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart('content')), + ], + ignore_leading_whitespace=True, + ), + Case( + name='new_part_ignore_whitespace_mixed_with_partial_opening', + chunks=[' '], + expected_parts=[TextPart('')], + expected_normal_events=[], + expected_flushed_events=[ + PartStartEvent(index=0, part=TextPart('')), + ], + ignore_leading_whitespace=True, + ), +] + +# Category 9: No Vendor ID (updates, new after thinking, closings as text) +NO_VENDOR_ID_CASES: list[Case] = [] + +# Category 10: No Thinking Tags (tags treated as text) +NO_THINKING_TAGS_CASES: list[Case] = [ + Case( + name='new_part_tags_as_text_when_thinking_tags_none', + chunks=['content'], + expected_parts=[TextPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=TextPart('content')), + ], + thinking_tags=None, + ) +] + +# Category 11: Buffer Management (stutter, flushed) +BUFFER_MANAGEMENT_CASES: list[Case] = [ + Case( + name='empty_first_chunk_with_buffered_partial_opening_flushed', + chunks=['', 'content'], + expected_parts=[TextPart('content'], + expected_parts=[TextPart('hellocontent', ''], + expected_parts=[ThinkingPart('content')], + expected_normal_events=[ + PartStartEvent(index=0, part=ThinkingPart('content')), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='')), + ], + ), + Case( + name='existing_thinking_fake_partial_closing_added_to_content', + chunks=['content', 'foo', 'bar None: + """ + Parametrized coverage for all cases described in the report. + Tests each case with both vendor_part_id='content' and vendor_part_id=None. + """ + case.vendor_part_id = vendor_part_id + + normal_events, flushed_events, final_parts = stream_text_deltas(case) + + # Parts observed from final state (after all deltas have been applied) + assert final_parts == case.expected_parts, f'\nObserved: {final_parts}\nExpected: {case.expected_parts}' + + # Events observed from streaming during normal processing + assert normal_events == case.expected_normal_events, ( + f'\nObserved: {normal_events}\nExpected: {case.expected_normal_events}' + ) + + # Events observed from final_flush + assert flushed_events == case.expected_flushed_events, ( + f'\nObserved: {flushed_events}\nExpected: {case.expected_flushed_events}' + ) + + +def test_final_flush_with_partial_tag_on_non_latest_part(): + """Test that final_flush properly handles partial tags attached to earlier parts.""" + manager = ModelResponsePartsManager() + + # Create ThinkingPart at index 0 with partial closing tag buffered + for _ in manager.handle_text_delta( + vendor_part_id='thinking', + content='content<', + thinking_tags=('', ''), + ): + pass + + # Create new part at index 1 using different vendor_part_id (makes ThinkingPart non-latest) + # Use tool call to create a different part type + manager.handle_tool_call_delta( + vendor_part_id='tool', + tool_name='my_tool', + args='{}', + ) + + # final_flush should emit PartDeltaEvent to index 0 (non-latest ThinkingPart with buffered '<') + events = list(manager.final_flush()) + assert events == snapshot([PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='<'))]) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 230a19501a..b94cf3e1dd 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -42,7 +42,12 @@ ) from pydantic_ai.agent import AgentRun from pydantic_ai.exceptions import ApprovalRequired, CallDeferred -from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel +from pydantic_ai.models.function import ( + AgentInfo, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, +) from pydantic_ai.models.test import TestModel from pydantic_ai.output import PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, RunUsage @@ -2176,6 +2181,33 @@ async def ret_a(x: str) -> str: ) +async def test_run_stream_finalize_with_incomplete_thinking_tag(): + """Test that incomplete thinking tags are flushed via finalize when using run_stream().""" + + async def stream_with_incomplete_thinking( + _messages: list[ModelMessage], _agent_info: AgentInfo + ) -> AsyncIterator[str]: + yield '', '') + + agent = Agent(function_model) + + events: list[Any] = [] + async for event in agent.run_stream_events('Hello'): + events.append(event) + + assert events == snapshot( + [ + PartStartEvent(index=0, part=TextPart(content=' AsyncIterator[DeltaToolCalls]: assert agent_info.output_tools is not None