Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 120 additions & 11 deletions fastapi_jsonapi/data_layers/filtering/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
from typing import Any, List, Tuple, Type, Union

from pydantic import BaseModel
import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

from pydantic import BaseConfig, BaseModel
from pydantic.fields import ModelField
from pydantic.validators import _VALIDATORS, find_validators
from sqlalchemy import and_, not_, or_
from sqlalchemy.orm import InstrumentedAttribute, aliased
from sqlalchemy.sql.elements import BinaryExpression

from fastapi_jsonapi.data_layers.shared import create_filters_or_sorts
from fastapi_jsonapi.data_typing import TypeModel, TypeSchema
from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType
from fastapi_jsonapi.exceptions.json_api import HTTPException
from fastapi_jsonapi.schema import get_model_field, get_relationships
from fastapi_jsonapi.splitter import SPLIT_REL
from fastapi_jsonapi.utils.sqla import get_related_model_cls

log = logging.getLogger(__name__)

Filter = BinaryExpression
Join = List[Any]

Expand All @@ -22,6 +36,11 @@
List[Join],
]

# The mapping with validators using by to cast raw value to instance of target type
REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS)

cast_failed = object()


def create_filters(model: Type[TypeModel], filter_info: Union[list, dict], schema: Type[TypeSchema]):
"""
Expand All @@ -48,6 +67,21 @@ def __init__(self, model: Type[TypeModel], filter_: dict, schema: Type[TypeSchem
self.filter_ = filter_
self.schema = schema

def _cast_value_with_scheme(self, field_types: List[ModelField], value: Any) -> Tuple[Any, List[str]]:
errors: List[str] = []
casted_value = cast_failed

for field_type in field_types:
try:
if isinstance(value, list): # noqa: SIM108
casted_value = [field_type(item) for item in value]
else:
casted_value = field_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))

return casted_value, errors

def create_filter(self, schema_field: ModelField, model_column, operator, value):
"""
Create sqlalchemy filter
Expand Down Expand Up @@ -78,19 +112,94 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
types = [i.type_ for i in fields]
clear_value = None
errors: List[str] = []
for i_type in types:
try:
if isinstance(value, list): # noqa: SIM108
clear_value = [i_type(item) for item in value]
else:
clear_value = i_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))

pydantic_types, userspace_types = self._separate_types(types)

if pydantic_types:
if isinstance(value, list):
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
else:
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)

if clear_value is None and userspace_types:
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")

clear_value, errors = self._cast_value_with_scheme(types, value)

if clear_value is cast_failed:
raise InvalidType(
detail=f"Can't cast filter value `{value}` to arbitrary type.",
errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors],
)

# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
if clear_value is None and not any(not i_f.required for i_f in fields):
raise InvalidType(detail=", ".join(errors))
return getattr(model_column, self.operator)(clear_value)

def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
"""
Separates the types into two kinds. The first are those for which
there are already validators defined by pydantic - str, int, datetime
and some other built-in types. The second are all other types for which
the `arbitrary_types_allowed` config is applied when defining the pydantic model
"""
pydantic_types = [
# skip format
type_
for type_ in types
if type_ in REGISTERED_PYDANTIC_TYPES
]
userspace_types = [
# skip format
type_
for type_ in types
if type_ not in REGISTERED_PYDANTIC_TYPES
]
return pydantic_types, userspace_types

def _cast_value_with_pydantic(
self,
types: List[Type],
value: Any,
) -> Tuple[Optional[Any], List[str]]:
result_value, errors = None, []

for type_to_cast in types:
for validator in find_validators(type_to_cast, BaseConfig):
try:
result_value = validator(value)
return result_value, errors
except Exception as ex:
errors.append(str(ex))

return None, errors

def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
type_cast_failed = False
failed_values = []

result_values: List[Any] = []
errors: List[str] = []

for value in values:
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
errors.extend(cast_errors)

if casted_value is None:
type_cast_failed = True
failed_values.append(value)

continue

result_values.append(casted_value)

if type_cast_failed:
msg = f"Can't parse items {failed_values} of value {values}"
raise InvalidFilters(msg)

return result_values, errors

def resolve(self) -> FilterAndJoins: # noqa: PLR0911
"""Create filter for a particular node of the filter tree"""
if "or" in self.filter_:
Expand Down
4 changes: 4 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ def sqla_uri():
db_dir = Path(__file__).resolve().parent
testing_db_url = f"sqlite+aiosqlite:///{db_dir}/db.sqlite3"
return testing_db_url


def is_postgres_tests() -> bool:
return "postgres" in sqla_uri()
7 changes: 6 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from uuid import UUID

from sqlalchemy import JSON, Column, ForeignKey, Index, Integer, String, Text
from sqlalchemy import JSON, Column, DateTime, ForeignKey, Index, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declared_attr, relationship
from sqlalchemy.types import CHAR, TypeDecorator
Expand Down Expand Up @@ -296,3 +296,8 @@ class SelfRelationship(Base):
)
# parent = relationship("SelfRelationship", back_populates="s")
self_relationship = relationship("SelfRelationship", remote_side=[id])


class ContainsTimestamp(Base):
id = Column(Integer, primary_key=True)
timestamp = Column(DateTime(True), nullable=False)
101 changes: 100 additions & 1 deletion tests/test_api/test_api_sqla_with_includes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
from collections import defaultdict
from datetime import datetime, timezone
from itertools import chain, zip_longest
from json import dumps
from typing import Dict, List
Expand All @@ -8,15 +10,18 @@
from fastapi import FastAPI, status
from httpx import AsyncClient
from pydantic import BaseModel, Field
from pytest import fixture, mark, param # noqa PT013
from pytest import fixture, mark, param, raises # noqa PT013
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from fastapi_jsonapi.views.view_base import ViewBase
from tests.common import is_postgres_tests
from tests.fixtures.app import build_app_custom
from tests.fixtures.entities import build_workplace, create_user
from tests.misc.utils import fake
from tests.models import (
Computer,
ContainsTimestamp,
IdCast,
Post,
PostComment,
Expand Down Expand Up @@ -1215,6 +1220,100 @@ async def test_create_with_relationship_to_the_same_table(self):
"meta": None,
}

async def test_create_with_timestamp_and_fetch(self, async_session: AsyncSession):
resource_type = "contains_timestamp_model"

class ContainsTimestampAttrsSchema(BaseModel):
timestamp: datetime

app = build_app_custom(
model=ContainsTimestamp,
schema=ContainsTimestampAttrsSchema,
schema_in_post=ContainsTimestampAttrsSchema,
schema_in_patch=ContainsTimestampAttrsSchema,
resource_type=resource_type,
)

create_timestamp = datetime.now(tz=timezone.utc)
create_body = {
"data": {
"attributes": {
"timestamp": create_timestamp.isoformat(),
},
},
}

async with AsyncClient(app=app, base_url="http://test") as client:
url = app.url_path_for(f"get_{resource_type}_list")
res = await client.post(url, json=create_body)
assert res.status_code == status.HTTP_201_CREATED, res.text
response_json = res.json()

assert (entity_id := response_json["data"]["id"])
assert response_json == {
"meta": None,
"jsonapi": {"version": "1.0"},
"data": {
"type": resource_type,
"attributes": {"timestamp": create_timestamp.isoformat()},
"id": entity_id,
},
}

stms = select(ContainsTimestamp).where(ContainsTimestamp.id == int(entity_id))
(await async_session.execute(stms)).scalar_one()

expected_response_timestamp = create_timestamp.replace(tzinfo=None).isoformat()
if is_postgres_tests():
expected_response_timestamp = create_timestamp.replace().isoformat()

params = {
"filter": json.dumps(
[
{
"name": "timestamp",
"op": "eq",
"val": create_timestamp.isoformat(),
},
],
),
}

# successfully filtered
res = await client.get(url, params=params)
assert res.status_code == status.HTTP_200_OK, res.text
assert res.json() == {
"meta": {"count": 1, "totalPages": 1},
"jsonapi": {"version": "1.0"},
"data": [
{
"type": resource_type,
"attributes": {"timestamp": expected_response_timestamp},
"id": entity_id,
},
],
}

# check filter really work
params = {
"filter": json.dumps(
[
{
"name": "timestamp",
"op": "eq",
"val": datetime.now(tz=timezone.utc).isoformat(),
},
],
),
}
res = await client.get(url, params=params)
assert res.status_code == status.HTTP_200_OK, res.text
assert res.json() == {
"meta": {"count": 0, "totalPages": 1},
"jsonapi": {"version": "1.0"},
"data": [],
}


class TestPatchObjects:
async def test_patch_object(
Expand Down
Empty file.
Empty file.
Loading