Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Literal, cast

from .exceptions import UserError

JsonSchema = dict[str, Any]

__all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer', 'flatten_allof']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make flatten_allof private and not export it



@dataclass(init=False)
class JsonSchemaTransformer(ABC):
Expand All @@ -26,14 +28,18 @@ def __init__(
strict: bool | None = None,
prefer_inlined_defs: bool = False,
simplify_nullable_unions: bool = False,
flatten_allof: bool = False,
):
self.schema = schema

self.strict = strict
self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
# Can be set to False by subclasses to set `strict` on `ToolDefinition`
# when not set explicitly by the user.
self.is_strict_compatible = True

self.prefer_inlined_defs = prefer_inlined_defs
self.simplify_nullable_unions = simplify_nullable_unions
self.flatten_allof = flatten_allof

self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
self.refs_stack: list[str] = []
Expand Down Expand Up @@ -73,6 +79,10 @@ def walk(self) -> JsonSchema:
return handled

def _handle(self, schema: JsonSchema) -> JsonSchema:
# Flatten allOf if requested, before processing the schema
if self.flatten_allof:
schema = flatten_allof(schema)

nested_refs = 0
if self.prefer_inlined_defs:
while ref := schema.get('$ref'):
Expand Down Expand Up @@ -188,3 +198,108 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):

def transform(self, schema: JsonSchema) -> JsonSchema:
return schema


def _allof_is_object_like(member: JsonSchema) -> bool:
member_type = member.get('type')
if member_type is None:
keys = ('properties', 'additionalProperties', 'patternProperties')
return bool(any(k in member for k in keys))
return member_type == 'object'


def _merge_additional_properties_values(values: list[Any]) -> bool | JsonSchema:
if any(isinstance(v, dict) for v in values):
return True
return False if values and all(v is False for v in values) else True


def _flatten_current_level(s: JsonSchema) -> JsonSchema:
raw_members = s.get('allOf')
if not isinstance(raw_members, list) or not raw_members:
return s

members = cast(list[JsonSchema], raw_members)
for raw in members:
if not isinstance(raw, dict):
return s
if not all(_allof_is_object_like(member) for member in members):
return s

processed_members = [_recurse_flatten_allof(member) for member in members]
merged: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'}
merged['type'] = 'object'

properties: dict[str, JsonSchema] = {}
if isinstance(merged.get('properties'), dict):
properties.update(merged['properties'])

required: set[str] = set(merged.get('required', []) or [])
pattern_properties: dict[str, JsonSchema] = dict(merged.get('patternProperties', {}) or {})
additional_values: list[Any] = []

for m in processed_members:
if isinstance(m.get('properties'), dict):
properties.update(m['properties'])
if isinstance(m.get('required'), list):
required.update(m['required'])
if isinstance(m.get('patternProperties'), dict):
pattern_properties.update(m['patternProperties'])
if 'additionalProperties' in m:
additional_values.append(m['additionalProperties'])

if properties:
merged['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()}
if required:
merged['required'] = sorted(required)
if pattern_properties:
merged['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()}

if additional_values:
merged['additionalProperties'] = _merge_additional_properties_values(additional_values)

return merged


def _recurse_children(s: JsonSchema) -> JsonSchema:
t = s.get('type')
if t == 'object':
if isinstance(s.get('properties'), dict):
s['properties'] = {
k: _recurse_flatten_allof(cast(JsonSchema, v))
for k, v in s['properties'].items()
if isinstance(v, dict)
}
ap = s.get('additionalProperties')
if isinstance(ap, dict):
ap_schema = cast(JsonSchema, ap)
s['additionalProperties'] = _recurse_flatten_allof(ap_schema)
if isinstance(s.get('patternProperties'), dict):
s['patternProperties'] = {
k: _recurse_flatten_allof(cast(JsonSchema, v))
for k, v in s['patternProperties'].items()
if isinstance(v, dict)
}
elif t == 'array':
items = s.get('items')
if isinstance(items, dict):
s['items'] = _recurse_flatten_allof(cast(JsonSchema, items))
return s


def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema:
s = deepcopy(schema)
s = _flatten_current_level(s)
s = _recurse_children(s)
return s


def flatten_allof(schema: JsonSchema) -> JsonSchema:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not need this method at all if the JSON transformer can call _recurse_flatten_allof directly

"""Flatten simple object-only allOf combinations by merging object members.

- Merges properties and unions required lists.
- Combines additionalProperties conservatively: only False if all are False; otherwise True.
- Recurses into nested object/array members.
- Leaves non-object allOfs untouched.
"""
return _recurse_flatten_allof(schema)
9 changes: 7 additions & 2 deletions pydantic_ai_slim/pydantic_ai/profiles/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
"""

def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict)
super().__init__(schema, strict=strict, flatten_allof=True)
self.root_ref = schema.get('$ref')

def walk(self) -> JsonSchema:
Expand All @@ -157,7 +157,12 @@ def walk(self) -> JsonSchema:
if self.root_ref is not None:
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
result.update(self.defs.get(root_key) or {})
# Use the transformed schema from $defs, not the original self.defs
if '$defs' in result and root_key in result['$defs']:
result.update(result['$defs'][root_key])
else:
# Fallback to original if transformed version not available (shouldn't happen in normal flow)
result.update(self.defs.get(root_key) or {})

return result

Expand Down
105 changes: 105 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,111 @@ class MyModel(BaseModel):
)


def test_openai_transformer_with_recursive_ref() -> None:
"""Test that OpenAIJsonSchemaTransformer correctly handles recursive models with $ref root."""
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer

# Create a schema with $ref root (recursive model scenario)
schema: dict[str, Any] = {
'$ref': '#/$defs/MyModel',
'$defs': {
'MyModel': {
'type': 'object',
'properties': {'foo': {'type': 'string'}},
'required': ['foo'],
},
},
}

transformer = OpenAIJsonSchemaTransformer(schema, strict=True)
result = transformer.walk()

# The transformer should resolve the $ref and use the transformed schema from $defs
# (not the original self.defs, which was the bug we fixed)
assert isinstance(result, dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use snapshot, I want to see the entire result

# In strict mode, all properties should be required
assert 'properties' in result
assert 'required' in result
# The transformed schema should have strict mode applied (additionalProperties: False)
assert result.get('additionalProperties') is False
# All properties should be in required list (strict mode requirement)
assert 'foo' in result['required']


def test_openai_transformer_fallback_when_defs_missing() -> None:
"""Test fallback path when root_key is not in result['$defs'] (line 165).

This tests the safety net fallback that shouldn't happen in normal flow.
The fallback uses self.defs (original schema) when the transformed $defs
doesn't contain the root_key. This edge case is simulated using a mock.
"""
from unittest.mock import patch

from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer

schema: dict[str, Any] = {
'$ref': '#/$defs/MyModel',
'$defs': {
'MyModel': {
'type': 'object',
'properties': {'foo': {'type': 'string'}},
'required': ['foo'],
},
},
}

transformer = OpenAIJsonSchemaTransformer(schema, strict=True)

# Simulate edge case: super().walk() returns $defs without root_key
# This shouldn't happen in normal flow, but we test the fallback path
with patch.object(
transformer.__class__.__bases__[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel comfortable with this; if we can't come up with a real schema that gets us to that line, do we need it at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM that was the old code before my change so I kept it as a fallback.
If you think is not anymore necessary we can remove the fallback.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pulphix If I understand the new code correctly, if a ref was in self.defs previously, it's now guaranteed to be in result['$defs'], so the change should be fine without the fallback. Although we could keep the or {} that we had previously just to be sure.

We use self.defs in a few more places though, so if we now can't rely on that anymore, should we also update those other uses? Or actually set self.defs to the updated value after flattening? It could be worth verifying the interaction between the new flag and prefer_inlined_defs or recursive_defs.

'walk',
return_value={'$defs': {'OtherModel': {'type': 'object'}}},
):
result = transformer.walk()
# Fallback should use self.defs.get(root_key) which contains MyModel
assert isinstance(result, dict)
assert 'properties' in result or 'type' in result


def test_openai_transformer_flattens_allof() -> None:
"""Test that OpenAIJsonSchemaTransformer flattens allOf schemas."""
from pydantic_ai._json_schema import JsonSchema
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer

schema: JsonSchema = {
'type': 'object',
'allOf': [
{
'type': 'object',
'properties': {'foo': {'type': 'string'}},
'required': ['foo'],
},
{
'type': 'object',
'properties': {'bar': {'type': 'integer'}},
'required': ['bar'],
},
],
}

transformer = OpenAIJsonSchemaTransformer(schema, strict=True)
transformed = transformer.walk()

assert transformed == snapshot(
{
'type': 'object',
'properties': {
'foo': {'type': 'string'},
'bar': {'type': 'integer'},
},
'required': ['foo', 'bar'],
'additionalProperties': False,
}
)


def test_native_output_strict_mode(allow_model_requests: None):
class CityLocation(BaseModel):
city: str
Expand Down
Loading
Loading