Skip to content

Commit ea9b350

Browse files
fix: Invalid code generation with some oneOf and anyOf combinations [#603, #642]. Thanks @jselig-rigetti!
Co-authored-by: Jake Selig <[email protected]>
1 parent 3f1f951 commit ea9b350

File tree

13 files changed

+309
-9
lines changed

13 files changed

+309
-9
lines changed

end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .default import DefaultEndpoints
66
from .location import LocationEndpoints
77
from .parameters import ParametersEndpoints
8+
from .responses import ResponsesEndpoints
89
from .tag1 import Tag1Endpoints
910
from .tests import TestsEndpoints
1011
from .true_ import True_Endpoints
@@ -15,6 +16,10 @@ class MyTestApiClientApi:
1516
def tests(cls) -> Type[TestsEndpoints]:
1617
return TestsEndpoints
1718

19+
@classmethod
20+
def responses(cls) -> Type[ResponsesEndpoints]:
21+
return ResponsesEndpoints
22+
1823
@classmethod
1924
def default(cls) -> Type[DefaultEndpoints]:
2025
return DefaultEndpoints
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
""" Contains methods for accessing the API Endpoints """
2+
3+
import types
4+
5+
from . import post_responses_unions_simple_before_complex
6+
7+
8+
class ResponsesEndpoints:
9+
@classmethod
10+
def post_responses_unions_simple_before_complex(cls) -> types.ModuleType:
11+
"""
12+
Regression test for #603
13+
"""
14+
return post_responses_unions_simple_before_complex

end_to_end_tests/golden-record/my_test_api_client/api/responses/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import Any, Dict, Optional
2+
3+
import httpx
4+
5+
from ...client import Client
6+
from ...models.post_responses_unions_simple_before_complex_response_200 import (
7+
PostResponsesUnionsSimpleBeforeComplexResponse200,
8+
)
9+
from ...types import Response
10+
11+
12+
def _get_kwargs(
13+
*,
14+
client: Client,
15+
) -> Dict[str, Any]:
16+
url = "{}/responses/unions/simple_before_complex".format(client.base_url)
17+
18+
headers: Dict[str, str] = client.get_headers()
19+
cookies: Dict[str, Any] = client.get_cookies()
20+
21+
return {
22+
"method": "post",
23+
"url": url,
24+
"headers": headers,
25+
"cookies": cookies,
26+
"timeout": client.get_timeout(),
27+
}
28+
29+
30+
def _parse_response(*, response: httpx.Response) -> Optional[PostResponsesUnionsSimpleBeforeComplexResponse200]:
31+
if response.status_code == 200:
32+
response_200 = PostResponsesUnionsSimpleBeforeComplexResponse200.from_dict(response.json())
33+
34+
return response_200
35+
return None
36+
37+
38+
def _build_response(*, response: httpx.Response) -> Response[PostResponsesUnionsSimpleBeforeComplexResponse200]:
39+
return Response(
40+
status_code=response.status_code,
41+
content=response.content,
42+
headers=response.headers,
43+
parsed=_parse_response(response=response),
44+
)
45+
46+
47+
def sync_detailed(
48+
*,
49+
client: Client,
50+
) -> Response[PostResponsesUnionsSimpleBeforeComplexResponse200]:
51+
"""Regression test for #603
52+
53+
Returns:
54+
Response[PostResponsesUnionsSimpleBeforeComplexResponse200]
55+
"""
56+
57+
kwargs = _get_kwargs(
58+
client=client,
59+
)
60+
61+
response = httpx.request(
62+
verify=client.verify_ssl,
63+
**kwargs,
64+
)
65+
66+
return _build_response(response=response)
67+
68+
69+
def sync(
70+
*,
71+
client: Client,
72+
) -> Optional[PostResponsesUnionsSimpleBeforeComplexResponse200]:
73+
"""Regression test for #603
74+
75+
Returns:
76+
Response[PostResponsesUnionsSimpleBeforeComplexResponse200]
77+
"""
78+
79+
return sync_detailed(
80+
client=client,
81+
).parsed
82+
83+
84+
async def asyncio_detailed(
85+
*,
86+
client: Client,
87+
) -> Response[PostResponsesUnionsSimpleBeforeComplexResponse200]:
88+
"""Regression test for #603
89+
90+
Returns:
91+
Response[PostResponsesUnionsSimpleBeforeComplexResponse200]
92+
"""
93+
94+
kwargs = _get_kwargs(
95+
client=client,
96+
)
97+
98+
async with httpx.AsyncClient(verify=client.verify_ssl) as _client:
99+
response = await _client.request(**kwargs)
100+
101+
return _build_response(response=response)
102+
103+
104+
async def asyncio(
105+
*,
106+
client: Client,
107+
) -> Optional[PostResponsesUnionsSimpleBeforeComplexResponse200]:
108+
"""Regression test for #603
109+
110+
Returns:
111+
Response[PostResponsesUnionsSimpleBeforeComplexResponse200]
112+
"""
113+
114+
return (
115+
await asyncio_detailed(
116+
client=client,
117+
)
118+
).parsed

end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def _get_kwargs(
5151

5252
params["list_prop"] = json_list_prop
5353

54+
json_union_prop: Union[float, str]
55+
5456
json_union_prop = union_prop
5557

5658
params["union_prop"] = json_union_prop

end_to_end_tests/golden-record/my_test_api_client/api/tests/get_user_list.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _get_kwargs(
4747

4848
params["an_enum_value_with_only_null"] = json_an_enum_value_with_only_null
4949

50+
json_some_date: str
51+
5052
if isinstance(some_date, datetime.date):
5153
json_some_date = some_date.isoformat()
5254
else:

end_to_end_tests/golden-record/my_test_api_client/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
from .model_with_union_property_inlined_fruit_type_0 import ModelWithUnionPropertyInlinedFruitType0
4141
from .model_with_union_property_inlined_fruit_type_1 import ModelWithUnionPropertyInlinedFruitType1
4242
from .none import None_
43+
from .post_responses_unions_simple_before_complex_response_200 import PostResponsesUnionsSimpleBeforeComplexResponse200
44+
from .post_responses_unions_simple_before_complex_response_200a_type_1 import (
45+
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
46+
)
4347
from .test_inline_objects_json_body import TestInlineObjectsJsonBody
4448
from .test_inline_objects_response_200 import TestInlineObjectsResponse200
4549
from .validation_error import ValidationError

end_to_end_tests/golden-record/my_test_api_client/models/a_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def to_dict(self) -> Dict[str, Any]:
7171

7272
an_allof_enum_with_overridden_default = self.an_allof_enum_with_overridden_default.value
7373

74+
a_camel_date_time: str
75+
7476
if isinstance(self.a_camel_date_time, datetime.datetime):
7577
a_camel_date_time = self.a_camel_date_time.isoformat()
7678

@@ -79,6 +81,7 @@ def to_dict(self) -> Dict[str, Any]:
7981

8082
a_date = self.a_date.isoformat()
8183
required_not_nullable = self.required_not_nullable
84+
one_of_models: Union[Any, Dict[str, Any]]
8285

8386
if isinstance(self.one_of_models, FreeFormModel):
8487
one_of_models = self.one_of_models.to_dict()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Any, Dict, List, Type, TypeVar, Union, cast
2+
3+
import attr
4+
5+
from ..models.post_responses_unions_simple_before_complex_response_200a_type_1 import (
6+
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
7+
)
8+
9+
T = TypeVar("T", bound="PostResponsesUnionsSimpleBeforeComplexResponse200")
10+
11+
12+
@attr.s(auto_attribs=True)
13+
class PostResponsesUnionsSimpleBeforeComplexResponse200:
14+
"""
15+
Attributes:
16+
a (Union[PostResponsesUnionsSimpleBeforeComplexResponse200AType1, str]):
17+
"""
18+
19+
a: Union[PostResponsesUnionsSimpleBeforeComplexResponse200AType1, str]
20+
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)
21+
22+
def to_dict(self) -> Dict[str, Any]:
23+
a: Union[Dict[str, Any], str]
24+
25+
if isinstance(self.a, PostResponsesUnionsSimpleBeforeComplexResponse200AType1):
26+
a = self.a.to_dict()
27+
28+
else:
29+
a = self.a
30+
31+
field_dict: Dict[str, Any] = {}
32+
field_dict.update(self.additional_properties)
33+
field_dict.update(
34+
{
35+
"a": a,
36+
}
37+
)
38+
39+
return field_dict
40+
41+
@classmethod
42+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
43+
d = src_dict.copy()
44+
45+
def _parse_a(data: object) -> Union[PostResponsesUnionsSimpleBeforeComplexResponse200AType1, str]:
46+
try:
47+
if not isinstance(data, dict):
48+
raise TypeError()
49+
a_type_1 = PostResponsesUnionsSimpleBeforeComplexResponse200AType1.from_dict(data)
50+
51+
return a_type_1
52+
except: # noqa: E722
53+
pass
54+
return cast(Union[PostResponsesUnionsSimpleBeforeComplexResponse200AType1, str], data)
55+
56+
a = _parse_a(d.pop("a"))
57+
58+
post_responses_unions_simple_before_complex_response_200 = cls(
59+
a=a,
60+
)
61+
62+
post_responses_unions_simple_before_complex_response_200.additional_properties = d
63+
return post_responses_unions_simple_before_complex_response_200
64+
65+
@property
66+
def additional_keys(self) -> List[str]:
67+
return list(self.additional_properties.keys())
68+
69+
def __getitem__(self, key: str) -> Any:
70+
return self.additional_properties[key]
71+
72+
def __setitem__(self, key: str, value: Any) -> None:
73+
self.additional_properties[key] = value
74+
75+
def __delitem__(self, key: str) -> None:
76+
del self.additional_properties[key]
77+
78+
def __contains__(self, key: str) -> bool:
79+
return key in self.additional_properties
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, List, Type, TypeVar
2+
3+
import attr
4+
5+
T = TypeVar("T", bound="PostResponsesUnionsSimpleBeforeComplexResponse200AType1")
6+
7+
8+
@attr.s(auto_attribs=True)
9+
class PostResponsesUnionsSimpleBeforeComplexResponse200AType1:
10+
""" """
11+
12+
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)
13+
14+
def to_dict(self) -> Dict[str, Any]:
15+
16+
field_dict: Dict[str, Any] = {}
17+
field_dict.update(self.additional_properties)
18+
field_dict.update({})
19+
20+
return field_dict
21+
22+
@classmethod
23+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
24+
d = src_dict.copy()
25+
post_responses_unions_simple_before_complex_response_200a_type_1 = cls()
26+
27+
post_responses_unions_simple_before_complex_response_200a_type_1.additional_properties = d
28+
return post_responses_unions_simple_before_complex_response_200a_type_1
29+
30+
@property
31+
def additional_keys(self) -> List[str]:
32+
return list(self.additional_properties.keys())
33+
34+
def __getitem__(self, key: str) -> Any:
35+
return self.additional_properties[key]
36+
37+
def __setitem__(self, key: str, value: Any) -> None:
38+
self.additional_properties[key] = value
39+
40+
def __delitem__(self, key: str) -> None:
41+
del self.additional_properties[key]
42+
43+
def __contains__(self, key: str) -> bool:
44+
return key in self.additional_properties

end_to_end_tests/openapi.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,33 @@
713713
}
714714
}
715715
},
716+
"/responses/unions/simple_before_complex": {
717+
"post": {
718+
"tags": ["responses"],
719+
"description": "Regression test for #603",
720+
"responses": {
721+
"200": {
722+
"description": "A union with simple types before complex ones.",
723+
"content": {
724+
"application/json": {
725+
"schema": {
726+
"type": "object",
727+
"required": ["a"],
728+
"properties": {
729+
"a": {
730+
"oneOf": [
731+
{"type": "string"},
732+
{"type": "object"}
733+
]
734+
}
735+
}
736+
}
737+
}
738+
}
739+
}
740+
}
741+
}
742+
},
716743
"/auth/token_with_cookie": {
717744
"get": {
718745
"tags": [

openapi_python_client/templates/model.py.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ field_dict: Dict[str, Any] = {}
8585
{% endif %}
8686
{% if prop_template and prop_template.transform %}
8787
for prop_name, prop in self.additional_properties.items():
88-
{{ prop_template.transform(model.additional_properties, "prop", "field_dict[prop_name]", multipart=multipart) | indent(4) }}
88+
{{ prop_template.transform(model.additional_properties, "prop", "field_dict[prop_name]", multipart=multipart, declare_type=false) | indent(4) }}
8989
{% elif multipart %}
9090
field_dict.update({
9191
key: (None, str(value).encode(), "text/plain")

0 commit comments

Comments
 (0)