diff --git a/example/models.py b/example/models.py index 7dfbc1ab..7fd8a3d7 100644 --- a/example/models.py +++ b/example/models.py @@ -48,3 +48,13 @@ class Entry(BaseModel): def __str__(self): return self.headline + + +@python_2_unicode_compatible +class Comment(BaseModel): + entry = models.ForeignKey(Entry) + body = models.TextField() + author = models.ForeignKey(Author) + + def __str__(self): + return self.body diff --git a/example/tests/test_relations.py b/example/tests/test_relations.py new file mode 100644 index 00000000..ea60ae9a --- /dev/null +++ b/example/tests/test_relations.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import + +from django.utils import timezone + +from rest_framework import serializers + +from . import TestBase +from rest_framework_json_api.utils import format_relation_name +from example.models import Blog, Entry, Comment, Author +from rest_framework_json_api.relations import ResourceRelatedField + + +class TestResourceRelatedField(TestBase): + + def setUp(self): + super(TestResourceRelatedField, self).setUp() + self.blog = Blog.objects.create(name='Some Blog', tagline="It's a blog") + self.entry = Entry.objects.create( + blog=self.blog, + headline='headline', + body_text='body_text', + pub_date=timezone.now(), + mod_date=timezone.now(), + n_comments=0, + n_pingbacks=0, + rating=3 + ) + for i in range(1,6): + name = 'some_author{}'.format(i) + self.entry.authors.add( + Author.objects.create(name=name, email='{}@example.org'.format(name)) + ) + + self.comment = Comment.objects.create( + entry=self.entry, + body='testing one two three', + author=Author.objects.first() + ) + + def test_data_in_correct_format_when_instantiated_with_blog_object(self): + serializer = BlogFKSerializer(instance={'blog': self.blog}) + + expected_data = { + 'type': format_relation_name('Blog'), + 'id': str(self.blog.id) + } + + actual_data = serializer.data['blog'] + + self.assertEqual(actual_data, expected_data) + + def test_data_in_correct_format_when_instantiated_with_entry_object(self): + serializer = EntryFKSerializer(instance={'entry': self.entry}) + + expected_data = { + 'type': format_relation_name('Entry'), + 'id': str(self.entry.id) + } + + actual_data = serializer.data['entry'] + + self.assertEqual(actual_data, expected_data) + + def test_deserialize_primitive_data_blog(self): + serializer = BlogFKSerializer(data={ + 'blog': { + 'type': format_relation_name('Blog'), + 'id': str(self.blog.id) + } + } + ) + + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.validated_data['blog'], self.blog) + + def test_validation_fails_for_wrong_type(self): + serializer = BlogFKSerializer(data={ + 'blog': { + 'type': 'Entries', + 'id': str(self.blog.id) + } + } + ) + + self.assertFalse(serializer.is_valid()) + + def test_serialize_many_to_many_relation(self): + serializer = EntryModelSerializer(instance=self.entry) + + type_string = format_relation_name('Author') + author_pks = Author.objects.values_list('pk', flat=True) + expected_data = [{'type': type_string, 'id': str(pk)} for pk in author_pks] + + self.assertEqual( + serializer.data['authors'], + expected_data + ) + + def test_deserialize_many_to_many_relation(self): + type_string = format_relation_name('Author') + author_pks = Author.objects.values_list('pk', flat=True) + authors = [{'type': type_string, 'id': pk} for pk in author_pks] + + serializer = EntryModelSerializer(data={'authors': authors, 'comment_set': []}) + + self.assertTrue(serializer.is_valid()) + self.assertEqual(len(serializer.validated_data['authors']), Author.objects.count()) + for author in serializer.validated_data['authors']: + self.assertIsInstance(author, Author) + + def test_read_only(self): + serializer = EntryModelSerializer(data={'authors': [], 'comment_set': [{'type': 'Comments', 'id': 2}]}) + serializer.is_valid(raise_exception=True) + self.assertNotIn('comment_set', serializer.validated_data) + + +class BlogFKSerializer(serializers.Serializer): + blog = ResourceRelatedField(queryset=Blog.objects) + + +class EntryFKSerializer(serializers.Serializer): + entry = ResourceRelatedField(queryset=Entry.objects) + + +class EntryModelSerializer(serializers.ModelSerializer): + authors = ResourceRelatedField(many=True, queryset=Author.objects) + comment_set = ResourceRelatedField(many=True, read_only=True) + + class Meta: + model = Entry + fields = ('authors', 'comment_set') diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index dd818df8..4bf6896f 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -1,5 +1,6 @@ from rest_framework.exceptions import ValidationError from rest_framework.relations import * +from rest_framework_json_api.utils import format_relation_name, get_related_resource_type from django.utils.translation import ugettext_lazy as _ @@ -35,3 +36,25 @@ def to_internal_value(self, data): self.fail('pk_does_not_exist', pk_value=data) except (TypeError, ValueError): self.fail('incorrect_pk_type', data_type=type(data).__name__) + + +class ResourceRelatedField(PrimaryKeyRelatedField): + default_error_messages = { + 'required': _('This field is required.'), + 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), + 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), + 'incorrect_relation_type': _('Incorrect relation type. Expected {relation_type}, received {received_type}.'), + } + + def to_internal_value(self, data): + expected_relation_type = format_relation_name(get_related_resource_type(self)) + if data['type'] != expected_relation_type: + self.fail('incorrect_relation_type', relation_type=expected_relation_type, received_type=data['type']) + return super(ResourceRelatedField, self).to_internal_value(data['id']) + + def to_representation(self, value): + return { + 'type': format_relation_name(get_related_resource_type(self)), + 'id': str(value.pk) + } + diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index b3e75126..007d387a 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -179,6 +179,10 @@ def get_related_resource_type(relation): return format_relation_name(relation_model.__name__) +def get_model_name_from_queryset(qs): + return qs.model._meta.model_name + + def extract_attributes(fields, resource): data = OrderedDict() for field_name, field in six.iteritems(fields):