Skip to content

fix: empty list is not an empty value for list filters even when a custom filtering method is provided #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion graphene_django/compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
from pathlib import PurePath

# For backwards compatibility, we import JSONField to have it available for import via
# this compat module (https://github.com/graphql-python/graphene-django/issues/1428).
# Django's JSONField is available in Django 3.2+ (the minimum version we support)
Expand All @@ -19,4 +22,23 @@ def __init__(self, *args, **kwargs):
RangeField,
)
except ImportError:
IntegerRangeField, ArrayField, HStoreField, RangeField = (MissingType,) * 4
IntegerRangeField, HStoreField, RangeField = (MissingType,) * 3

# For unit tests we fake ArrayField using JSONFields
if any(
PurePath(sys.argv[0]).match(p)
for p in [
"**/pytest",
"**/py.test",
"**/pytest/__main__.py",
]
):

class ArrayField(JSONField):
def __init__(self, *args, **kwargs):
if len(args) > 0:
self.base_field = args[0]
super().__init__(**kwargs)

else:
ArrayField = MissingType
23 changes: 23 additions & 0 deletions graphene_django/filter/filters/array_filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
from django_filters.constants import EMPTY_VALUES
from django_filters.filters import FilterMethod

from .typed_filter import TypedFilter


class ArrayFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)


class ArrayFilter(TypedFilter):
"""
Filter made for PostgreSQL ArrayField.
"""

@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ArrayFilterMethod that consider empty lists as values.

Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ArrayFilterMethod(self)

def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
Expand Down
24 changes: 24 additions & 0 deletions graphene_django/filter/filters/list_filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from django_filters.filters import FilterMethod

from .typed_filter import TypedFilter


class ListFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)


class ListFilter(TypedFilter):
"""
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""

@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ListFilterMethod that consider empty lists as values.

Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ListFilterMethod(self)

def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
Expand Down
152 changes: 97 additions & 55 deletions graphene_django/filter/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from functools import reduce

import pytest
from django.db import models
Expand All @@ -25,15 +25,15 @@
)


STORE = {"events": []}


class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())

def __repr__(self):
return f"Event [{self.name}]"


@pytest.fixture
def EventFilterSet():
Expand All @@ -48,6 +48,14 @@ class Meta:
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
tags__len = ArrayFilter(
field_name="tags", lookup_expr="len", input_type=graphene.Int
)
tags__len__in = ArrayFilter(
field_name="tags",
method="tags__len__in_filter",
input_type=graphene.List(graphene.Int),
)

# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
Expand All @@ -61,6 +69,14 @@ class Meta:
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")

def tags__len__in_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return reduce(
lambda q1, q2: q1.union(q2),
[queryset.filter(tags__len=v) for v in value],
).distinct()

return EventFilterSet


Expand All @@ -83,68 +99,94 @@ def Query(EventType):
we are running unit tests in sqlite which does not have ArrayFields.
"""

events = [
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]

class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)

def resolve_events(self, info, **kwargs):
events = [
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]

STORE["events"] = events

m_queryset = MagicMock(spec=QuerySet)
m_queryset.model = Event

def filter_events(**kwargs):
if "tags__contains" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
STORE["events"],
class FakeQuerySet(QuerySet):
def __init__(self, model=None):
self.model = Event
self.__store = list(events)

def all(self):
return self

def filter(self, **kwargs):
queryset = FakeQuerySet()
queryset.__store = list(self.__store)
if "tags__contains" in kwargs:
queryset.__store = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
queryset.__store,
)
)
if "tags__overlap" in kwargs:
queryset.__store = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
queryset.__store,
)
)
)
if "tags__overlap" in kwargs:
STORE["events"] = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
STORE["events"],
if "tags__exact" in kwargs:
queryset.__store = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
queryset.__store,
)
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
if "tags__len" in kwargs:
queryset.__store = list(
filter(
lambda e: len(e.tags) == kwargs["tags__len"],
queryset.__store,
)
)
)
return queryset

def union(self, *args):
queryset = FakeQuerySet()
queryset.__store = self.__store
for arg in args:
queryset.__store += arg.__store
return queryset

def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
return m_queryset
def none(self):
queryset = FakeQuerySet()
queryset.__store = []
return queryset

def mock_queryset_none(*args, **kwargs):
STORE["events"] = []
return m_queryset
def count(self):
return len(self.__store)

def mock_queryset_count(*args, **kwargs):
return len(STORE["events"])
def distinct(self):
queryset = FakeQuerySet()
queryset.__store = []
for event in self.__store:
if event not in queryset.__store:
queryset.__store.append(event)
queryset.__store = sorted(queryset.__store, key=lambda e: e.name)
return queryset

m_queryset.all.return_value = m_queryset
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
def __getitem__(self, index):
return self.__store[index]

return m_queryset
return FakeQuerySet()

return Query


@pytest.fixture
def schema(Query):
return graphene.Schema(query=Query)
14 changes: 3 additions & 11 deletions graphene_django/filter/tests/test_array_field_contains_filter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import pytest

from graphene import Schema

from ...compat import ArrayField, MissingType


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_multiple(Query):
def test_array_field_contains_multiple(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: ["concert", "music"]) {
Expand All @@ -32,13 +28,11 @@ def test_array_field_contains_multiple(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_one(Query):
def test_array_field_contains_one(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: ["music"]) {
Expand All @@ -59,13 +53,11 @@ def test_array_field_contains_one(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_empty_list(Query):
def test_array_field_contains_empty_list(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: []) {
Expand Down
Loading