diff --git a/.changeset/support_all_text_content_types_in_responses.md b/.changeset/support_all_text_content_types_in_responses.md new file mode 100644 index 000000000..36c06a97f --- /dev/null +++ b/.changeset/support_all_text_content_types_in_responses.md @@ -0,0 +1,11 @@ +--- +default: minor +--- + +# Support all `text/*` content types in responses + +Within an API response, any content type which starts with `text/` will now be treated the same as `text/html` already was—they will return the `response.text` attribute from the [httpx Response](https://www.python-httpx.org/api/#response). + +Thanks to @fdintino for the initial implementation, and thanks for the discussions from @kairntech, @rubenfiszel, and @antoneladestito. + +Closes #797 and #821. diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py index c2e39c16a..353c41b9b 100644 --- a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py @@ -2,7 +2,7 @@ import types -from . import post_responses_unions_simple_before_complex +from . import post_responses_unions_simple_before_complex, text_response class ResponsesEndpoints: @@ -12,3 +12,10 @@ def post_responses_unions_simple_before_complex(cls) -> types.ModuleType: Regression test for #603 """ return post_responses_unions_simple_before_complex + + @classmethod + def text_response(cls) -> types.ModuleType: + """ + Text Response + """ + return text_response diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/responses/text_response.py b/end_to_end_tests/golden-record/my_test_api_client/api/responses/text_response.py new file mode 100644 index 000000000..ce3f87e78 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/api/responses/text_response.py @@ -0,0 +1,118 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...types import Response + + +def _get_kwargs() -> Dict[str, Any]: + return { + "method": "post", + "url": "/responses/text", + } + + +def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[str]: + if response.status_code == HTTPStatus.OK: + response_200 = response.text + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[str]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[str]: + """Text Response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[str] + """ + + kwargs = _get_kwargs() + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[str]: + """Text Response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + str + """ + + return sync_detailed( + client=client, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[str]: + """Text Response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[str] + """ + + kwargs = _get_kwargs() + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[str]: + """Text Response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + str + """ + + return ( + await asyncio_detailed( + client=client, + ) + ).parsed diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/callback_test.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/callback_test.py index e46cf0e56..643e9c0f6 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/callback_test.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/callback_test.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Union import httpx @@ -27,7 +27,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py index f958a03e1..8d1702b71 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py @@ -1,6 +1,6 @@ import datetime from http import HTTPStatus -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union import httpx from dateutil.parser import isoparse @@ -94,7 +94,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py index e22287b08..bbfa1b885 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Union import httpx @@ -32,7 +32,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py index 995c4c4d6..e146b2ade 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Union import httpx @@ -27,7 +27,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py index 60b436985..30a90393b 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Union import httpx @@ -27,7 +27,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_multiple_files_tests_upload_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_multiple_files_tests_upload_post.py index 9b62342f2..a5cb511d1 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_multiple_files_tests_upload_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_multiple_files_tests_upload_post.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union import httpx @@ -30,7 +30,7 @@ def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response ) -> Optional[Union[Any, HTTPValidationError]]: if response.status_code == HTTPStatus.OK: - response_200 = cast(Any, response.json()) + response_200 = response.json() return response_200 if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: response_422 = HTTPValidationError.from_dict(response.json()) diff --git a/end_to_end_tests/openapi.json b/end_to_end_tests/openapi.json index 9c4334d46..2402da17c 100644 --- a/end_to_end_tests/openapi.json +++ b/end_to_end_tests/openapi.json @@ -781,6 +781,27 @@ } } }, + "/responses/text": { + "post": { + "tags": [ + "responses" + ], + "summary": "Text Response", + "operationId": "text_response", + "responses": { + "200": { + "description": "Text response", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + } + } + } + }, "/auth/token_with_cookie": { "get": { "tags": [ diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index dc6e9f47b..94a42998e 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -46,7 +46,11 @@ class EndpointCollection: @staticmethod def from_data( - *, data: Dict[str, oai.PathItem], schemas: Schemas, parameters: Parameters, config: Config + *, + data: Dict[str, oai.PathItem], + schemas: Schemas, + parameters: Parameters, + config: Config, ) -> Tuple[Dict[utils.PythonIdentifier, "EndpointCollection"], Schemas, Parameters]: """Parse the openapi paths data to get EndpointCollections by tag""" endpoints_by_tag: Dict[utils.PythonIdentifier, EndpointCollection] = {} @@ -72,7 +76,11 @@ def from_data( # Add `PathItem` parameters if not isinstance(endpoint, ParseError): endpoint, schemas, parameters = Endpoint.add_parameters( - endpoint=endpoint, data=path_data, schemas=schemas, parameters=parameters, config=config + endpoint=endpoint, + data=path_data, + schemas=schemas, + parameters=parameters, + config=config, ) if not isinstance(endpoint, ParseError): endpoint = Endpoint.sort_parameters(endpoint=endpoint) @@ -145,7 +153,13 @@ def parse_request_form_body( config=config, ) if isinstance(prop, ModelProperty): - schemas = attr.evolve(schemas, classes_by_name={**schemas.classes_by_name, prop.class_info.name: prop}) + schemas = attr.evolve( + schemas, + classes_by_name={ + **schemas.classes_by_name, + prop.class_info.name: prop, + }, + ) return prop, schemas return None, schemas @@ -167,7 +181,13 @@ def parse_multipart_body( ) if isinstance(prop, ModelProperty): prop = attr.evolve(prop, is_multipart_body=True) - schemas = attr.evolve(schemas, classes_by_name={**schemas.classes_by_name, prop.class_info.name: prop}) + schemas = attr.evolve( + schemas, + classes_by_name={ + **schemas.classes_by_name, + prop.class_info.name: prop, + }, + ) return prop, schemas return None, schemas @@ -178,9 +198,11 @@ def parse_request_json_body( """Return json_body""" json_body = None for content_type, schema in body.content.items(): - content_type = get_content_type(content_type) # noqa: PLW2901 + parsed_content_type = get_content_type(content_type) - if content_type == "application/json" or content_type.endswith("+json"): + if parsed_content_type is not None and ( + parsed_content_type == "application/json" or parsed_content_type.endswith("+json") + ): json_body = schema break @@ -209,7 +231,10 @@ def _add_body( return endpoint, schemas form_body, schemas = Endpoint.parse_request_form_body( - body=data.requestBody, schemas=schemas, parent_name=endpoint.name, config=config + body=data.requestBody, + schemas=schemas, + parent_name=endpoint.name, + config=config, ) if isinstance(form_body, ParseError): @@ -223,7 +248,10 @@ def _add_body( ) json_body, schemas = Endpoint.parse_request_json_body( - body=data.requestBody, schemas=schemas, parent_name=endpoint.name, config=config + body=data.requestBody, + schemas=schemas, + parent_name=endpoint.name, + config=config, ) if isinstance(json_body, ParseError): return ( @@ -236,7 +264,10 @@ def _add_body( ) multipart_body, schemas = Endpoint.parse_multipart_body( - body=data.requestBody, schemas=schemas, parent_name=endpoint.name, config=config + body=data.requestBody, + schemas=schemas, + parent_name=endpoint.name, + config=config, ) if isinstance(multipart_body, ParseError): return ( @@ -285,7 +316,11 @@ def _add_responses( continue response, schemas = response_from_data( - status_code=status_code, data=response_data, schemas=schemas, parent_name=endpoint.name, config=config + status_code=status_code, + data=response_data, + schemas=schemas, + parent_name=endpoint.name, + config=config, ) if isinstance(response, ParseError): detail_suffix = "" if response.detail is None else f" ({response.detail})" @@ -350,7 +385,15 @@ def add_parameters( # noqa: PLR0911, PLR0912 oai.ParameterLocation.HEADER: endpoint.header_parameters, oai.ParameterLocation.COOKIE: endpoint.cookie_parameters, "RESERVED": { # These can't be param names because codegen needs them as vars, the properties don't matter - "client": AnyProperty("client", True, False, None, PythonIdentifier("client", ""), None, None), + "client": AnyProperty( + "client", + True, + False, + None, + PythonIdentifier("client", ""), + None, + None, + ), "url": AnyProperty("url", True, False, None, PythonIdentifier("url", ""), None, None), }, } @@ -393,7 +436,10 @@ def add_parameters( # noqa: PLR0911, PLR0912 if isinstance(prop, ParseError): return ( - ParseError(detail=f"cannot parse parameter of endpoint {endpoint.name}", data=prop.data), + ParseError( + detail=f"cannot parse parameter of endpoint {endpoint.name}", + data=prop.data, + ), schemas, parameters, ) @@ -432,7 +478,8 @@ def add_parameters( # noqa: PLR0911, PLR0912 if prop.python_name in endpoint.used_python_identifiers: return ( ParseError( - detail=f"Parameters with same Python identifier `{prop.python_name}` detected", data=data + detail=f"Parameters with same Python identifier `{prop.python_name}` detected", + data=data, ), schemas, parameters, @@ -465,7 +512,8 @@ def sort_parameters(*, endpoint: "Endpoint") -> Union["Endpoint", ParseError]: parameters_from_path = re.findall(_PATH_PARAM_REGEX, endpoint.path) try: sorted_params = sorted( - endpoint.path_parameters.values(), key=lambda param: parameters_from_path.index(param.name) + endpoint.path_parameters.values(), + key=lambda param: parameters_from_path.index(param.name), ) endpoint.path_parameters = OrderedDict((param.name, param) for param in sorted_params) except ValueError: @@ -506,7 +554,11 @@ def from_data( ) result, schemas, parameters = Endpoint.add_parameters( - endpoint=endpoint, data=data, schemas=schemas, parameters=parameters, config=config + endpoint=endpoint, + data=data, + schemas=schemas, + parameters=parameters, + config=config, ) if isinstance(result, ParseError): return result, schemas, parameters @@ -570,7 +622,9 @@ def from_dict(data: Dict[str, Any], *, config: Config) -> Union["GeneratorData", schemas = build_schemas(components=openapi.components.schemas, schemas=schemas, config=config) if openapi.components and openapi.components.parameters: parameters = build_parameters( - components=openapi.components.parameters, parameters=parameters, config=config + components=openapi.components.parameters, + parameters=parameters, + config=config, ) endpoint_collections_by_tag, schemas, parameters = EndpointCollection.from_data( data=openapi.paths, schemas=schemas, parameters=parameters, config=config diff --git a/openapi_python_client/parser/responses.py b/openapi_python_client/parser/responses.py index 2b41eac8d..97909a40c 100644 --- a/openapi_python_client/parser/responses.py +++ b/openapi_python_client/parser/responses.py @@ -1,7 +1,7 @@ __all__ = ["Response", "response_from_data"] from http import HTTPStatus -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, TypedDict, Union from attrs import define @@ -14,32 +14,53 @@ from .properties import AnyProperty, Property, Schemas, property_from_data +class _ResponseSource(TypedDict): + """What data should be pulled from the httpx Response object""" + + attribute: str + return_type: str + + +JSON_SOURCE = _ResponseSource(attribute="response.json()", return_type="Any") +BYTES_SOURCE = _ResponseSource(attribute="response.content", return_type="bytes") +TEXT_SOURCE = _ResponseSource(attribute="response.text", return_type="str") +NONE_SOURCE = _ResponseSource(attribute="None", return_type="None") + + @define class Response: """Describes a single response for an endpoint""" status_code: HTTPStatus prop: Property - source: str + source: _ResponseSource + +def _source_by_content_type(content_type: str) -> Optional[_ResponseSource]: + parsed_content_type = utils.get_content_type(content_type) + if parsed_content_type is None: + return None -def _source_by_content_type(content_type: str) -> Optional[str]: - content_type = utils.get_content_type(content_type) + if parsed_content_type.startswith("text/"): + return TEXT_SOURCE known_content_types = { - "application/json": "response.json()", - "application/octet-stream": "response.content", - "text/html": "response.text", + "application/json": JSON_SOURCE, + "application/octet-stream": BYTES_SOURCE, } - source = known_content_types.get(content_type) - if source is None and content_type.endswith("+json"): + source = known_content_types.get(parsed_content_type) + if source is None and parsed_content_type.endswith("+json"): # Implements https://www.rfc-editor.org/rfc/rfc6838#section-4.2.8 for the +json suffix - source = "response.json()" + source = JSON_SOURCE return source def empty_response( - *, status_code: HTTPStatus, response_name: str, config: Config, description: Optional[str] + *, + status_code: HTTPStatus, + response_name: str, + config: Config, + description: Optional[str], ) -> Response: """Return an untyped response, for when no response type is defined""" return Response( @@ -53,7 +74,7 @@ def empty_response( description=description, example=None, ), - source="None", + source=NONE_SOURCE, ) @@ -70,7 +91,12 @@ def response_from_data( response_name = f"response_{status_code}" if isinstance(data, oai.Reference): return ( - empty_response(status_code=status_code, response_name=response_name, config=config, description=None), + empty_response( + status_code=status_code, + response_name=response_name, + config=config, + description=None, + ), schemas, ) @@ -78,7 +104,10 @@ def response_from_data( if not content: return ( empty_response( - status_code=status_code, response_name=response_name, config=config, description=data.description + status_code=status_code, + response_name=response_name, + config=config, + description=data.description, ), schemas, ) @@ -89,12 +118,18 @@ def response_from_data( schema_data = media_type.media_type_schema break else: - return ParseError(data=data, detail=f"Unsupported content_type {content}"), schemas + return ( + ParseError(data=data, detail=f"Unsupported content_type {content}"), + schemas, + ) if schema_data is None: return ( empty_response( - status_code=status_code, response_name=response_name, config=config, description=data.description + status_code=status_code, + response_name=response_name, + config=config, + description=data.description, ), schemas, ) diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index c2b738ced..6a9921e8a 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -65,9 +65,11 @@ def _parse_response(*, client: Union[AuthenticatedClient, Client], response: htt if response.status_code == HTTPStatus.{{ response.status_code.name }}: {% if parsed_responses %}{% import "property_templates/" + response.prop.template as prop_template %} {% if prop_template.construct %} - {{ prop_template.construct(response.prop, response.source) | indent(8) }} + {{ prop_template.construct(response.prop, response.source.attribute) | indent(8) }} + {% elif response.source.return_type == response.prop.get_type_string() %} + {{ response.prop.python_name }} = {{ response.source.attribute }} {% else %} - {{ response.prop.python_name }} = cast({{ response.prop.get_type_string() }}, {{ response.source }}) + {{ response.prop.python_name }} = cast({{ response.prop.get_type_string() }}, {{ response.source.attribute }}) {% endif %} return {{ response.prop.python_name }} {% else %} diff --git a/openapi_python_client/utils.py b/openapi_python_client/utils.py index 8d54de096..ea19622c4 100644 --- a/openapi_python_client/utils.py +++ b/openapi_python_client/utils.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import builtins import re from email.message import Message from keyword import iskeyword -from typing import Any, List +from typing import Any DELIMITERS = r"\. _-" @@ -10,21 +12,21 @@ class PythonIdentifier(str): """A snake_case string which has been validated / transformed into a valid identifier for Python""" - def __new__(cls, value: str, prefix: str) -> "PythonIdentifier": + def __new__(cls, value: str, prefix: str) -> PythonIdentifier: new_value = fix_reserved_words(snake_case(sanitize(value))) if not new_value.isidentifier() or value.startswith("_"): new_value = f"{prefix}{new_value}" return str.__new__(cls, new_value) - def __deepcopy__(self, _: Any) -> "PythonIdentifier": + def __deepcopy__(self, _: Any) -> PythonIdentifier: return self class ClassName(str): """A PascalCase string which has been validated / transformed into a valid class name for Python""" - def __new__(cls, value: str, prefix: str) -> "ClassName": + def __new__(cls, value: str, prefix: str) -> ClassName: new_value = fix_reserved_words(pascal_case(sanitize(value))) if not new_value.isidentifier(): @@ -32,7 +34,7 @@ def __new__(cls, value: str, prefix: str) -> "ClassName": new_value = fix_reserved_words(pascal_case(sanitize(value))) return str.__new__(cls, new_value) - def __deepcopy__(self, _: Any) -> "ClassName": + def __deepcopy__(self, _: Any) -> ClassName: return self @@ -41,7 +43,7 @@ def sanitize(value: str) -> str: return re.sub(rf"[^\w{DELIMITERS}]+", "", value) -def split_words(value: str) -> List[str]: +def split_words(value: str) -> list[str]: """Split a string on words and known delimiters""" # We can't guess words if there is no capital letter if any(c.isupper() for c in value): @@ -49,7 +51,10 @@ def split_words(value: str) -> List[str]: return re.findall(rf"[^{DELIMITERS}]+", value) -RESERVED_WORDS = (set(dir(builtins)) | {"self", "true", "false", "datetime"}) - {"type", "id"} +RESERVED_WORDS = (set(dir(builtins)) | {"self", "true", "false", "datetime"}) - { + "type", + "id", +} def fix_reserved_words(value: str) -> str: @@ -97,13 +102,16 @@ def remove_string_escapes(value: str) -> str: return value.replace('"', r"\"") -def get_content_type(content_type: str) -> str: +def get_content_type(content_type: str) -> str | None: """ Given a string representing a content type with optional parameters, returns the content type only """ message = Message() message.add_header("Content-Type", content_type) - content_type = message.get_content_type() + parsed_content_type = message.get_content_type() + if not content_type.startswith(parsed_content_type): + # Always defaults to `text/plain` if it's not recognized. We want to return an error, not default. + return None - return content_type + return parsed_content_type diff --git a/tests/test_parser/test_responses.py b/tests/test_parser/test_responses.py index 11a7ebd66..8a0836fd0 100644 --- a/tests/test_parser/test_responses.py +++ b/tests/test_parser/test_responses.py @@ -3,6 +3,7 @@ import openapi_python_client.schema as oai from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas +from openapi_python_client.parser.responses import JSON_SOURCE, NONE_SOURCE MODULE_NAME = "openapi_python_client.parser.responses" @@ -27,7 +28,7 @@ def test_response_from_data_no_content(any_property_factory): required=True, description="", ), - source="None", + source=NONE_SOURCE, ) @@ -50,7 +51,7 @@ def test_response_from_data_reference(any_property_factory): nullable=False, required=True, ), - source="None", + source=NONE_SOURCE, ) @@ -59,7 +60,11 @@ def test_response_from_data_unsupported_content_type(): data = oai.Response.model_construct(description="", content={"blah": None}) response, schemas = response_from_data( - status_code=200, data=data, schemas=Schemas(), parent_name="parent", config=MagicMock() + status_code=200, + data=data, + schemas=Schemas(), + parent_name="parent", + config=MagicMock(), ) assert response == ParseError(data=data, detail="Unsupported content_type {'blah': None}") @@ -69,10 +74,15 @@ def test_response_from_data_no_content_schema(any_property_factory): from openapi_python_client.parser.responses import Response, response_from_data data = oai.Response.model_construct( - description="", content={"application/vnd.api+json; version=2.2": oai.MediaType.model_construct()} + description="", + content={"application/vnd.api+json; version=2.2": oai.MediaType.model_construct()}, ) response, schemas = response_from_data( - status_code=200, data=data, schemas=Schemas(), parent_name="parent", config=MagicMock() + status_code=200, + data=data, + schemas=Schemas(), + parent_name="parent", + config=MagicMock(), ) assert response == Response( @@ -84,7 +94,7 @@ def test_response_from_data_no_content_schema(any_property_factory): required=True, description=data.description, ), - source="None", + source=NONE_SOURCE, ) @@ -93,17 +103,27 @@ def test_response_from_data_property_error(mocker): property_from_data = mocker.patch.object(responses, "property_from_data", return_value=(PropertyError(), Schemas())) data = oai.Response.model_construct( - description="", content={"application/json": oai.MediaType.model_construct(media_type_schema="something")} + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, ) config = MagicMock() response, schemas = responses.response_from_data( - status_code=400, data=data, schemas=Schemas(), parent_name="parent", config=config + status_code=400, + data=data, + schemas=Schemas(), + parent_name="parent", + config=config, ) assert response == PropertyError() property_from_data.assert_called_once_with( - name="response_400", required=True, data="something", schemas=Schemas(), parent_name="parent", config=config + name="response_400", + required=True, + data="something", + schemas=Schemas(), + parent_name="parent", + config=config, ) @@ -113,19 +133,29 @@ def test_response_from_data_property(mocker, property_factory): prop = property_factory() property_from_data = mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas())) data = oai.Response.model_construct( - description="", content={"application/json": oai.MediaType.model_construct(media_type_schema="something")} + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, ) config = MagicMock() response, schemas = responses.response_from_data( - status_code=400, data=data, schemas=Schemas(), parent_name="parent", config=config + status_code=400, + data=data, + schemas=Schemas(), + parent_name="parent", + config=config, ) assert response == responses.Response( status_code=400, prop=prop, - source="response.json()", + source=JSON_SOURCE, ) property_from_data.assert_called_once_with( - name="response_400", required=True, data="something", schemas=Schemas(), parent_name="parent", config=config + name="response_400", + required=True, + data="something", + schemas=Schemas(), + parent_name="parent", + config=config, )