Skip to content

Commit 47707cd

Browse files
committed
Flatten allOf properties for OpenAI compatibility
1 parent 359c6d2 commit 47707cd

File tree

7 files changed

+301
-4
lines changed

7 files changed

+301
-4
lines changed

docs/models/openai.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ agent = Agent(model)
9999
```
100100

101101
## OpenAI Responses API
102+
### Tips for JSON Schema compatibility
103+
104+
Strict Structured Outputs prefer flat object schemas without combinators. If your tool or output schema contains `allOf`/`oneOf`, consider flattening `allOf` ahead of time. See the Tools documentation section "Flattening allOf for provider compatibility" for a sample prepare hook.
105+
102106

103107
Pydantic AI also supports OpenAI's [Responses API](https://platform.openai.com/docs/api-reference/responses) through the
104108

docs/tools.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,30 @@ print(test_model.last_model_request_parameters.function_tools)
357357
_(This example is complete, it can be run "as is")_
358358

359359

360+
### Flattening allOf for provider compatibility
361+
362+
Some providers (especially when using strict Structured Outputs) reject JSON Schema combinators like `allOf/oneOf`.
363+
If your tool parameter schema includes `allOf`, you can flatten it before sending to the model using a prepare hook.
364+
365+
Example prepare hook that flattens `allOf`:
366+
367+
```python
368+
from pydantic_ai import Agent
369+
from pydantic_ai.tools import RunContext, ToolDefinition
370+
from pydantic_ai._json_schema import flatten_allof
371+
372+
async def flatten_prepare(ctx: RunContext[None], tool: ToolDefinition) -> ToolDefinition:
373+
tool.parameters_json_schema = flatten_allof(tool.parameters_json_schema)
374+
return tool
375+
376+
# Register your tools normally, then pass `prepare_tools=flatten_prepare` to Agent if you want to apply globally.
377+
agent = Agent('openai:gpt-4o', prepare_tools=lambda ctx, tools: [
378+
await flatten_prepare(ctx, t) or t for t in tools
379+
])
380+
```
381+
382+
Alternatively, you can construct tools with a flattened schema at source (e.g., for MCP-exposed tools) using `Tool.from_schema`.
383+
360384
## See Also
361385

362386
For more tool features and integrations, see:

pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from abc import ABC, abstractmethod
55
from copy import deepcopy
66
from dataclasses import dataclass
7-
from typing import Any, Literal
7+
from typing import Any, Literal, cast
88

99
from .exceptions import UserError
1010

1111
JsonSchema = dict[str, Any]
1212

13+
__all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer', 'flatten_allof']
14+
1315

1416
@dataclass(init=False)
1517
class JsonSchemaTransformer(ABC):
@@ -30,7 +32,9 @@ def __init__(
3032
self.schema = schema
3133

3234
self.strict = strict
33-
self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
35+
# Can be set to False by subclasses to set `strict` on `ToolDefinition`
36+
# when not set explicitly by the user.
37+
self.is_strict_compatible = True
3438

3539
self.prefer_inlined_defs = prefer_inlined_defs
3640
self.simplify_nullable_unions = simplify_nullable_unions
@@ -188,3 +192,108 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
188192

189193
def transform(self, schema: JsonSchema) -> JsonSchema:
190194
return schema
195+
196+
197+
def _allof_is_object_like(member: JsonSchema) -> bool:
198+
member_type = member.get('type')
199+
if member_type is None:
200+
keys = ('properties', 'additionalProperties', 'patternProperties')
201+
return bool(any(k in member for k in keys))
202+
return member_type == 'object'
203+
204+
205+
def _merge_additional_properties_values(values: list[Any]) -> bool | JsonSchema:
206+
if any(isinstance(v, dict) for v in values):
207+
return True
208+
return False if values and all(v is False for v in values) else True
209+
210+
211+
def _flatten_current_level(s: JsonSchema) -> JsonSchema:
212+
raw_members = s.get('allOf')
213+
if not isinstance(raw_members, list) or not raw_members:
214+
return s
215+
216+
members = cast(list[JsonSchema], raw_members)
217+
for raw in members:
218+
if not isinstance(raw, dict):
219+
return s
220+
if not all(_allof_is_object_like(member) for member in members):
221+
return s
222+
223+
processed_members = [_recurse_flatten_allof(member) for member in members]
224+
merged: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'}
225+
merged['type'] = 'object'
226+
227+
properties: dict[str, JsonSchema] = {}
228+
if isinstance(merged.get('properties'), dict):
229+
properties.update(merged['properties'])
230+
231+
required: set[str] = set(merged.get('required', []) or [])
232+
pattern_properties: dict[str, JsonSchema] = dict(merged.get('patternProperties', {}) or {})
233+
additional_values: list[Any] = []
234+
235+
for m in processed_members:
236+
if isinstance(m.get('properties'), dict):
237+
properties.update(m['properties'])
238+
if isinstance(m.get('required'), list):
239+
required.update(m['required'])
240+
if isinstance(m.get('patternProperties'), dict):
241+
pattern_properties.update(m['patternProperties'])
242+
if 'additionalProperties' in m:
243+
additional_values.append(m['additionalProperties'])
244+
245+
if properties:
246+
merged['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()}
247+
if required:
248+
merged['required'] = sorted(required)
249+
if pattern_properties:
250+
merged['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()}
251+
252+
if additional_values:
253+
merged['additionalProperties'] = _merge_additional_properties_values(additional_values)
254+
255+
return merged
256+
257+
258+
def _recurse_children(s: JsonSchema) -> JsonSchema:
259+
t = s.get('type')
260+
if t == 'object':
261+
if isinstance(s.get('properties'), dict):
262+
s['properties'] = {
263+
k: _recurse_flatten_allof(cast(JsonSchema, v))
264+
for k, v in s['properties'].items()
265+
if isinstance(v, dict)
266+
}
267+
ap = s.get('additionalProperties')
268+
if isinstance(ap, dict):
269+
ap_schema = cast(JsonSchema, ap)
270+
s['additionalProperties'] = _recurse_flatten_allof(ap_schema)
271+
if isinstance(s.get('patternProperties'), dict):
272+
s['patternProperties'] = {
273+
k: _recurse_flatten_allof(cast(JsonSchema, v))
274+
for k, v in s['patternProperties'].items()
275+
if isinstance(v, dict)
276+
}
277+
elif t == 'array':
278+
items = s.get('items')
279+
if isinstance(items, dict):
280+
s['items'] = _recurse_flatten_allof(cast(JsonSchema, items))
281+
return s
282+
283+
284+
def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema:
285+
s = deepcopy(schema)
286+
s = _flatten_current_level(s)
287+
s = _recurse_children(s)
288+
return s
289+
290+
291+
def flatten_allof(schema: JsonSchema) -> JsonSchema:
292+
"""Flatten simple object-only allOf combinations by merging object members.
293+
294+
- Merges properties and unions required lists.
295+
- Combines additionalProperties conservatively: only False if all are False; otherwise True.
296+
- Recurses into nested object/array members.
297+
- Leaves non-object allOfs untouched.
298+
"""
299+
return _recurse_flatten_allof(schema)

pydantic_ai_slim/pydantic_ai/profiles/google.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic_ai.exceptions import UserError
66

7-
from .._json_schema import JsonSchema, JsonSchemaTransformer
7+
from .._json_schema import JsonSchema, JsonSchemaTransformer, flatten_allof
88
from . import ModelProfile
99

1010

@@ -35,6 +35,8 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
3535
super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True)
3636

3737
def transform(self, schema: JsonSchema) -> JsonSchema:
38+
# Flatten object-only allOf to improve compatibility
39+
schema = flatten_allof(schema)
3840
# Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
3941
additional_properties = schema.pop(
4042
'additionalProperties', None

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any, Literal
88

9-
from .._json_schema import JsonSchema, JsonSchemaTransformer
9+
from .._json_schema import JsonSchema, JsonSchemaTransformer, flatten_allof
1010
from . import ModelProfile
1111

1212
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
@@ -162,6 +162,8 @@ def walk(self) -> JsonSchema:
162162
return result
163163

164164
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
165+
# First, flatten object-only allOf constructs to avoid unsupported combinators in strict mode
166+
schema = flatten_allof(schema)
165167
# Remove unnecessary keys
166168
schema.pop('title', None)
167169
schema.pop('$schema', None)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
5+
import pytest
6+
7+
from pydantic_ai._json_schema import JsonSchema
8+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer
9+
10+
11+
def _openai_transformer_factory(schema: JsonSchema) -> OpenAIJsonSchemaTransformer:
12+
return OpenAIJsonSchemaTransformer(schema, strict=True)
13+
14+
15+
TransformerFactory = Callable[[JsonSchema], OpenAIJsonSchemaTransformer]
16+
17+
18+
@pytest.mark.parametrize('transformer_factory', [_openai_transformer_factory])
19+
def test_openai_compatible_transformers_flatten_allof(
20+
transformer_factory: TransformerFactory,
21+
) -> None:
22+
schema: JsonSchema = {
23+
'type': 'object',
24+
'allOf': [
25+
{
26+
'type': 'object',
27+
'properties': {'foo': {'type': 'string'}},
28+
'required': ['foo'],
29+
},
30+
{
31+
'type': 'object',
32+
'properties': {'bar': {'type': 'integer'}},
33+
'required': ['bar'],
34+
},
35+
],
36+
}
37+
38+
transformer = transformer_factory(schema)
39+
transformed = transformer.walk()
40+
41+
# allOf should have been flattened by the transformer
42+
assert 'allOf' not in transformed
43+
assert transformed['type'] == 'object'
44+
assert set(transformed.get('required', [])) == {'foo', 'bar'}
45+
assert transformed['properties']['foo']['type'] == 'string'
46+
assert transformed['properties']['bar']['type'] == 'integer'
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from typing import Any
5+
6+
7+
def test_flatten_allof_simple_merge() -> None:
8+
# Import inline to avoid import errors before implementation exists in editors
9+
from pydantic_ai._json_schema import flatten_allof
10+
11+
schema: dict[str, Any] = {
12+
'type': 'object',
13+
'allOf': [
14+
{
15+
'type': 'object',
16+
'properties': {'a': {'type': 'string'}},
17+
'required': ['a'],
18+
'additionalProperties': False,
19+
},
20+
{
21+
'type': 'object',
22+
'properties': {'b': {'type': 'integer'}},
23+
'required': ['b'],
24+
'additionalProperties': False,
25+
},
26+
],
27+
}
28+
29+
flattened = flatten_allof(copy.deepcopy(schema))
30+
31+
assert 'allOf' not in flattened
32+
assert flattened['type'] == 'object'
33+
assert flattened['properties']['a']['type'] == 'string'
34+
assert flattened['properties']['b']['type'] == 'integer'
35+
# union of required keys
36+
assert set(flattened['required']) == {'a', 'b'}
37+
# boolean AP should remain False when all are False
38+
assert flattened.get('additionalProperties') is False
39+
40+
41+
def test_flatten_allof_nested_objects_and_pass_through_keywords() -> None:
42+
from pydantic_ai._json_schema import flatten_allof
43+
44+
schema: dict[str, Any] = {
45+
'type': 'object',
46+
'title': 'Root',
47+
'allOf': [
48+
{
49+
'type': 'object',
50+
'properties': {
51+
'user': {
52+
'type': 'object',
53+
'properties': {'id': {'type': 'string'}},
54+
'required': ['id'],
55+
}
56+
},
57+
'required': ['user'],
58+
},
59+
{
60+
'type': 'object',
61+
'properties': {'age': {'type': 'integer'}},
62+
'required': ['age'],
63+
},
64+
],
65+
'description': 'test',
66+
}
67+
68+
flattened = flatten_allof(copy.deepcopy(schema))
69+
assert flattened.get('title') == 'Root'
70+
assert flattened.get('description') == 'test'
71+
assert 'allOf' not in flattened
72+
assert set(flattened['required']) == {'user', 'age'}
73+
assert flattened['properties']['user']['type'] == 'object'
74+
assert set(flattened['properties']['user']['required']) == {'id'}
75+
76+
77+
def test_flatten_allof_does_not_touch_unrelated_unions() -> None:
78+
from pydantic_ai._json_schema import flatten_allof
79+
80+
schema: dict[str, Any] = {
81+
'type': 'object',
82+
'properties': {
83+
'x': {
84+
'anyOf': [
85+
{'type': 'string'},
86+
{'type': 'null'},
87+
]
88+
}
89+
},
90+
'required': ['x'],
91+
}
92+
93+
flattened = flatten_allof(copy.deepcopy(schema))
94+
assert flattened['properties']['x'].get('anyOf') is not None
95+
96+
97+
def test_flatten_allof_non_object_members_are_left_as_is() -> None:
98+
from pydantic_ai._json_schema import flatten_allof
99+
100+
schema: dict[str, Any] = {
101+
'type': 'object',
102+
'allOf': [
103+
{'type': 'string'},
104+
{'type': 'integer'},
105+
],
106+
}
107+
108+
# Expect: we cannot sensibly merge non-object members; keep allOf
109+
flattened = flatten_allof(copy.deepcopy(schema))
110+
assert 'allOf' in flattened

0 commit comments

Comments
 (0)