-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Flatten allOf properties for OpenAI compatibility #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
afddade
8ff4db7
e4aebaf
e5815fd
b1058e9
1244bfd
f143029
eaa685e
cea2b10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,14 @@ | |
| from abc import ABC, abstractmethod | ||
| from copy import deepcopy | ||
| from dataclasses import dataclass | ||
| from typing import Any, Literal | ||
| from typing import Any, Literal, cast | ||
|
|
||
| from .exceptions import UserError | ||
|
|
||
| JsonSchema = dict[str, Any] | ||
|
|
||
| __all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer', 'flatten_allof'] | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class JsonSchemaTransformer(ABC): | ||
|
|
@@ -26,14 +28,18 @@ def __init__( | |
| strict: bool | None = None, | ||
| prefer_inlined_defs: bool = False, | ||
| simplify_nullable_unions: bool = False, | ||
| flatten_allof: bool = False, | ||
| ): | ||
| self.schema = schema | ||
|
|
||
| self.strict = strict | ||
| self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly | ||
| # Can be set to False by subclasses to set `strict` on `ToolDefinition` | ||
| # when not set explicitly by the user. | ||
| self.is_strict_compatible = True | ||
|
|
||
| self.prefer_inlined_defs = prefer_inlined_defs | ||
| self.simplify_nullable_unions = simplify_nullable_unions | ||
| self.flatten_allof = flatten_allof | ||
|
|
||
| self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) | ||
| self.refs_stack: list[str] = [] | ||
|
|
@@ -73,6 +79,10 @@ def walk(self) -> JsonSchema: | |
| return handled | ||
|
|
||
| def _handle(self, schema: JsonSchema) -> JsonSchema: | ||
| # Flatten allOf if requested, before processing the schema | ||
| if self.flatten_allof: | ||
| schema = flatten_allof(schema) | ||
|
|
||
| nested_refs = 0 | ||
| if self.prefer_inlined_defs: | ||
| while ref := schema.get('$ref'): | ||
|
|
@@ -188,3 +198,108 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | |
|
|
||
| def transform(self, schema: JsonSchema) -> JsonSchema: | ||
| return schema | ||
|
|
||
|
|
||
| def _allof_is_object_like(member: JsonSchema) -> bool: | ||
| member_type = member.get('type') | ||
| if member_type is None: | ||
| keys = ('properties', 'additionalProperties', 'patternProperties') | ||
| return bool(any(k in member for k in keys)) | ||
| return member_type == 'object' | ||
|
|
||
|
|
||
| def _merge_additional_properties_values(values: list[Any]) -> bool | JsonSchema: | ||
| if any(isinstance(v, dict) for v in values): | ||
| return True | ||
| return False if values and all(v is False for v in values) else True | ||
|
|
||
|
|
||
| def _flatten_current_level(s: JsonSchema) -> JsonSchema: | ||
| raw_members = s.get('allOf') | ||
| if not isinstance(raw_members, list) or not raw_members: | ||
| return s | ||
|
|
||
| members = cast(list[JsonSchema], raw_members) | ||
| for raw in members: | ||
| if not isinstance(raw, dict): | ||
| return s | ||
| if not all(_allof_is_object_like(member) for member in members): | ||
| return s | ||
|
|
||
| processed_members = [_recurse_flatten_allof(member) for member in members] | ||
| merged: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'} | ||
| merged['type'] = 'object' | ||
|
|
||
| properties: dict[str, JsonSchema] = {} | ||
| if isinstance(merged.get('properties'), dict): | ||
| properties.update(merged['properties']) | ||
|
|
||
| required: set[str] = set(merged.get('required', []) or []) | ||
| pattern_properties: dict[str, JsonSchema] = dict(merged.get('patternProperties', {}) or {}) | ||
| additional_values: list[Any] = [] | ||
|
|
||
| for m in processed_members: | ||
| if isinstance(m.get('properties'), dict): | ||
| properties.update(m['properties']) | ||
| if isinstance(m.get('required'), list): | ||
| required.update(m['required']) | ||
| if isinstance(m.get('patternProperties'), dict): | ||
| pattern_properties.update(m['patternProperties']) | ||
| if 'additionalProperties' in m: | ||
| additional_values.append(m['additionalProperties']) | ||
|
|
||
| if properties: | ||
| merged['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()} | ||
| if required: | ||
| merged['required'] = sorted(required) | ||
| if pattern_properties: | ||
| merged['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()} | ||
|
|
||
| if additional_values: | ||
| merged['additionalProperties'] = _merge_additional_properties_values(additional_values) | ||
|
|
||
| return merged | ||
|
|
||
|
|
||
| def _recurse_children(s: JsonSchema) -> JsonSchema: | ||
| t = s.get('type') | ||
| if t == 'object': | ||
| if isinstance(s.get('properties'), dict): | ||
| s['properties'] = { | ||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||
| for k, v in s['properties'].items() | ||
| if isinstance(v, dict) | ||
| } | ||
| ap = s.get('additionalProperties') | ||
| if isinstance(ap, dict): | ||
| ap_schema = cast(JsonSchema, ap) | ||
| s['additionalProperties'] = _recurse_flatten_allof(ap_schema) | ||
| if isinstance(s.get('patternProperties'), dict): | ||
| s['patternProperties'] = { | ||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||
| for k, v in s['patternProperties'].items() | ||
| if isinstance(v, dict) | ||
| } | ||
| elif t == 'array': | ||
| items = s.get('items') | ||
| if isinstance(items, dict): | ||
| s['items'] = _recurse_flatten_allof(cast(JsonSchema, items)) | ||
| return s | ||
|
|
||
|
|
||
| def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema: | ||
| s = deepcopy(schema) | ||
| s = _flatten_current_level(s) | ||
| s = _recurse_children(s) | ||
| return s | ||
|
|
||
|
|
||
| def flatten_allof(schema: JsonSchema) -> JsonSchema: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may not need this method at all if the JSON transformer can call |
||
| """Flatten simple object-only allOf combinations by merging object members. | ||
|
|
||
| - Merges properties and unions required lists. | ||
| - Combines additionalProperties conservatively: only False if all are False; otherwise True. | ||
| - Recurses into nested object/array members. | ||
| - Leaves non-object allOfs untouched. | ||
| """ | ||
| return _recurse_flatten_allof(schema) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1915,6 +1915,111 @@ class MyModel(BaseModel): | |
| ) | ||
|
|
||
|
|
||
| def test_openai_transformer_with_recursive_ref() -> None: | ||
| """Test that OpenAIJsonSchemaTransformer correctly handles recursive models with $ref root.""" | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||
|
|
||
| # Create a schema with $ref root (recursive model scenario) | ||
| schema: dict[str, Any] = { | ||
| '$ref': '#/$defs/MyModel', | ||
| '$defs': { | ||
| 'MyModel': { | ||
| 'type': 'object', | ||
| 'properties': {'foo': {'type': 'string'}}, | ||
| 'required': ['foo'], | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| transformer = OpenAIJsonSchemaTransformer(schema, strict=True) | ||
| result = transformer.walk() | ||
|
|
||
| # The transformer should resolve the $ref and use the transformed schema from $defs | ||
| # (not the original self.defs, which was the bug we fixed) | ||
| assert isinstance(result, dict) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use |
||
| # In strict mode, all properties should be required | ||
| assert 'properties' in result | ||
| assert 'required' in result | ||
| # The transformed schema should have strict mode applied (additionalProperties: False) | ||
| assert result.get('additionalProperties') is False | ||
| # All properties should be in required list (strict mode requirement) | ||
| assert 'foo' in result['required'] | ||
|
|
||
|
|
||
| def test_openai_transformer_fallback_when_defs_missing() -> None: | ||
| """Test fallback path when root_key is not in result['$defs'] (line 165). | ||
|
|
||
| This tests the safety net fallback that shouldn't happen in normal flow. | ||
| The fallback uses self.defs (original schema) when the transformed $defs | ||
| doesn't contain the root_key. This edge case is simulated using a mock. | ||
| """ | ||
| from unittest.mock import patch | ||
|
|
||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||
|
|
||
| schema: dict[str, Any] = { | ||
| '$ref': '#/$defs/MyModel', | ||
| '$defs': { | ||
| 'MyModel': { | ||
| 'type': 'object', | ||
| 'properties': {'foo': {'type': 'string'}}, | ||
| 'required': ['foo'], | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| transformer = OpenAIJsonSchemaTransformer(schema, strict=True) | ||
|
|
||
| # Simulate edge case: super().walk() returns $defs without root_key | ||
| # This shouldn't happen in normal flow, but we test the fallback path | ||
| with patch.object( | ||
| transformer.__class__.__bases__[0], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel comfortable with this; if we can't come up with a real schema that gets us to that line, do we need it at all?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DouweM that was the old code before my change so I kept it as a fallback.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pulphix If I understand the new code correctly, if a ref was in We use |
||
| 'walk', | ||
| return_value={'$defs': {'OtherModel': {'type': 'object'}}}, | ||
| ): | ||
| result = transformer.walk() | ||
| # Fallback should use self.defs.get(root_key) which contains MyModel | ||
| assert isinstance(result, dict) | ||
| assert 'properties' in result or 'type' in result | ||
|
|
||
|
|
||
| def test_openai_transformer_flattens_allof() -> None: | ||
| """Test that OpenAIJsonSchemaTransformer flattens allOf schemas.""" | ||
| from pydantic_ai._json_schema import JsonSchema | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||
|
|
||
| schema: JsonSchema = { | ||
| 'type': 'object', | ||
| 'allOf': [ | ||
| { | ||
| 'type': 'object', | ||
| 'properties': {'foo': {'type': 'string'}}, | ||
| 'required': ['foo'], | ||
| }, | ||
| { | ||
| 'type': 'object', | ||
| 'properties': {'bar': {'type': 'integer'}}, | ||
| 'required': ['bar'], | ||
| }, | ||
| ], | ||
| } | ||
|
|
||
| transformer = OpenAIJsonSchemaTransformer(schema, strict=True) | ||
| transformed = transformer.walk() | ||
|
|
||
| assert transformed == snapshot( | ||
| { | ||
| 'type': 'object', | ||
| 'properties': { | ||
| 'foo': {'type': 'string'}, | ||
| 'bar': {'type': 'integer'}, | ||
| }, | ||
| 'required': ['foo', 'bar'], | ||
| 'additionalProperties': False, | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def test_native_output_strict_mode(allow_model_requests: None): | ||
| class CityLocation(BaseModel): | ||
| city: str | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make
flatten_allofprivate and not export it