Skip to content

Commit 6c1f9fd

Browse files
authored
Merge pull request #70 from mts-ai/run-validator-pass-model-field
fix run validator: sometimes it requires model field
2 parents 7095b40 + 13a8228 commit 6c1f9fd

File tree

7 files changed

+267
-43
lines changed

7 files changed

+267
-43
lines changed

fastapi_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2+
import inspect
23
import logging
34
from typing import (
45
Any,
@@ -133,10 +134,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
133134
pydantic_types, userspace_types = self._separate_types(types)
134135

135136
if pydantic_types:
137+
func = self._cast_value_with_pydantic
136138
if isinstance(value, list):
137-
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
138-
else:
139-
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
139+
func = self._cast_iterable_with_pydantic
140+
clear_value, errors = func(pydantic_types, value, schema_field)
140141

141142
if clear_value is None and userspace_types:
142143
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
@@ -151,7 +152,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
151152

152153
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
153154
if clear_value is None and not can_be_none:
154-
raise InvalidType(detail=", ".join(errors))
155+
raise InvalidType(
156+
detail=", ".join(errors),
157+
pointer=schema_field.name,
158+
)
155159

156160
return getattr(model_column, self.operator)(clear_value)
157161

@@ -179,32 +183,65 @@ def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
179183
]
180184
return pydantic_types, userspace_types
181185

186+
def _validator_requires_model_field(self, validator: Callable) -> bool:
187+
"""
188+
Check if validator accepts the `field` param
189+
190+
:param validator:
191+
:return:
192+
"""
193+
signature = inspect.signature(validator)
194+
parameters = signature.parameters
195+
196+
if "field" not in parameters:
197+
return False
198+
199+
field_param = parameters["field"]
200+
field_type = field_param.annotation
201+
202+
return field_type == "ModelField" or field_type is ModelField
203+
182204
def _cast_value_with_pydantic(
183205
self,
184206
types: List[Type],
185207
value: Any,
208+
schema_field: ModelField,
186209
) -> Tuple[Optional[Any], List[str]]:
187210
result_value, errors = None, []
188211

189212
for type_to_cast in types:
190213
for validator in find_validators(type_to_cast, BaseConfig):
214+
args = [value]
215+
# TODO: some other way to get all the validator's dependencies?
216+
if self._validator_requires_model_field(validator):
217+
args.append(schema_field)
191218
try:
192-
result_value = validator(value)
193-
return result_value, errors
219+
result_value = validator(*args)
194220
except Exception as ex:
195221
errors.append(str(ex))
222+
else:
223+
return result_value, errors
196224

197225
return None, errors
198226

199-
def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
227+
def _cast_iterable_with_pydantic(
228+
self,
229+
types: List[Type],
230+
values: List,
231+
schema_field: ModelField,
232+
) -> Tuple[List, List[str]]:
200233
type_cast_failed = False
201234
failed_values = []
202235

203236
result_values: List[Any] = []
204237
errors: List[str] = []
205238

206239
for value in values:
207-
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
240+
casted_value, cast_errors = self._cast_value_with_pydantic(
241+
types,
242+
value,
243+
schema_field,
244+
)
208245
errors.extend(cast_errors)
209246

210247
if casted_value is None:
@@ -217,7 +254,7 @@ def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple
217254

218255
if type_cast_failed:
219256
msg = f"Can't parse items {failed_values} of value {values}"
220-
raise InvalidFilters(msg)
257+
raise InvalidFilters(msg, pointer=schema_field.name)
221258

222259
return result_values, errors
223260

fastapi_jsonapi/exceptions/json_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ def __init__(
5353
parameter = parameter or self.parameter
5454
if not errors:
5555
if pointer:
56-
pointer = pointer if pointer.startswith("/") else "/data/" + pointer
56+
pointer = (
57+
pointer
58+
if pointer.startswith("/")
59+
else "/data/" + (pointer if pointer == "id" else "attributes/" + pointer)
60+
)
5761
self.source = {"pointer": pointer}
5862
elif parameter:
5963
self.source = {"parameter": parameter}

tests/fixtures/app.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from tests.models import (
1616
Child,
1717
Computer,
18+
CustomUUIDItem,
1819
Parent,
1920
ParentToChildAssociation,
2021
Post,
@@ -30,6 +31,7 @@
3031
ComputerInSchema,
3132
ComputerPatchSchema,
3233
ComputerSchema,
34+
CustomUUIDItemSchema,
3335
ParentPatchSchema,
3436
ParentSchema,
3537
ParentToChildAssociationSchema,
@@ -178,6 +180,17 @@ def add_routers(app_plain: FastAPI):
178180
schema_in_post=TaskInSchema,
179181
)
180182

183+
RoutersJSONAPI(
184+
router=router,
185+
path="/custom-uuid-item",
186+
tags=["Custom UUID Item"],
187+
class_detail=DetailViewBaseGeneric,
188+
class_list=ListViewBaseGeneric,
189+
model=CustomUUIDItem,
190+
schema=CustomUUIDItemSchema,
191+
resource_type="custom_uuid_item",
192+
)
193+
181194
atomic = AtomicOperations()
182195

183196
app_plain.include_router(router, prefix="")

tests/fixtures/db_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ async def async_session_plain(async_engine):
5151

5252
@async_fixture(scope="class")
5353
async def async_session(async_session_plain):
54-
async with async_session_plain() as session:
54+
async with async_session_plain() as session: # type: AsyncSession
5555
yield session
56+
await session.rollback()

tests/models.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy.orm import declared_attr, relationship
77
from sqlalchemy.types import CHAR, TypeDecorator
88

9-
from tests.common import sqla_uri
9+
from tests.common import is_postgres_tests, sqla_uri
1010

1111

1212
class Base:
@@ -253,33 +253,43 @@ def load_dialect_impl(self, dialect):
253253
return CHAR(32)
254254

255255
def process_bind_param(self, value, dialect):
256+
if value is None:
257+
return value
258+
256259
if not isinstance(value, UUID):
257260
msg = f"Incorrect type got {type(value).__name__}, expected {UUID.__name__}"
258261
raise Exception(msg)
259262

260263
return str(value)
261264

262265
def process_result_value(self, value, dialect):
263-
return UUID(value)
266+
return value and UUID(value)
264267

265268
@property
266269
def python_type(self):
267270
return UUID if self.as_uuid else str
268271

269272

270273
db_uri = sqla_uri()
271-
if "postgres" in db_uri:
274+
if is_postgres_tests():
272275
# noinspection PyPep8Naming
273-
from sqlalchemy.dialects.postgresql import UUID as UUIDType
276+
from sqlalchemy.dialects.postgresql.asyncpg import AsyncpgUUID as UUIDType
274277
elif "sqlite" in db_uri:
275278
UUIDType = CustomUUIDType
276279
else:
277280
msg = "unsupported dialect (custom uuid?)"
278281
raise ValueError(msg)
279282

280283

281-
class IdCast(Base):
282-
id = Column(UUIDType, primary_key=True)
284+
class CustomUUIDItem(Base):
285+
__tablename__ = "custom_uuid_item"
286+
id = Column(UUIDType(as_uuid=True), primary_key=True)
287+
288+
extra_id = Column(
289+
UUIDType(as_uuid=True),
290+
nullable=True,
291+
unique=True,
292+
)
283293

284294

285295
class SelfRelationship(Base):

tests/schemas.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,14 @@ class TaskSchema(TaskBaseSchema):
389389
# uuid below
390390

391391

392-
class IdCastSchema(BaseModel):
392+
class CustomUUIDItemAttributesSchema(BaseModel):
393+
extra_id: Optional[UUID] = None
394+
395+
class Config:
396+
orm_mode = True
397+
398+
399+
class CustomUUIDItemSchema(CustomUUIDItemAttributesSchema):
393400
id: UUID = Field(client_can_set_id=True)
394401

395402

0 commit comments

Comments
 (0)