diff --git a/fastapi_jsonapi/views/view_base.py b/fastapi_jsonapi/views/view_base.py index a44a8052..8108011e 100644 --- a/fastapi_jsonapi/views/view_base.py +++ b/fastapi_jsonapi/views/view_base.py @@ -1,10 +1,12 @@ import inspect import logging +from collections import defaultdict from contextvars import ContextVar from functools import partial from typing import ( Any, Callable, + ClassVar, Dict, Iterable, List, @@ -47,6 +49,9 @@ included_object_schema_ctx_var: ContextVar[Type[TypeSchema]] = ContextVar("included_object_schema_ctx_var") relationship_info_ctx_var: ContextVar[RelationshipInfo] = ContextVar("relationship_info_ctx_var") +# TODO: just change state on `self`!! (refactor) +included_objects_ctx_var: ContextVar[Dict[Tuple[str, str], TypeSchema]] = ContextVar("included_objects_ctx_var") + class ViewBase: """ @@ -54,7 +59,7 @@ class ViewBase: """ data_layer_cls = BaseDataLayer - method_dependencies: Dict[HTTPMethod, HTTPMethodConfig] = {} + method_dependencies: ClassVar[Dict[HTTPMethod, HTTPMethodConfig]] = {} def __init__(self, *, request: Request, jsonapi: RoutersJSONAPI, **options): self.request: Request = request @@ -240,12 +245,12 @@ def prepare_data_for_relationship( def update_related_object( cls, relationship_data: Union[Dict[str, str], List[Dict[str, str]]], - included_objects: Dict[Tuple[str, str], TypeSchema], cache_key: Tuple[str, str], related_field_name: str, ): relationships_schema: Type[BaseModel] = relationships_schema_ctx_var.get() object_schema: Type[JSONAPIObjectSchema] = object_schema_ctx_var.get() + included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get() relationship_data_schema = get_related_schema(relationships_schema, related_field_name) parent_included_object = included_objects.get(cache_key) @@ -256,12 +261,10 @@ def update_related_object( existing = existing.dict() new_relationships.update(existing) new_relationships.update( - { - **{ - related_field_name: relationship_data_schema( - data=relationship_data, - ), - }, + **{ + related_field_name: relationship_data_schema( + data=relationship_data, + ), }, ) included_objects[cache_key] = object_schema.parse_obj( @@ -273,17 +276,19 @@ def update_related_object( @classmethod def update_known_included( cls, - included_objects: Dict[Tuple[str, str], TypeSchema], new_included: List[TypeSchema], ): + included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get() + for included in new_included: - included_objects[(included.id, included.type)] = included + key = (included.id, included.type) + if key not in included_objects: + included_objects[key] = included @classmethod def process_single_db_item_and_prepare_includes( cls, parent_db_item: TypeModel, - included_objects: Dict[Tuple[str, str], TypeSchema], ): previous_resource_type: str = previous_resource_type_ctx_var.get() related_field_name: str = related_field_name_ctx_var.get() @@ -305,7 +310,6 @@ def process_single_db_item_and_prepare_includes( ) cls.update_known_included( - included_objects=included_objects, new_included=new_included, ) relationship_data_items.append(data_for_relationship) @@ -317,7 +321,6 @@ def process_single_db_item_and_prepare_includes( cls.update_related_object( relationship_data=relationship_data_items, - included_objects=included_objects, cache_key=cache_key, related_field_name=related_field_name, ) @@ -328,14 +331,12 @@ def process_single_db_item_and_prepare_includes( def process_db_items_and_prepare_includes( cls, parent_db_items: List[TypeModel], - included_objects: Dict[Tuple[str, str], TypeSchema], ): next_current_db_item = [] for parent_db_item in parent_db_items: new_next_items = cls.process_single_db_item_and_prepare_includes( parent_db_item=parent_db_item, - included_objects=included_objects, ) next_current_db_item.extend(new_next_items) return next_current_db_item @@ -346,18 +347,21 @@ def process_include_with_nested( current_db_item: Union[List[TypeModel], TypeModel], item_as_schema: TypeSchema, current_relation_schema: Type[TypeSchema], + included_objects: Dict[Tuple[str, str], TypeSchema], + requested_includes: Dict[str, Iterable[str]], ) -> Tuple[Dict[str, TypeSchema], List[JSONAPIObjectSchema]]: root_item_key = (item_as_schema.id, item_as_schema.type) - included_objects: Dict[Tuple[str, str], TypeSchema] = { - root_item_key: item_as_schema, - } + + if root_item_key not in included_objects: + included_objects[root_item_key] = item_as_schema previous_resource_type = item_as_schema.type + previous_related_field_name = previous_resource_type for related_field_name in include.split(SPLIT_REL): object_schemas = self.jsonapi.schema_builder.create_jsonapi_object_schemas( schema=current_relation_schema, - includes=[related_field_name], - compute_included_schemas=bool([related_field_name]), + includes=requested_includes[previous_related_field_name], + compute_included_schemas=True, ) relationships_schema = object_schemas.relationships_schema schemas_include = object_schemas.can_be_included_schemas @@ -379,16 +383,28 @@ def process_include_with_nested( related_field_name_ctx_var.set(related_field_name) relationship_info_ctx_var.set(relationship_info) included_object_schema_ctx_var.set(included_object_schema) + included_objects_ctx_var.set(included_objects) current_db_item = self.process_db_items_and_prepare_includes( parent_db_items=current_db_item, - included_objects=included_objects, ) previous_resource_type = relationship_info.resource_type + previous_related_field_name = related_field_name return included_objects.pop(root_item_key), list(included_objects.values()) + def prep_requested_includes(self, includes: Iterable[str]): + requested_includes: Dict[str, set[str]] = defaultdict(set) + default: str = self.jsonapi.type_ + for include in includes: + prev = default + for related_field_name in include.split(SPLIT_REL): + requested_includes[prev].add(related_field_name) + prev = related_field_name + + return requested_includes + def process_db_object( self, includes: List[str], @@ -403,12 +419,17 @@ def process_db_object( attributes=object_schemas.attributes_schema.from_orm(item), ) + cache_included_objects: Dict[Tuple[str, str], TypeSchema] = {} + requested_includes = self.prep_requested_includes(includes) + for include in includes: item_as_schema, new_included_objects = self.process_include_with_nested( include=include, current_db_item=item, item_as_schema=item_as_schema, current_relation_schema=item_schema, + included_objects=cache_included_objects, + requested_includes=requested_includes, ) included_objects.extend(new_included_objects) diff --git a/pyproject.toml b/pyproject.toml index a9c5251e..a4b07c4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,6 +219,7 @@ extend-ignore = [ "RUF001", # String contains ambiguous unicode character {confusable} (did you mean {representant}?) "RUF002", # Docstring contains ambiguous unicode character {confusable} (did you mean {representant}?) "RUF003", # Comment contains ambiguous unicode character {confusable} (did you mean {representant}?) + "PT006", # pytest parametrize tuple args ] [tool.ruff.per-file-ignores] diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index b245132b..c2a47956 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from itertools import chain, zip_longest from json import dumps from typing import Dict, List @@ -598,6 +599,180 @@ async def test_many_to_many_load_inner_includes_to_parents( assert ("child", ViewBase.get_db_item_id(child_4)) not in included_data +class TestUserWithPostsWithInnerIncludes: + @mark.parametrize( + "include, expected_relationships_inner_relations, expect_user_include", + [ + ( + ["posts", "posts.user"], + {"post": ["user"], "user": []}, + False, + ), + ( + ["posts", "posts.comments"], + {"post": ["comments"], "post_comment": []}, + False, + ), + ( + ["posts", "posts.user", "posts.comments"], + {"post": ["user", "comments"], "user": [], "post_comment": []}, + False, + ), + ( + ["posts", "posts.user", "posts.comments", "posts.comments.author"], + {"post": ["user", "comments"], "post_comment": ["author"], "user": []}, + True, + ), + ], + ) + async def test_get_users_with_posts_and_inner_includes( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + user_2: User, + user_1_posts: list[PostComment], + user_1_post_for_comments: Post, + user_2_comment_for_one_u1_post: PostComment, + include: list[str], + expected_relationships_inner_relations: dict[str, list[str]], + expect_user_include: bool, + ): + """ + Test if requesting `posts.user` and `posts.comments` + returns posts with both `user` and `comments` + """ + assert user_1_posts + assert user_2_comment_for_one_u1_post.author_id == user_2.id + include_param = ",".join(include) + resource_type = "user" + url = app.url_path_for(f"get_{resource_type}_list") + url = f"{url}?filter[name]={user_1.name}&include={include_param}" + response = await client.get(url) + assert response.status_code == status.HTTP_200_OK, response.text + response_json = response.json() + + result_data = response_json["data"] + + assert result_data == [ + { + "id": str(user_1.id), + "type": resource_type, + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(), + "relationships": { + "posts": { + "data": [ + # relationship info + {"id": str(p.id), "type": "post"} + # for every post + for p in user_1_posts + ], + }, + }, + }, + ] + included_data = response_json["included"] + included_as_map = defaultdict(list) + for item in included_data: + included_as_map[item["type"]].append(item) + + for item_type, items in included_as_map.items(): + expected_relationships = expected_relationships_inner_relations[item_type] + for item in items: + relationships = set(item.get("relationships", {})) + assert relationships.intersection(expected_relationships) == set( + expected_relationships, + ), f"Expected relationships {expected_relationships} not found in {item_type} {item['id']}" + + expected_includes = self.prepare_expected_includes( + user_1=user_1, + user_2=user_2, + user_1_posts=user_1_posts, + user_2_comment_for_one_u1_post=user_2_comment_for_one_u1_post, + ) + + for item_type, includes_names in expected_relationships_inner_relations.items(): + items = expected_includes[item_type] + have_to_be_present = set(includes_names) + for item in items: # type: dict + item_relationships = item.get("relationships", {}) + for key in tuple(item_relationships.keys()): + if key not in have_to_be_present: + item_relationships.pop(key) + if not item_relationships: + item.pop("relationships", None) + + for key in set(expected_includes).difference(expected_relationships_inner_relations): + expected_includes.pop(key) + + # XXX + if not expect_user_include: + expected_includes.pop("user", None) + assert included_as_map == expected_includes + + def prepare_expected_includes( + self, + user_1: User, + user_2: User, + user_1_posts: list[PostComment], + user_2_comment_for_one_u1_post: PostComment, + ): + expected_includes = { + "post": [ + # + { + "id": str(p.id), + "type": "post", + "attributes": PostAttributesBaseSchema.from_orm(p).dict(), + "relationships": { + "user": { + "data": { + "id": str(user_1.id), + "type": "user", + }, + }, + "comments": { + "data": [ + { + "id": str(user_2_comment_for_one_u1_post.id), + "type": "post_comment", + }, + ] + if p.id == user_2_comment_for_one_u1_post.post_id + else [], + }, + }, + } + # + for p in user_1_posts + ], + "post_comment": [ + { + "id": str(user_2_comment_for_one_u1_post.id), + "type": "post_comment", + "attributes": PostCommentAttributesBaseSchema.from_orm(user_2_comment_for_one_u1_post).dict(), + "relationships": { + "author": { + "data": { + "id": str(user_2.id), + "type": "user", + }, + }, + }, + }, + ], + "user": [ + { + "id": str(user_2.id), + "type": "user", + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(), + }, + ], + } + + return expected_includes + + async def test_method_not_allowed(app: FastAPI, client: AsyncClient): url = app.url_path_for("get_user_list") res = await client.put(url, json={})