From bf3c4c6a8b304e0e32469b515a88ee2b09144277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Br=C3=A9mond?= Date: Sun, 17 Jan 2021 18:46:33 -0800 Subject: [PATCH 1/2] Support contains/overlap filters --- graphene_django/filter/tests/conftest.py | 128 ++++++++++++++++++ .../filter/tests/test_contains_filter.py | 82 +++++++++++ .../filter/tests/test_overlap_filter.py | 84 ++++++++++++ graphene_django/filter/utils.py | 16 +-- 4 files changed, 301 insertions(+), 9 deletions(-) create mode 100644 graphene_django/filter/tests/conftest.py create mode 100644 graphene_django/filter/tests/test_contains_filter.py create mode 100644 graphene_django/filter/tests/test_overlap_filter.py diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py new file mode 100644 index 000000000..031364519 --- /dev/null +++ b/graphene_django/filter/tests/conftest.py @@ -0,0 +1,128 @@ +from mock import MagicMock +import pytest + +from django.db import models +from django.db.models.query import QuerySet +from django_filters import filters +from django_filters import FilterSet +import graphene +from graphene.relay import Node +from graphene_django import DjangoObjectType +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +from ...compat import ArrayField + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +STORE = {"events": []} + + +@pytest.fixture +def Event(): + class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + + return Event + + +@pytest.fixture +def EventFilterSet(Event): + + from django.contrib.postgres.forms import SimpleArrayField + + class ArrayFilter(filters.Filter): + base_field_class = SimpleArrayField + + class EventFilterSet(FilterSet): + class Meta: + model = Event + fields = { + "name": ["exact"], + } + + tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") + tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + + return EventFilterSet + + +@pytest.fixture +def EventType(Event, EventFilterSet): + class EventType(DjangoObjectType): + class Meta: + model = Event + interfaces = (Node,) + filterset_class = EventFilterSet + + return EventType + + +@pytest.fixture +def Query(Event, EventType): + class Query(graphene.ObjectType): + events = DjangoFilterConnectionField(EventType) + + def resolve_events(self, info, **kwargs): + + events = [ + Event(name="Live Show", tags=["concert", "music", "rock"],), + Event(name="Musical", tags=["movie", "music"],), + Event(name="Ballet", tags=["concert", "dance"],), + ] + + STORE["events"] = events + + m_queryset = MagicMock(spec=QuerySet) + m_queryset.model = Event + + def filter_events(**kwargs): + if "tags__contains" in kwargs: + STORE["events"] = list( + filter( + lambda e: set(kwargs["tags__contains"]).issubset( + set(e.tags) + ), + STORE["events"], + ) + ) + if "tags__overlap" in kwargs: + STORE["events"] = list( + filter( + lambda e: not set(kwargs["tags__overlap"]).isdisjoint( + set(e.tags) + ), + STORE["events"], + ) + ) + + def mock_queryset_filter(*args, **kwargs): + filter_events(**kwargs) + return m_queryset + + def mock_queryset_none(*args, **kwargs): + STORE["events"] = [] + return m_queryset + + def mock_queryset_count(*args, **kwargs): + return len(STORE["events"]) + + m_queryset.all.return_value = m_queryset + m_queryset.filter.side_effect = mock_queryset_filter + m_queryset.none.side_effect = mock_queryset_none + m_queryset.count.side_effect = mock_queryset_count + m_queryset.__getitem__.side_effect = STORE["events"].__getitem__ + + return m_queryset + + return Query diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_contains_filter.py new file mode 100644 index 000000000..3e90a3bc0 --- /dev/null +++ b/graphene_django/filter/tests/test_contains_filter.py @@ -0,0 +1,82 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_multiple(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_one(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_none(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_overlap_filter.py new file mode 100644 index 000000000..90e825f80 --- /dev/null +++ b/graphene_django/filter/tests/test_overlap_filter.py @@ -0,0 +1,84 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_multiple(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_one(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_none(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 76cbf4650..4530599e5 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,4 +1,4 @@ -from graphene import List +import graphene from django_filters.utils import get_model_field from django_filters.filters import Filter, BaseCSVFilter @@ -39,11 +39,11 @@ def get_filtering_args_from_filterset(filterset_class, type): field = convert_form_field(form_field) - if filter_type in ["in", "range"]: - # Replace CSV filters (`in`, `range`) argument type to be a list of + if filter_type in {"in", "range", "contains", "overlap"}: + # Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of # the same type as the field. See comments in # `replace_csv_filters` method for more details. - field = List(field.get_type()) + field = graphene.List(field.get_type()) field_type = field.Argument() field_type.description = str(filter_field.label) if filter_field.label else None @@ -69,7 +69,7 @@ def get_filterset_class(filterset_class, **meta): def replace_csv_filters(filterset_class): """ - Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore + Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore but regular Filter objects that simply use the input value as filter argument on the queryset. This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we @@ -79,8 +79,7 @@ def replace_csv_filters(filterset_class): """ for name, filter_field in list(filterset_class.base_filters.items()): filter_type = filter_field.lookup_expr - if filter_type == "in": - assert isinstance(filter_field, BaseCSVFilter) + if filter_type in {"in", "contains", "overlap"}: filterset_class.base_filters[name] = InFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, @@ -90,8 +89,7 @@ def replace_csv_filters(filterset_class): **filter_field.extra ) - if filter_type == "range": - assert isinstance(filter_field, BaseCSVFilter) + elif filter_type == "range": filterset_class.base_filters[name] = RangeFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, From b45286c727adc36e3e19d299eb96fae3f24ccaff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Br=C3=A9mond?= Date: Sun, 17 Jan 2021 22:25:40 -0800 Subject: [PATCH 2/2] Remove unused fixtures --- graphene_django/filter/tests/test_contains_filter.py | 6 +++--- graphene_django/filter/tests/test_overlap_filter.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_contains_filter.py index 3e90a3bc0..35e775ef5 100644 --- a/graphene_django/filter/tests/test_contains_filter.py +++ b/graphene_django/filter/tests/test_contains_filter.py @@ -6,7 +6,7 @@ @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_multiple(Event, Query): +def test_string_contains_multiple(Query): """ Test contains filter on a string field. """ @@ -32,7 +32,7 @@ def test_string_contains_multiple(Event, Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_one(Event, Query): +def test_string_contains_one(Query): """ Test contains filter on a string field. """ @@ -59,7 +59,7 @@ def test_string_contains_one(Event, Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_none(Event, Query): +def test_string_contains_none(Query): """ Test contains filter on a string field. """ diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_overlap_filter.py index 90e825f80..32dfa44a1 100644 --- a/graphene_django/filter/tests/test_overlap_filter.py +++ b/graphene_django/filter/tests/test_overlap_filter.py @@ -6,7 +6,7 @@ @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_multiple(Event, Query): +def test_string_overlap_multiple(Query): """ Test overlap filter on a string field. """ @@ -34,7 +34,7 @@ def test_string_overlap_multiple(Event, Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_one(Event, Query): +def test_string_overlap_one(Query): """ Test overlap filter on a string field. """ @@ -61,7 +61,7 @@ def test_string_overlap_one(Event, Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_none(Event, Query): +def test_string_overlap_none(Query): """ Test overlap filter on a string field. """