diff --git a/openapi_python_client/parser/properties/__init__.py b/openapi_python_client/parser/properties/__init__.py index c4fe245e0..cece6f033 100644 --- a/openapi_python_client/parser/properties/__init__.py +++ b/openapi_python_client/parser/properties/__init__.py @@ -223,6 +223,8 @@ class UnionProperty(Property): """A property representing a Union (anyOf) of other properties""" inner_properties: List[Property] + discriminator: Optional[oai.openapi_schema_pydantic.Discriminator] = None + template: ClassVar[str] = "union_property.py.jinja" def _get_inner_type_strings(self, json: bool = False) -> Set[str]: @@ -300,6 +302,22 @@ def get_lazy_imports(self, *, prefix: str) -> Set[str]: lazy_imports.update(inner_prop.get_lazy_imports(prefix=prefix)) return lazy_imports + def get_discriminator_value(self, sub_model: ModelProperty) -> str: + """ + Get discriminator's property value for sub_model + """ + if self.discriminator: + if self.discriminator.mapping: + for property_value, schema_path in self.discriminator.mapping.items(): + ref_path = parse_reference_path(schema_path) + if isinstance(ref_path, ParseError): + raise TypeError() + if ref_path in sub_model.roots: + return property_value + else: + return sub_model.get_base_type_string() + raise TypeError() + def _string_based_property( name: str, required: bool, data: oai.Schema, config: Config @@ -509,6 +527,7 @@ def build_union_property( python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix), description=data.description, example=data.example, + discriminator=data.discriminator, ), schemas, ) diff --git a/openapi_python_client/templates/property_templates/model_property.py.jinja b/openapi_python_client/templates/property_templates/model_property.py.jinja index 903aeefaa..55dfaa84c 100644 --- a/openapi_python_client/templates/property_templates/model_property.py.jinja +++ b/openapi_python_client/templates/property_templates/model_property.py.jinja @@ -9,6 +9,7 @@ {% endmacro %} {% macro check_type_for_construct(property, source) %}isinstance({{ source }}, dict){% endmacro %} +{% macro check_discriminator(source, property_name, discriminator_value) %}{{ source }}["{{ property_name }}"] == "{{ discriminator_value }}"{% endmacro %} {% macro transform(property, source, destination, declare_type=True, multipart=False, transform_method="to_dict") %} {% set transformed = source + "." + transform_method + "()" %} diff --git a/openapi_python_client/templates/property_templates/union_property.py.jinja b/openapi_python_client/templates/property_templates/union_property.py.jinja index 4d43fafc0..75536506f 100644 --- a/openapi_python_client/templates/property_templates/union_property.py.jinja +++ b/openapi_python_client/templates/property_templates/union_property.py.jinja @@ -19,6 +19,10 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri try: if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: raise TypeError() + {% if property.discriminator and inner_template.check_discriminator %} + if not {{ inner_template.check_discriminator("data", property.discriminator.propertyName, property.get_discriminator_value(inner_property)) }}: + raise TypeError() + {% endif %} {{ inner_template.construct(inner_property, "data", initial_value="UNSET") | indent(8) }} return {{ inner_property.python_name }} except: # noqa: E722 @@ -28,6 +32,10 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri if not {{ inner_template.check_type_for_construct(inner_property, "data") }}: raise TypeError() {% endif %} + {% if property.discriminator and inner_template.check_discriminator %} + if not {{ inner_template.check_discriminator("data", property.discriminator.propertyName, property.get_discriminator_value(inner_property)) }}: + raise TypeError() + {% endif %} {{ inner_template.construct(inner_property, "data", initial_value="UNSET") | indent(4) }} return {{ inner_property.python_name }} {% endif %}