|
2 | 2 | Pagination serializers determine the structure of the output that should |
3 | 3 | be used for paginated responses. |
4 | 4 | """ |
| 5 | +import json |
| 6 | +import operator |
| 7 | + |
5 | 8 | from base64 import b64decode, b64encode |
6 | 9 | from collections import OrderedDict, namedtuple |
7 | 10 | from urllib import parse |
| 11 | +from functools import reduce |
8 | 12 |
|
9 | 13 | from django.core.paginator import InvalidPage |
10 | 14 | from django.core.paginator import Paginator as DjangoPaginator |
11 | 15 | from django.template import loader |
12 | 16 | from django.utils.encoding import force_str |
13 | 17 | from django.utils.translation import gettext_lazy as _ |
| 18 | +from django.db.models.query import Q |
14 | 19 |
|
15 | 20 | from rest_framework.compat import coreapi, coreschema |
16 | 21 | from rest_framework.exceptions import NotFound |
@@ -616,25 +621,41 @@ def paginate_queryset(self, queryset, request, view=None): |
616 | 621 | else: |
617 | 622 | (offset, reverse, current_position) = self.cursor |
618 | 623 |
|
619 | | - # Cursor pagination always enforces an ordering. |
620 | | - if reverse: |
621 | | - queryset = queryset.order_by(*_reverse_ordering(self.ordering)) |
622 | | - else: |
623 | | - queryset = queryset.order_by(*self.ordering) |
624 | | - |
625 | 624 | # If we have a cursor with a fixed position then filter by that. |
626 | 625 | if current_position is not None: |
627 | | - order = self.ordering[0] |
628 | | - is_reversed = order.startswith('-') |
629 | | - order_attr = order.lstrip('-') |
| 626 | + current_position_list = json.loads(current_position) |
630 | 627 |
|
631 | | - # Test for: (cursor reversed) XOR (queryset reversed) |
632 | | - if self.cursor.reverse != is_reversed: |
633 | | - kwargs = {order_attr + '__lt': current_position} |
634 | | - else: |
635 | | - kwargs = {order_attr + '__gt': current_position} |
| 628 | + q_objects_equals = {} |
| 629 | + q_objects_compare = {} |
| 630 | + |
| 631 | + for order, position in zip(self.ordering, current_position_list): |
| 632 | + is_reversed = order.startswith("-") |
| 633 | + order_attr = order.lstrip("-") |
| 634 | + |
| 635 | + q_objects_equals[order] = Q(**{order_attr: position}) |
| 636 | + |
| 637 | + # Test for: (cursor reversed) XOR (queryset reversed) |
| 638 | + if self.cursor.reverse != is_reversed: |
| 639 | + q_objects_compare[order] = Q( |
| 640 | + **{(order_attr + "__lt"): position} |
| 641 | + ) |
| 642 | + else: |
| 643 | + q_objects_compare[order] = Q( |
| 644 | + **{(order_attr + "__gt"): position} |
| 645 | + ) |
636 | 646 |
|
637 | | - queryset = queryset.filter(**kwargs) |
| 647 | + filter_list = [] |
| 648 | + # starting with the second field |
| 649 | + for i in range(len(self.ordering)): |
| 650 | + # The first operands need to be equals |
| 651 | + # the last operands need to be gt |
| 652 | + equals = list(self.ordering[:i+2]) |
| 653 | + greater_than_q = q_objects_compare[equals.pop()] |
| 654 | + sub_filters = [q_objects_equals[e] for e in equals] |
| 655 | + sub_filters.append(greater_than_q) |
| 656 | + filter_list.append(reduce(operator.and_, sub_filters)) |
| 657 | + |
| 658 | + queryset = queryset.filter(reduce(operator.or_, filter_list)) |
638 | 659 |
|
639 | 660 | # If we have an offset cursor then offset the entire page by that amount. |
640 | 661 | # We also always fetch an extra item in order to determine if there is a |
@@ -839,7 +860,14 @@ def get_ordering(self, request, queryset, view): |
839 | 860 | ) |
840 | 861 |
|
841 | 862 | if isinstance(ordering, str): |
842 | | - return (ordering,) |
| 863 | + ordering = (ordering,) |
| 864 | + |
| 865 | + pk_name = queryset.model._meta.pk.name |
| 866 | + |
| 867 | + # Always include a unique key to order by |
| 868 | + if not {f"-{pk_name}", pk_name, "pk", "-pk"} & set(ordering): |
| 869 | + ordering = ordering + (pk_name,) |
| 870 | + |
843 | 871 | return tuple(ordering) |
844 | 872 |
|
845 | 873 | def decode_cursor(self, request): |
@@ -884,12 +912,18 @@ def encode_cursor(self, cursor): |
884 | 912 | return replace_query_param(self.base_url, self.cursor_query_param, encoded) |
885 | 913 |
|
886 | 914 | def _get_position_from_instance(self, instance, ordering): |
887 | | - field_name = ordering[0].lstrip('-') |
888 | | - if isinstance(instance, dict): |
889 | | - attr = instance[field_name] |
890 | | - else: |
891 | | - attr = getattr(instance, field_name) |
892 | | - return str(attr) |
| 915 | + fields = [] |
| 916 | + |
| 917 | + for o in ordering: |
| 918 | + field_name = o.lstrip("-") |
| 919 | + if isinstance(instance, dict): |
| 920 | + attr = instance[field_name] |
| 921 | + else: |
| 922 | + attr = getattr(instance, field_name) |
| 923 | + |
| 924 | + fields.append(str(attr)) |
| 925 | + |
| 926 | + return json.dumps(fields).encode() |
893 | 927 |
|
894 | 928 | def get_paginated_response(self, data): |
895 | 929 | return Response(OrderedDict([ |
|
0 commit comments