Skip to content
Merged
1 change: 0 additions & 1 deletion .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
- uses: actions/setup-python@v3
- name: Install dependencies
run: |
pip install sphinx sphinx_rtd_theme
pip install -r docs/requirements.txt
- name: Sphinx build
run: |
Expand Down
18 changes: 17 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ Changelog
#########


**2.3.1**
*********

Pydantic validators inheritance fix
====================================

* fix schema validators passthrough `#45 <https://github.com/mts-ai/FastAPI-JSONAPI/pull/45>`_
* fix doc build

Authors
"""""""

* `@CosmoV`_
* `@mahenzon`_


**2.3.0**
*********

Expand All @@ -15,9 +31,9 @@ Current Atomic Operation context var
Authors
"""""""


* `@mahenzon`_


**2.2.2**
*********

Expand Down
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
fastapi<0.100.0
pydantic<2
simplejson>=3.17.6
sphinx
sphinx_rtd_theme
sqlalchemy<2
tortoise-orm>=0.19.3
122 changes: 11 additions & 111 deletions fastapi_jsonapi/schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)

import pydantic
from pydantic import BaseConfig, root_validator, validator
from pydantic import BaseConfig
from pydantic import BaseModel as PydanticBaseModel
from pydantic.class_validators import ROOT_VALIDATOR_CONFIG_KEY, VALIDATOR_CONFIG_KEY
from pydantic.fields import FieldInfo, ModelField, Validator
from pydantic.fields import FieldInfo, ModelField

from fastapi_jsonapi.data_typing import TypeSchema
from fastapi_jsonapi.schema import (
Expand All @@ -35,6 +33,10 @@
)
from fastapi_jsonapi.schema_base import BaseModel, Field, RelationshipInfo, registry
from fastapi_jsonapi.splitter import SPLIT_REL
from fastapi_jsonapi.validation_utils import (
extract_field_validators,
extract_validators,
)

JSON_API_RESPONSE_TYPE = Dict[Union[int, str], Dict[str, Any]]

Expand Down Expand Up @@ -291,7 +293,10 @@ def _get_info_from_schema_for_building(
# works both for to-one and to-many
included_schemas.append((name, field.type_, relationship.resource_type))
elif name == "id":
id_validators = self._extract_field_validators(schema, target_field_name="id")
id_validators = extract_field_validators(
schema,
include_for_field_names={"id"},
)
resource_id_field = (*(resource_id_field[:-1]), id_validators)

if not field.field_info.extra.get("client_can_set_id"):
Expand All @@ -310,7 +315,7 @@ class ConfigOrmMode(BaseConfig):
f"{base_name}AttributesJSONAPI",
**attributes_schema_fields,
__config__=ConfigOrmMode,
__validators__=self._extract_validators(schema, exclude_for_field_names={"id"}),
__validators__=extract_validators(schema, exclude_for_field_names={"id"}),
)

relationships_schema = pydantic.create_model(
Expand Down Expand Up @@ -378,111 +383,6 @@ def create_relationship_data_schema(
self.relationship_schema_cache[cache_key] = relationship_data_schema
return relationship_data_schema

def _is_target_validator(self, attr_name: str, value: Any, validator_config_key: str) -> bool:
"""
True if passed object is validator of type identified by "validator_config_key" arg

:param attr_name:
:param value:
:param validator_config_key: Choice field, available options are pydantic consts
VALIDATOR_CONFIG_KEY, ROOT_VALIDATOR_CONFIG_KEY
"""
return (
# also with private items
not attr_name.startswith("__")
and getattr(value, validator_config_key, None)
)

def _unpack_validators(self, model: Type[BaseModel], validator_config_key: str) -> Dict[str, Validator]:
"""
Selects all validators from model attrs and unpack them from class methods

:param model: Type[BaseModel]
:param validator_config_key: Choice field, available options are pydantic consts
VALIDATOR_CONFIG_KEY, ROOT_VALIDATOR_CONFIG_KEY
"""
root_validator_class_methods = {
# validators only
attr_name: value
for attr_name, value in model.__dict__.items()
if self._is_target_validator(attr_name, value, validator_config_key)
}

return {
validator_name: getattr(validator_method, validator_config_key)
for validator_name, validator_method in root_validator_class_methods.items()
}

def _extract_root_validators(self, model: Type[BaseModel]) -> Dict[str, Callable]:
validators = {}

unpacked_validators = self._unpack_validators(model, ROOT_VALIDATOR_CONFIG_KEY)
for validator_name, validator_instance in unpacked_validators.items():
validators[validator_name] = root_validator(
pre=validator_instance.pre,
skip_on_failure=validator_instance.skip_on_failure,
allow_reuse=True,
)(validator_instance.func)

return validators

def _extract_field_validators(
self,
model: Type[BaseModel],
target_field_name: str = None,
exclude_for_field_names: Set[str] = None,
) -> Dict[str, Callable]:
"""
:param model: Type[BaseModel]
:param target_field_name: Name of field for which validators will be returned.
If not set the function will return validators for all fields.
"""
validators = {}
validator_origin_param_keys = ("pre", "each_item", "always", "check_fields")

unpacked_validators = self._unpack_validators(model, VALIDATOR_CONFIG_KEY)
for validator_name, (field_names, validator_instance) in unpacked_validators.items():
if target_field_name and target_field_name not in field_names:
continue
elif target_field_name:
field_names = [target_field_name] # noqa: PLW2901

if exclude_for_field_names:
field_names = [ # noqa: PLW2901
# filter names
field_name
for field_name in field_names
if field_name not in exclude_for_field_names
]

if not field_names:
continue

validators[validator_name] = validator(
*field_names,
allow_reuse=True,
**{
# copy origin params
param_name: getattr(validator_instance, param_name)
for param_name in validator_origin_param_keys
},
)(validator_instance.func)

return validators

def _extract_validators(
self,
model: Type[BaseModel],
exclude_for_field_names: Set[str] = None,
) -> Dict[str, Callable]:
return {
**self._extract_field_validators(
model,
exclude_for_field_names=exclude_for_field_names,
),
**self._extract_root_validators(model),
}

def _build_jsonapi_object(
self,
base_name: str,
Expand Down
123 changes: 123 additions & 0 deletions fastapi_jsonapi/validation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from copy import deepcopy
from typing import (
Callable,
Dict,
Set,
Type,
)

from pydantic import (
class_validators,
root_validator,
validator,
)
from pydantic.fields import Validator
from pydantic.utils import unique_list

from fastapi_jsonapi.schema_base import BaseModel


def extract_root_validators(model: Type[BaseModel]) -> Dict[str, Callable]:
pre_rv_new, post_rv_new = class_validators.extract_root_validators(model.__dict__)
pre_root_validators = unique_list(
model.__pre_root_validators__ + pre_rv_new,
name_factory=lambda v: v.__name__,
)
post_root_validators = unique_list(
model.__post_root_validators__ + post_rv_new,
name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__,
)

result_validators = {}

for validator_func in pre_root_validators:
result_validators[validator_func.__name__] = root_validator(
pre=True,
allow_reuse=True,
)(validator_func)

for skip_on_failure, validator_func in post_root_validators:
result_validators[validator_func.__name__] = root_validator(
allow_reuse=True,
skip_on_failure=skip_on_failure,
)(validator_func)

return result_validators


def _deduplicate_field_validators(validators: Dict) -> Dict:
result_validators = {}

for field_name, field_validators in validators.items():
result_validators[field_name] = list(
{
# override in definition order
field_validator.func.__name__: field_validator
for field_validator in field_validators
}.values(),
)

return result_validators


def extract_field_validators(
model: Type[BaseModel],
*,
include_for_field_names: Set[str] = None,
exclude_for_field_names: Set[str] = None,
):
validators = class_validators.inherit_validators(
class_validators.extract_validators(model.__dict__),
deepcopy(model.__validators__),
)
validators = _deduplicate_field_validators(validators)
validator_origin_param_keys = (
"pre",
"each_item",
"always",
"check_fields",
)

exclude_for_field_names = exclude_for_field_names or set()

if include_for_field_names and exclude_for_field_names:
include_for_field_names = include_for_field_names.difference(
exclude_for_field_names,
)

result_validators = {}
for field_name, field_validators in validators.items():
if field_name in exclude_for_field_names:
continue

if include_for_field_names and field_name not in include_for_field_names:
continue

field_validator: Validator
for field_validator in field_validators:
validator_name = f"{field_name}_{field_validator.func.__name__}_validator"
validator_params = {
# copy validator params
param_key: getattr(field_validator, param_key)
for param_key in validator_origin_param_keys
}
result_validators[validator_name] = validator(
field_name,
**validator_params,
allow_reuse=True,
)(field_validator.func)

return result_validators


def extract_validators(
model: Type[BaseModel],
exclude_for_field_names: Set[str] = None,
) -> Dict[str, Callable]:
return {
**extract_field_validators(
model,
exclude_for_field_names=exclude_for_field_names,
),
**extract_root_validators(model),
}
17 changes: 17 additions & 0 deletions tests/fixtures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ParentToChildAssociation,
Post,
PostComment,
Task,
User,
UserBio,
)
Expand All @@ -36,6 +37,9 @@
PostInSchema,
PostPatchSchema,
PostSchema,
TaskInSchema,
TaskPatchSchema,
TaskSchema,
UserBioSchema,
UserInSchema,
UserPatchSchema,
Expand Down Expand Up @@ -161,6 +165,19 @@ def add_routers(app_plain: FastAPI):
schema_in_post=ComputerInSchema,
)

RoutersJSONAPI(
router=router,
path="/tasks",
tags=["Task"],
class_detail=DetailViewBaseGeneric,
class_list=ListViewBaseGeneric,
model=Task,
schema=TaskSchema,
resource_type="task",
schema_in_patch=TaskPatchSchema,
schema_in_post=TaskInSchema,
)

atomic = AtomicOperations()

app_plain.include_router(router, prefix="")
Expand Down
9 changes: 9 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ def __repr__(self):
return f"{self.__class__.__name__}(id={self.id}, name={self.name!r}, user_id={self.user_id})"


class Task(Base):
__tablename__ = "tasks"
id = Column(Integer, primary_key=True)
task_ids = Column(JSON, nullable=True, unique=False)


# uuid below


class CustomUUIDType(TypeDecorator):
cache_ok = True

Expand Down
Loading