Skip to content

Commit 62842db

Browse files
committed
Updated pull request based on feedbacks
1 parent 3ea8465 commit 62842db

File tree

5 files changed

+232
-143
lines changed

5 files changed

+232
-143
lines changed

pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
strict: bool | None = None,
2929
prefer_inlined_defs: bool = False,
3030
simplify_nullable_unions: bool = False,
31+
flatten_allof: bool = False,
3132
):
3233
self.schema = schema
3334

@@ -38,6 +39,7 @@ def __init__(
3839

3940
self.prefer_inlined_defs = prefer_inlined_defs
4041
self.simplify_nullable_unions = simplify_nullable_unions
42+
self.flatten_allof = flatten_allof
4143

4244
self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
4345
self.refs_stack: list[str] = []
@@ -77,6 +79,10 @@ def walk(self) -> JsonSchema:
7779
return handled
7880

7981
def _handle(self, schema: JsonSchema) -> JsonSchema:
82+
# Flatten allOf if requested, before processing the schema
83+
if self.flatten_allof:
84+
schema = flatten_allof(schema)
85+
8086
nested_refs = 0
8187
if self.prefer_inlined_defs:
8288
while ref := schema.get('$ref'):

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any, Literal
88

9-
from .._json_schema import JsonSchema, JsonSchemaTransformer, flatten_allof
9+
from .._json_schema import JsonSchema, JsonSchemaTransformer
1010
from . import ModelProfile
1111

1212
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
@@ -167,8 +167,6 @@ def walk(self) -> JsonSchema:
167167
return result
168168

169169
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
170-
# First, flatten object-only allOf constructs to avoid unsupported combinators in strict mode
171-
schema = flatten_allof(schema)
172170
# Remove unnecessary keys
173171
schema.pop('title', None)
174172
schema.pop('$schema', None)

tests/models/test_openai.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,13 +1915,11 @@ class MyModel(BaseModel):
19151915
)
19161916

19171917

1918-
def test_openai_transformer_fallback_when_defs_missing() -> None:
1919-
"""Test that OpenAIJsonSchemaTransformer falls back to original defs when root_key not in $defs."""
1920-
from unittest.mock import patch
1921-
1918+
def test_openai_transformer_with_recursive_ref() -> None:
1919+
"""Test that OpenAIJsonSchemaTransformer correctly handles recursive models with $ref root."""
19221920
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer
19231921

1924-
# Create a schema with $ref pointing to a key that exists in original
1922+
# Create a schema with $ref root (recursive model scenario)
19251923
schema: dict[str, Any] = {
19261924
'$ref': '#/$defs/MyModel',
19271925
'$defs': {
@@ -1934,15 +1932,55 @@ def test_openai_transformer_fallback_when_defs_missing() -> None:
19341932
}
19351933

19361934
transformer = OpenAIJsonSchemaTransformer(schema, strict=True)
1935+
result = transformer.walk()
1936+
1937+
# The transformer should resolve the $ref and use the transformed schema from $defs
1938+
# (not the original self.defs, which was the bug we fixed)
1939+
assert isinstance(result, dict)
1940+
# In strict mode, all properties should be required
1941+
assert 'properties' in result
1942+
assert 'required' in result
1943+
# The transformed schema should have strict mode applied (additionalProperties: False)
1944+
assert result.get('additionalProperties') is False
1945+
# All properties should be in required list (strict mode requirement)
1946+
assert 'foo' in result['required']
1947+
1948+
1949+
def test_openai_transformer_flattens_allof() -> None:
1950+
"""Test that OpenAIJsonSchemaTransformer flattens allOf schemas."""
1951+
from pydantic_ai._json_schema import JsonSchema
1952+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer
1953+
1954+
schema: JsonSchema = {
1955+
'type': 'object',
1956+
'allOf': [
1957+
{
1958+
'type': 'object',
1959+
'properties': {'foo': {'type': 'string'}},
1960+
'required': ['foo'],
1961+
},
1962+
{
1963+
'type': 'object',
1964+
'properties': {'bar': {'type': 'integer'}},
1965+
'required': ['bar'],
1966+
},
1967+
],
1968+
}
1969+
1970+
transformer = OpenAIJsonSchemaTransformer(schema, strict=True)
1971+
transformed = transformer.walk()
19371972

1938-
# Mock super().walk() to return a result without $defs to test the fallback path
1939-
with patch.object(transformer.__class__.__bases__[0], 'walk', return_value={'$ref': '#/$defs/MyModel'}):
1940-
result = transformer.walk()
1941-
# The fallback should use self.defs.get(root_key) or {}
1942-
# Since we mocked to return just $ref, the fallback should trigger
1943-
assert isinstance(result, dict)
1944-
# Result should have been updated with the original defs content
1945-
assert 'properties' in result or 'type' in result
1973+
assert transformed == snapshot(
1974+
{
1975+
'type': 'object',
1976+
'properties': {
1977+
'foo': {'type': 'string'},
1978+
'bar': {'type': 'integer'},
1979+
},
1980+
'required': ['foo', 'bar'],
1981+
'additionalProperties': False,
1982+
}
1983+
)
19461984

19471985

19481986
def test_native_output_strict_mode(allow_model_requests: None):

tests/models/test_openai_like_providers.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

0 commit comments

Comments
 (0)