Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions fastapi_jsonapi/views/view_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import logging
from collections import defaultdict
from contextvars import ContextVar
from functools import partial
from typing import (
Any,
Callable,
ClassVar,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -47,14 +49,17 @@
included_object_schema_ctx_var: ContextVar[Type[TypeSchema]] = ContextVar("included_object_schema_ctx_var")
relationship_info_ctx_var: ContextVar[RelationshipInfo] = ContextVar("relationship_info_ctx_var")

# TODO: just change state on `self`!! (refactor)
included_objects_ctx_var: ContextVar[Dict[Tuple[str, str], TypeSchema]] = ContextVar("included_objects_ctx_var")


class ViewBase:
"""
Views are inited for each request
"""

data_layer_cls = BaseDataLayer
method_dependencies: Dict[HTTPMethod, HTTPMethodConfig] = {}
method_dependencies: ClassVar[Dict[HTTPMethod, HTTPMethodConfig]] = {}

def __init__(self, *, request: Request, jsonapi: RoutersJSONAPI, **options):
self.request: Request = request
Expand Down Expand Up @@ -240,12 +245,12 @@ def prepare_data_for_relationship(
def update_related_object(
cls,
relationship_data: Union[Dict[str, str], List[Dict[str, str]]],
included_objects: Dict[Tuple[str, str], TypeSchema],
cache_key: Tuple[str, str],
related_field_name: str,
):
relationships_schema: Type[BaseModel] = relationships_schema_ctx_var.get()
object_schema: Type[JSONAPIObjectSchema] = object_schema_ctx_var.get()
included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get()

relationship_data_schema = get_related_schema(relationships_schema, related_field_name)
parent_included_object = included_objects.get(cache_key)
Expand All @@ -256,12 +261,10 @@ def update_related_object(
existing = existing.dict()
new_relationships.update(existing)
new_relationships.update(
{
**{
related_field_name: relationship_data_schema(
data=relationship_data,
),
},
**{
related_field_name: relationship_data_schema(
data=relationship_data,
),
},
)
included_objects[cache_key] = object_schema.parse_obj(
Expand All @@ -273,17 +276,19 @@ def update_related_object(
@classmethod
def update_known_included(
cls,
included_objects: Dict[Tuple[str, str], TypeSchema],
new_included: List[TypeSchema],
):
included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get()

for included in new_included:
included_objects[(included.id, included.type)] = included
key = (included.id, included.type)
if key not in included_objects:
included_objects[key] = included

@classmethod
def process_single_db_item_and_prepare_includes(
cls,
parent_db_item: TypeModel,
included_objects: Dict[Tuple[str, str], TypeSchema],
):
previous_resource_type: str = previous_resource_type_ctx_var.get()
related_field_name: str = related_field_name_ctx_var.get()
Expand All @@ -305,7 +310,6 @@ def process_single_db_item_and_prepare_includes(
)

cls.update_known_included(
included_objects=included_objects,
new_included=new_included,
)
relationship_data_items.append(data_for_relationship)
Expand All @@ -317,7 +321,6 @@ def process_single_db_item_and_prepare_includes(

cls.update_related_object(
relationship_data=relationship_data_items,
included_objects=included_objects,
cache_key=cache_key,
related_field_name=related_field_name,
)
Expand All @@ -328,14 +331,12 @@ def process_single_db_item_and_prepare_includes(
def process_db_items_and_prepare_includes(
cls,
parent_db_items: List[TypeModel],
included_objects: Dict[Tuple[str, str], TypeSchema],
):
next_current_db_item = []

for parent_db_item in parent_db_items:
new_next_items = cls.process_single_db_item_and_prepare_includes(
parent_db_item=parent_db_item,
included_objects=included_objects,
)
next_current_db_item.extend(new_next_items)
return next_current_db_item
Expand All @@ -346,18 +347,21 @@ def process_include_with_nested(
current_db_item: Union[List[TypeModel], TypeModel],
item_as_schema: TypeSchema,
current_relation_schema: Type[TypeSchema],
included_objects: Dict[Tuple[str, str], TypeSchema],
requested_includes: Dict[str, Iterable[str]],
) -> Tuple[Dict[str, TypeSchema], List[JSONAPIObjectSchema]]:
root_item_key = (item_as_schema.id, item_as_schema.type)
included_objects: Dict[Tuple[str, str], TypeSchema] = {
root_item_key: item_as_schema,
}

if root_item_key not in included_objects:
included_objects[root_item_key] = item_as_schema
previous_resource_type = item_as_schema.type

previous_related_field_name = previous_resource_type
for related_field_name in include.split(SPLIT_REL):
object_schemas = self.jsonapi.schema_builder.create_jsonapi_object_schemas(
schema=current_relation_schema,
includes=[related_field_name],
compute_included_schemas=bool([related_field_name]),
includes=requested_includes[previous_related_field_name],
compute_included_schemas=True,
)
relationships_schema = object_schemas.relationships_schema
schemas_include = object_schemas.can_be_included_schemas
Expand All @@ -379,16 +383,28 @@ def process_include_with_nested(
related_field_name_ctx_var.set(related_field_name)
relationship_info_ctx_var.set(relationship_info)
included_object_schema_ctx_var.set(included_object_schema)
included_objects_ctx_var.set(included_objects)

current_db_item = self.process_db_items_and_prepare_includes(
parent_db_items=current_db_item,
included_objects=included_objects,
)

previous_resource_type = relationship_info.resource_type
previous_related_field_name = related_field_name

return included_objects.pop(root_item_key), list(included_objects.values())

def prep_requested_includes(self, includes: Iterable[str]):
requested_includes: Dict[str, set[str]] = defaultdict(set)
default: str = self.jsonapi.type_
for include in includes:
prev = default
for related_field_name in include.split(SPLIT_REL):
requested_includes[prev].add(related_field_name)
prev = related_field_name

return requested_includes

def process_db_object(
self,
includes: List[str],
Expand All @@ -403,12 +419,17 @@ def process_db_object(
attributes=object_schemas.attributes_schema.from_orm(item),
)

cache_included_objects: Dict[Tuple[str, str], TypeSchema] = {}
requested_includes = self.prep_requested_includes(includes)

for include in includes:
item_as_schema, new_included_objects = self.process_include_with_nested(
include=include,
current_db_item=item,
item_as_schema=item_as_schema,
current_relation_schema=item_schema,
included_objects=cache_included_objects,
requested_includes=requested_includes,
)

included_objects.extend(new_included_objects)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ extend-ignore = [
"RUF001", # String contains ambiguous unicode character {confusable} (did you mean {representant}?)
"RUF002", # Docstring contains ambiguous unicode character {confusable} (did you mean {representant}?)
"RUF003", # Comment contains ambiguous unicode character {confusable} (did you mean {representant}?)
"PT006", # pytest parametrize tuple args
]

[tool.ruff.per-file-ignores]
Expand Down
175 changes: 175 additions & 0 deletions tests/test_api/test_api_sqla_with_includes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections import defaultdict
from itertools import chain, zip_longest
from json import dumps
from typing import Dict, List
Expand Down Expand Up @@ -598,6 +599,180 @@ async def test_many_to_many_load_inner_includes_to_parents(
assert ("child", ViewBase.get_db_item_id(child_4)) not in included_data


class TestUserWithPostsWithInnerIncludes:
@mark.parametrize(
"include, expected_relationships_inner_relations, expect_user_include",
[
(
["posts", "posts.user"],
{"post": ["user"], "user": []},
False,
),
(
["posts", "posts.comments"],
{"post": ["comments"], "post_comment": []},
False,
),
(
["posts", "posts.user", "posts.comments"],
{"post": ["user", "comments"], "user": [], "post_comment": []},
False,
),
(
["posts", "posts.user", "posts.comments", "posts.comments.author"],
{"post": ["user", "comments"], "post_comment": ["author"], "user": []},
True,
),
],
)
async def test_get_users_with_posts_and_inner_includes(
self,
app: FastAPI,
client: AsyncClient,
user_1: User,
user_2: User,
user_1_posts: list[PostComment],
user_1_post_for_comments: Post,
user_2_comment_for_one_u1_post: PostComment,
include: list[str],
expected_relationships_inner_relations: dict[str, list[str]],
expect_user_include: bool,
):
"""
Test if requesting `posts.user` and `posts.comments`
returns posts with both `user` and `comments`
"""
assert user_1_posts
assert user_2_comment_for_one_u1_post.author_id == user_2.id
include_param = ",".join(include)
resource_type = "user"
url = app.url_path_for(f"get_{resource_type}_list")
url = f"{url}?filter[name]={user_1.name}&include={include_param}"
response = await client.get(url)
assert response.status_code == status.HTTP_200_OK, response.text
response_json = response.json()

result_data = response_json["data"]

assert result_data == [
{
"id": str(user_1.id),
"type": resource_type,
"attributes": UserAttributesBaseSchema.from_orm(user_1).dict(),
"relationships": {
"posts": {
"data": [
# relationship info
{"id": str(p.id), "type": "post"}
# for every post
for p in user_1_posts
],
},
},
},
]
included_data = response_json["included"]
included_as_map = defaultdict(list)
for item in included_data:
included_as_map[item["type"]].append(item)

for item_type, items in included_as_map.items():
expected_relationships = expected_relationships_inner_relations[item_type]
for item in items:
relationships = set(item.get("relationships", {}))
assert relationships.intersection(expected_relationships) == set(
expected_relationships,
), f"Expected relationships {expected_relationships} not found in {item_type} {item['id']}"

expected_includes = self.prepare_expected_includes(
user_1=user_1,
user_2=user_2,
user_1_posts=user_1_posts,
user_2_comment_for_one_u1_post=user_2_comment_for_one_u1_post,
)

for item_type, includes_names in expected_relationships_inner_relations.items():
items = expected_includes[item_type]
have_to_be_present = set(includes_names)
for item in items: # type: dict
item_relationships = item.get("relationships", {})
for key in tuple(item_relationships.keys()):
if key not in have_to_be_present:
item_relationships.pop(key)
if not item_relationships:
item.pop("relationships", None)

for key in set(expected_includes).difference(expected_relationships_inner_relations):
expected_includes.pop(key)

# XXX
if not expect_user_include:
expected_includes.pop("user", None)
assert included_as_map == expected_includes

def prepare_expected_includes(
self,
user_1: User,
user_2: User,
user_1_posts: list[PostComment],
user_2_comment_for_one_u1_post: PostComment,
):
expected_includes = {
"post": [
#
{
"id": str(p.id),
"type": "post",
"attributes": PostAttributesBaseSchema.from_orm(p).dict(),
"relationships": {
"user": {
"data": {
"id": str(user_1.id),
"type": "user",
},
},
"comments": {
"data": [
{
"id": str(user_2_comment_for_one_u1_post.id),
"type": "post_comment",
},
]
if p.id == user_2_comment_for_one_u1_post.post_id
else [],
},
},
}
#
for p in user_1_posts
],
"post_comment": [
{
"id": str(user_2_comment_for_one_u1_post.id),
"type": "post_comment",
"attributes": PostCommentAttributesBaseSchema.from_orm(user_2_comment_for_one_u1_post).dict(),
"relationships": {
"author": {
"data": {
"id": str(user_2.id),
"type": "user",
},
},
},
},
],
"user": [
{
"id": str(user_2.id),
"type": "user",
"attributes": UserAttributesBaseSchema.from_orm(user_2).dict(),
},
],
}

return expected_includes


async def test_method_not_allowed(app: FastAPI, client: AsyncClient):
url = app.url_path_for("get_user_list")
res = await client.put(url, json={})
Expand Down