diff --git a/example/tests/test_relations.py b/example/tests/test_relations.py index ea60ae9a..bf2bfed7 100644 --- a/example/tests/test_relations.py +++ b/example/tests/test_relations.py @@ -1,10 +1,10 @@ from __future__ import absolute_import from django.utils import timezone - from rest_framework import serializers from . import TestBase +from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.utils import format_relation_name from example.models import Blog, Entry, Comment, Author from rest_framework_json_api.relations import ResourceRelatedField @@ -74,15 +74,17 @@ def test_deserialize_primitive_data_blog(self): self.assertEqual(serializer.validated_data['blog'], self.blog) def test_validation_fails_for_wrong_type(self): - serializer = BlogFKSerializer(data={ - 'blog': { - 'type': 'Entries', - 'id': str(self.blog.id) + with self.assertRaises(Conflict) as cm: + serializer = BlogFKSerializer(data={ + 'blog': { + 'type': 'Entries', + 'id': str(self.blog.id) + } } - } - ) - - self.assertFalse(serializer.is_valid()) + ) + serializer.is_valid() + the_exception = cm.exception + self.assertEqual(the_exception.status_code, 409) def test_serialize_many_to_many_relation(self): serializer = EntryModelSerializer(instance=self.entry) diff --git a/rest_framework_json_api/parsers.py b/rest_framework_json_api/parsers.py index 0afc3551..5aa81f9c 100644 --- a/rest_framework_json_api/parsers.py +++ b/rest_framework_json_api/parsers.py @@ -48,9 +48,12 @@ def parse(self, stream, media_type=None, parser_context=None): raise ParseError('Received data is not a valid JSONAPI Resource Identifier Object') return data + + request = parser_context.get('request') + # Check for inconsistencies resource_name = utils.get_resource_name(parser_context) - if data.get('type') != resource_name: + if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'): raise exceptions.Conflict( "The resource object's type ({data_type}) is not the type " "that constitute the collection represented by the endpoint ({resource_type}).".format( @@ -72,9 +75,9 @@ def parse(self, stream, media_type=None, parser_context=None): for field_name, field_data in relationships.items(): field_data = field_data.get('data') if isinstance(field_data, dict): - parsed_relationships[field_name] = field_data.get('id') + parsed_relationships[field_name] = field_data elif isinstance(field_data, list): - parsed_relationships[field_name] = list(relation.get('id') for relation in field_data) + parsed_relationships[field_name] = list(relation for relation in field_data) # Construct the return data parsed_data = {'id': data_id} diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index 2f33af0a..f81bb45f 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -1,9 +1,12 @@ from rest_framework.exceptions import ValidationError +from rest_framework.fields import MISSING_ERROR_MESSAGE from rest_framework.relations import * -from rest_framework_json_api.utils import format_relation_name, get_related_resource_type, \ - get_resource_type_from_queryset, get_resource_type_from_instance from django.utils.translation import ugettext_lazy as _ +from rest_framework_json_api.exceptions import Conflict +from rest_framework_json_api.utils import format_relation_name, Hyperlink, \ + get_resource_type_from_queryset, get_resource_type_from_instance + class HyperlinkedRelatedField(HyperlinkedRelatedField): """ @@ -40,22 +43,108 @@ def to_internal_value(self, data): class ResourceRelatedField(PrimaryKeyRelatedField): + self_link_view_name = None + related_link_view_name = None + related_link_lookup_field = 'pk' + default_error_messages = { 'required': _('This field is required.'), 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), - 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), + 'incorrect_type': _('Incorrect type. Expected resource identifier object, received {data_type}.'), 'incorrect_relation_type': _('Incorrect relation type. Expected {relation_type}, received {received_type}.'), + 'no_match': _('Invalid hyperlink - No URL match.'), } + def __init__(self, self_link_view_name=None, related_link_view_name=None, **kwargs): + if self_link_view_name is not None: + self.self_link_view_name = self_link_view_name + if related_link_view_name is not None: + self.related_link_view_name = related_link_view_name + + self.related_link_lookup_field = kwargs.pop('related_link_lookup_field', self.related_link_lookup_field) + self.related_link_url_kwarg = kwargs.pop('related_link_url_kwarg', self.related_link_lookup_field) + + # We include this simply for dependency injection in tests. + # We can't add it as a class attributes or it would expect an + # implicit `self` argument to be passed. + self.reverse = reverse + + super(ResourceRelatedField, self).__init__(**kwargs) + + def use_pk_only_optimization(self): + # We need the real object to determine its type... + return False + + def conflict(self, key, **kwargs): + """ + A helper method that simply raises a validation error. + """ + try: + msg = self.error_messages[key] + except KeyError: + class_name = self.__class__.__name__ + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + raise Conflict(message_string) + + def get_url(self, name, view_name, kwargs, request): + """ + Given a name, view name and kwargs, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + + # Return None if the view name is not supplied + if not view_name: + return None + + # Return the hyperlink, or error if incorrectly configured. + try: + url = self.reverse(view_name, kwargs=kwargs, request=request) + except NoReverseMatch: + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s".' + ) + raise ImproperlyConfigured(msg % view_name) + + if url is None: + return None + + return Hyperlink(url, name) + + def get_links(self): + request = self.context.get('request', None) + view = self.context.get('view', None) + return_data = OrderedDict() + self_kwargs = view.kwargs.copy() + self_kwargs.update({'related_field': self.field_name if self.field_name else self.parent.field_name}) + self_link = self.get_url('self', self.self_link_view_name, self_kwargs, request) + + related_kwargs = {self.related_link_url_kwarg: view.kwargs[self.related_link_lookup_field]} + related_link = self.get_url('related', self.related_link_view_name, related_kwargs, request) + + if self_link: + return_data.update({'self': self_link}) + if related_link: + return_data.update({'related': related_link}) + return return_data + def to_internal_value(self, data): expected_relation_type = get_resource_type_from_queryset(self.queryset) + if not isinstance(data, dict): + self.fail('incorrect_type', data_type=type(data).__name__) if data['type'] != expected_relation_type: - self.fail('incorrect_relation_type', relation_type=expected_relation_type, received_type=data['type']) + self.conflict('incorrect_relation_type', relation_type=expected_relation_type, received_type=data['type']) return super(ResourceRelatedField, self).to_internal_value(data['id']) def to_representation(self, value): - return { - 'type': format_relation_name(get_resource_type_from_instance(value)), - 'id': str(value.pk) - } + if getattr(self, 'pk_field', None) is not None: + pk = self.pk_field.to_representation(value.pk) + else: + pk = value.pk + + return OrderedDict([('type', format_relation_name(get_resource_type_from_instance(value))), ('id', str(pk))]) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 17da3941..362d70f5 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -11,6 +11,7 @@ from rest_framework.settings import api_settings from rest_framework.exceptions import APIException + try: from rest_framework.compat import OrderedDict except ImportError: @@ -237,6 +238,9 @@ def extract_attributes(fields, resource): def extract_relationships(fields, resource, resource_instance): + # Avoid circular deps + from rest_framework_json_api.relations import ResourceRelatedField + data = OrderedDict() # Don't try to extract relationships from a non-existent resource @@ -254,7 +258,7 @@ def extract_relationships(fields, resource, resource_instance): try: relation_instance_or_manager = getattr(resource_instance, field_name) - except AttributeError: # Skip fields defined on the serializer that don't correspond to a field on the model + except AttributeError: # Skip fields defined on the serializer that don't correspond to a field on the model continue relation_type = get_related_resource_type(field) @@ -282,6 +286,20 @@ def extract_relationships(fields, resource, resource_instance): }}) continue + if isinstance(field, ResourceRelatedField): + # special case for ResourceRelatedField + relation_data = { + 'data': resource.get(field_name) + } + + field_links = field.get_links() + relation_data.update( + {'links': field_links} + if field_links else dict() + ) + data.update({field_name: relation_data}) + continue + if isinstance(field, (PrimaryKeyRelatedField, HyperlinkedRelatedField)): relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None @@ -299,6 +317,28 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ManyRelatedField): + + if isinstance(field.child_relation, ResourceRelatedField): + # special case for ResourceRelatedField + relation_data = { + 'data': resource.get(field_name) + } + + field_links = field.child_relation.get_links() + relation_data.update( + {'links': field_links} + if field_links else dict() + ) + relation_data.update( + { + 'meta': { + 'count': len(resource.get(field_name)) + } + } + ) + data.update({field_name: relation_data}) + continue + relation_data = list() for related_object in relation_instance_or_manager.all(): related_object_type = get_instance_or_manager_resource_type(relation_instance_or_manager) @@ -395,3 +435,21 @@ def extract_included(fields, resource, resource_instance): ) return format_keys(included_data) + + +class Hyperlink(six.text_type): + """ + A string like object that additionally has an associated name. + We use this for hyperlinked URLs that may render as a named link + in some contexts, or render as a plain URL in others. + + Comes from Django REST framework 3.2 + https://github.com/tomchristie/django-rest-framework + """ + + def __new__(self, url, name): + ret = six.text_type.__new__(self, url) + ret.name = name + return ret + + is_hyperlink = True diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 9fc2e67a..cbd37303 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -3,7 +3,6 @@ from django.db.models import Model from django.db.models.query import QuerySet from django.db.models.manager import Manager -from django.utils import six from rest_framework import generics from rest_framework.response import Response from rest_framework.exceptions import NotFound, MethodNotAllowed @@ -11,7 +10,7 @@ from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer -from rest_framework_json_api.utils import format_relation_name, get_resource_type_from_instance, OrderedDict +from rest_framework_json_api.utils import format_relation_name, get_resource_type_from_instance, OrderedDict, Hyperlink class RelationshipView(generics.GenericAPIView): @@ -28,29 +27,12 @@ def __init__(self, **kwargs): def get_url(self, name, view_name, kwargs, request): """ - Given an object, return the URL that hyperlinks to the object. + Given a name, view name and kwargs, return the URL that hyperlinks to the object. May raise a `NoReverseMatch` if the `view_name` and `lookup_field` attributes are not configured to correctly match the URL conf. """ - class Hyperlink(six.text_type): - """ - A string like object that additionally has an associated name. - We use this for hyperlinked URLs that may render as a named link - in some contexts, or render as a plain URL in others. - - Comes from Django REST framework 3.2 - https://github.com/tomchristie/django-rest-framework - """ - - def __new__(self, url, name): - ret = six.text_type.__new__(self, url) - ret.name = name - return ret - - is_hyperlink = True - # Return None if the view name is not supplied if not view_name: return None