diff --git a/example/serializers.py b/example/serializers.py index 61812337..e259a10b 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -25,25 +25,33 @@ class EntrySerializer(serializers.ModelSerializer): def __init__(self, *args, **kwargs): # to make testing more concise we'll only output the - # `suggested` field when it's requested via `include` + # `featured` field when it's requested via `include` request = kwargs.get('context', {}).get('request') - if request and 'suggested' not in request.query_params.get('include', []): - self.fields.pop('suggested') + if request and 'featured' not in request.query_params.get('include', []): + self.fields.pop('featured') super(EntrySerializer, self).__init__(*args, **kwargs) included_serializers = { 'authors': 'example.serializers.AuthorSerializer', 'comments': 'example.serializers.CommentSerializer', - 'suggested': 'example.serializers.EntrySerializer', + 'featured': 'example.serializers.EntrySerializer', } body_format = serializers.SerializerMethodField() + # many related from model comments = relations.ResourceRelatedField( source='comment_set', many=True, read_only=True) + # many related from serializer suggested = relations.SerializerMethodResourceRelatedField( - source='get_suggested', model=Entry, read_only=True) + source='get_suggested', model=Entry, many=True, read_only=True) + # single related from serializer + featured = relations.SerializerMethodResourceRelatedField( + source='get_featured', model=Entry, read_only=True) def get_suggested(self, obj): + return Entry.objects.exclude(pk=obj.pk) + + def get_featured(self, obj): return Entry.objects.exclude(pk=obj.pk).first() def get_body_format(self, obj): @@ -52,7 +60,7 @@ def get_body_format(self, obj): class Meta: model = Entry fields = ('blog', 'headline', 'body_text', 'pub_date', 'mod_date', - 'authors', 'comments', 'suggested',) + 'authors', 'comments', 'featured', 'suggested',) meta_fields = ('body_format',) diff --git a/example/tests/integration/test_includes.py b/example/tests/integration/test_includes.py index 8c4cb587..05c59131 100644 --- a/example/tests/integration/test_includes.py +++ b/example/tests/integration/test_includes.py @@ -31,7 +31,7 @@ def test_included_data_on_detail(single_entry, client): def test_dynamic_related_data_is_included(single_entry, entry_factory, client): entry_factory() - response = client.get(reverse("entry-detail", kwargs={'pk': single_entry.pk}) + '?include=suggested') + response = client.get(reverse("entry-detail", kwargs={'pk': single_entry.pk}) + '?include=featured') included = load_json(response.content).get('included') assert [x.get('type') for x in included] == ['entries'], 'Dynamic included types are incorrect' diff --git a/example/tests/integration/test_non_paginated_responses.py b/example/tests/integration/test_non_paginated_responses.py index f68f2b71..de9e3055 100644 --- a/example/tests/integration/test_non_paginated_responses.py +++ b/example/tests/integration/test_non_paginated_responses.py @@ -41,6 +41,9 @@ def test_multiple_entries_no_pagination(multiple_entries, rf): "comments": { "meta": {"count": 1}, "data": [{"type": "comments", "id": "1"}] + }, + "suggested": { + "data": [{"type": "entries", "id": "2"}] } } }, @@ -69,6 +72,9 @@ def test_multiple_entries_no_pagination(multiple_entries, rf): "comments": { "meta": {"count": 1}, "data": [{"type": "comments", "id": "2"}] + }, + "suggested": { + "data": [{"type": "entries", "id": "1"}] } } }, diff --git a/example/tests/integration/test_pagination.py b/example/tests/integration/test_pagination.py index 0cc5e15e..742be523 100644 --- a/example/tests/integration/test_pagination.py +++ b/example/tests/integration/test_pagination.py @@ -35,6 +35,9 @@ def test_pagination_with_single_entry(single_entry, client): "comments": { "meta": {"count": 1}, "data": [{"type": "comments", "id": "1"}] + }, + "suggested": { + "data": [] } } }], diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index b7ccce36..0e6594d5 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -3,6 +3,7 @@ from rest_framework.fields import MISSING_ERROR_MESSAGE from rest_framework.relations import * from django.utils.translation import ugettext_lazy as _ +from django.db.models.query import QuerySet from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.utils import Hyperlink, \ @@ -168,11 +169,50 @@ def choices(self): ]) + class SerializerMethodResourceRelatedField(ResourceRelatedField): + """ + Allows us to use serializer method RelatedFields + with return querysets + """ + def __new__(cls, *args, **kwargs): + """ + We override this because getting serializer methods + fails at the base class when many=True + """ + if kwargs.pop('many', False): + return cls.many_init(*args, **kwargs) + return super(ResourceRelatedField, cls).__new__(cls, *args, **kwargs) + + def __init__(self, child_relation=None, *args, **kwargs): + # DRF 3.1 doesn't expect the `many` kwarg + kwargs.pop('many', None) + model = kwargs.pop('model', None) + if model: + self.model = model + super(SerializerMethodResourceRelatedField, self).__init__(child_relation, *args, **kwargs) + + @classmethod + def many_init(cls, *args, **kwargs): + list_kwargs = {'child_relation': cls(*args, **kwargs)} + for key in kwargs.keys(): + if key in ('model',) + MANY_RELATION_KWARGS: + list_kwargs[key] = kwargs[key] + return SerializerMethodResourceRelatedField(**list_kwargs) + def get_attribute(self, instance): # check for a source fn defined on the serializer instead of the model if self.source and hasattr(self.parent, self.source): serializer_method = getattr(self.parent, self.source) if hasattr(serializer_method, '__call__'): return serializer_method(instance) - return super(ResourceRelatedField, self).get_attribute(instance) + return super(SerializerMethodResourceRelatedField, self).get_attribute(instance) + + def to_representation(self, value): + if isinstance(value, QuerySet): + base = super(SerializerMethodResourceRelatedField, self) + return [base.to_representation(x) for x in value] + return super(SerializerMethodResourceRelatedField, self).to_representation(value) + + def get_links(self, obj=None, lookup_field='pk'): + return OrderedDict()