Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions fastapi_jsonapi/data_layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,49 @@ async def delete_relationship(
"""
raise NotImplementedError

def get_related_model_query_base(
self,
related_model: Type[TypeModel],
):
"""
Prepare query for the related model

:param related_model: Related ORM model class (not instance)
:return:
"""
raise NotImplementedError

def get_related_object_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
id_value: str,
):
"""
Prepare query to get related object
:param related_model:
:param related_id_field:
:param id_value:
:return:
"""
raise NotImplementedError

def get_related_objects_list_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
ids: list[str],
):
"""
Prepare query to get related objects list
:param related_model:
:param related_id_field:
:param ids:
:return:
"""
raise NotImplementedError

# async def get_related_object_query(self):
async def get_related_object(
self,
related_model: Type[TypeModel],
Expand Down
49 changes: 43 additions & 6 deletions fastapi_jsonapi/data_layers/sqla_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,38 @@ async def delete_relationship(
:param view_kwargs: kwargs from the resource view.
"""

def get_related_model_query_base(
self,
related_model: Type[TypeModel],
) -> "Select":
"""
:param related_model:
:return:
"""
return select(related_model)

def get_related_object_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
id_value: str,
):
id_field = getattr(related_model, related_id_field)
id_value = self.prepare_id_value(id_field, id_value)
stmt: "Select" = self.get_related_model_query_base(related_model)
return stmt.where(id_field == id_value)

def get_related_objects_list_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
ids: list[str],
) -> Tuple["Select", list[str]]:
id_field = getattr(related_model, related_id_field)
prepared_ids = [self.prepare_id_value(id_field, _id) for _id in ids]
stmt: "Select" = self.get_related_model_query_base(related_model)
return stmt.where(id_field.in_(prepared_ids)), prepared_ids

async def get_related_object(
self,
related_model: Type[TypeModel],
Expand All @@ -532,9 +564,12 @@ async def get_related_object(
:param id_value: related object id value
:return: a related SQLA ORM object
"""
id_field = getattr(related_model, related_id_field)
id_value = self.prepare_id_value(id_field, id_value)
stmt = select(related_model).where(id_field == id_value)
stmt = self.get_related_object_query(
related_model=related_model,
related_id_field=related_id_field,
id_value=id_value,
)

try:
related_object = (await self.session.execute(stmt)).scalar_one()
except NoResultFound:
Expand All @@ -556,9 +591,11 @@ async def get_related_objects_list(
:param ids:
:return:
"""
id_field = getattr(related_model, related_id_field)
ids = [self.prepare_id_value(id_field, _id) for _id in ids]
stmt = select(related_model).where(id_field.in_(ids))
stmt, ids = self.get_related_objects_list_query(
related_model=related_model,
related_id_field=related_id_field,
ids=ids,
)

related_objects = (await self.session.execute(stmt)).scalars().all()
object_ids = [getattr(obj, related_id_field) for obj in related_objects]
Expand Down