diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index f51358f2..1d9ef502 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -131,7 +131,7 @@ def format_value(value, format_type=None): def build_json_resource_obj(fields, resource, resource_instance, resource_name): resource_data = [ ('type', resource_name), - ('id', extract_id(fields, resource)), + ('id', encoding.force_text(resource_instance.pk)), ('attributes', extract_attributes(fields, resource)), ] relationships = extract_relationships(fields, resource, resource_instance) @@ -171,31 +171,10 @@ def get_related_resource_type(relation): return inflection.pluralize(relation_model.__name__).lower() -def extract_id_from_url(url): - http_prefix = url.startswith(('http:', 'https:')) - if http_prefix: - # If needed convert absolute URLs to relative path - data = urlparse(url).path - prefix = urlresolvers.get_script_prefix() - if data.startswith(prefix): - url = '/' + data[len(prefix):] - - match = urlresolvers.resolve(url) - return encoding.force_text(match.kwargs['pk']) - - -def extract_id(fields, resource): - for field_name, field in six.iteritems(fields): - if field_name == 'id': - return encoding.force_text(resource.get(field_name)) - if field_name == api_settings.URL_FIELD_NAME: - return extract_id_from_url(resource.get(field_name)) - - def extract_attributes(fields, resource): data = OrderedDict() for field_name, field in six.iteritems(fields): - # ID is always provided in the root of JSON API so remove it from attrs + # ID is always provided in the root of JSON API so remove it from attributes if field_name == 'id': continue # Skip fields with relations @@ -219,15 +198,21 @@ def extract_relationships(fields, resource, resource_instance): if not isinstance(field, (RelatedField, ManyRelatedField, BaseSerializer)): continue + relation_type = get_related_resource_type(field) + relation_instance_or_manager = getattr(resource_instance, field_name) + if isinstance(field, HyperlinkedIdentityField): # special case for HyperlinkedIdentityField relation_data = list() - relation_type = get_related_resource_type(field) - relation_manager = getattr(resource_instance, field_name) + # Don't try to query an empty relation - related = relation_manager.all() if relation_manager is not None else list() - for relation in related: - relation_data.append(OrderedDict([('type', relation_type), ('id', relation.pk)])) + relation_queryset = relation_instance_or_manager.all() \ + if relation_instance_or_manager is not None else list() + + for related_object in relation_queryset: + relation_data.append( + OrderedDict([('type', relation_type), ('id', encoding.force_text(related_object.pk))]) + ) data.update({field_name: { 'links': { @@ -240,27 +225,26 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, (PrimaryKeyRelatedField, HyperlinkedRelatedField)): - relation_type = get_related_resource_type(field) - relation_id = getattr(resource_instance, field_name).pk if resource.get(field_name) else None + relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None relation_data = { - 'data': (OrderedDict([ - ('type', relation_type), ('id', relation_id) - ]) if relation_id is not None else None) + 'data': ( + OrderedDict([('type', relation_type), ('id', encoding.force_text(relation_id))]) + if relation_id is not None else None) } relation_data.update( {'links': {'related': resource.get(field_name)}} - if isinstance(field, HyperlinkedRelatedField) and resource.get(field_name) else {} + if isinstance(field, HyperlinkedRelatedField) and resource.get(field_name) else dict() ) data.update({field_name: relation_data}) continue if isinstance(field, ManyRelatedField): relation_data = list() - relation = field.child_relation - relation_type = get_related_resource_type(relation) - for related_object in getattr(resource_instance, field_name).all(): + related_object = field.child_relation + relation_type = get_related_resource_type(related_object) + for related_object in relation_instance_or_manager.all(): relation_data.append(OrderedDict([ ('type', relation_type), ('id', encoding.force_text(related_object.pk)) @@ -277,20 +261,20 @@ def extract_relationships(fields, resource, resource_instance): if isinstance(field, ListSerializer): relation_data = list() - serializer = field.child relation_model = serializer.Meta.model relation_type = inflection.pluralize(relation_model.__name__).lower() - # Get the serializer fields - serializer_fields = get_serializer_fields(serializer) serializer_data = resource.get(field_name) + resource_instance_queryset = relation_instance_or_manager.all() if isinstance(serializer_data, list): - for serializer_resource in serializer_data: + for position in range(len(serializer_data)): + nested_resource_instance = resource_instance_queryset[position] relation_data.append( - OrderedDict([ - ('type', relation_type), ('id', extract_id(serializer_fields, serializer_resource)) - ])) + OrderedDict( + [('type', relation_type), ('id', encoding.force_text(nested_resource_instance.pk))] + ) + ) data.update({field_name: {'data': relation_data}}) continue @@ -299,15 +283,12 @@ def extract_relationships(fields, resource, resource_instance): relation_model = field.Meta.model relation_type = inflection.pluralize(relation_model.__name__).lower() - # Get the serializer fields - serializer_fields = get_serializer_fields(field) - serializer_data = resource.get(field_name) data.update({ field_name: { 'data': ( OrderedDict([ ('type', relation_type), - ('id', extract_id(serializer_fields, serializer_data)) + ('id', encoding.force_text(relation_instance_or_manager.pk)) ]) if resource.get(field_name) else None) } }) @@ -327,20 +308,21 @@ def extract_included(fields, resource, resource_instance): if not isinstance(field, BaseSerializer): continue - if isinstance(field, ListSerializer): + relation_instance_or_manager = getattr(resource_instance, field_name) + relation_queryset = relation_instance_or_manager.all() + serializer_data = resource.get(field_name) + if isinstance(field, ListSerializer): serializer = field.child model = serializer.Meta.model relation_type = inflection.pluralize(model.__name__).lower() # Get the serializer fields serializer_fields = get_serializer_fields(serializer) - serializer_data = resource.get(field_name) - if isinstance(serializer_data, list): + if serializer_data: for position in range(len(serializer_data)): serializer_resource = serializer_data[position] - resource_instance_manager = getattr(resource_instance, field_name).all() - nested_resource_instance = resource_instance_manager[position] + nested_resource_instance = relation_queryset[position] included_data.append( build_json_resource_obj( serializer_fields, serializer_resource, nested_resource_instance, relation_type @@ -348,17 +330,14 @@ def extract_included(fields, resource, resource_instance): ) if isinstance(field, ModelSerializer): - model = field.Meta.model relation_type = inflection.pluralize(model.__name__).lower() # Get the serializer fields serializer_fields = get_serializer_fields(field) - serializer_data = resource.get(field_name) - nested_resource_instance = getattr(resource_instance, field_name).all() if serializer_data: included_data.append( - build_json_resource_obj(serializer_fields, serializer_data, nested_resource_instance, relation_type) + build_json_resource_obj(serializer_fields, serializer_data, relation_queryset, relation_type) ) return format_keys(included_data)