Skip to content

Allow recursive connections with DjangoFilterConnectionField #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 3, 2017
76 changes: 59 additions & 17 deletions graphene_django/filter/fields.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,76 @@
import inspect

from collections import OrderedDict
from functools import partial

from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField
from graphene.relay import is_node
from .utils import get_filtering_args_from_filterset, get_filterset_class


class DjangoFilterConnectionField(DjangoConnectionField):

def __init__(self, type, fields=None, extra_filter_meta=None,
filterset_class=None, *args, **kwargs):
def __init__(self, type, fields=None, order_by=None,
extra_filter_meta=None, filterset_class=None,
*args, **kwargs):
self._fields = fields
self._type = type
self._filterset_class = filterset_class
self._extra_filter_meta = extra_filter_meta
self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)

@property
def node_type(self):
if inspect.isfunction(self._type) or inspect.ismethod(self._type):
return self._type()
return self._type

if is_node(type):
_fields = type._meta.filter_fields
_model = type._meta.model
@property
def meta(self):
if is_node(self.node_type):
_model = self.node_type._meta.model
else:
# ConnectionFields can also be passed Connections,
# in which case, we need to use the Node of the connection
# to get our relevant args.
_fields = type._meta.node._meta.filter_fields
_model = type._meta.node._meta.model

self.fields = fields or _fields
meta = dict(model=_model, fields=self.fields)
if extra_filter_meta:
meta.update(extra_filter_meta)
self.filterset_class = get_filterset_class(filterset_class, **meta)
self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type)
kwargs.setdefault('args', {})
kwargs['args'].update(self.filtering_args)
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
_model = self.node_type._meta.node._meta.model

meta = dict(model=_model,
fields=self.fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
return meta

@property
def fields(self):
if self._fields:
return self._fields

if is_node(self.node_type):
return self.node_type._meta.filter_fields
else:
# ConnectionFields can also be passed Connections,
# in which case, we need to use the Node of the connection
# to get our relevant args.
return self.node_type._meta.node._meta.filter_fields

@property
def args(self):
return to_arguments(self._base_args or OrderedDict(), self.filtering_args)

@args.setter
def args(self, args):
self._base_args = args

@property
def filterset_class(self):
return get_filterset_class(self._filterset_class, **self.meta)

@property
def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)

@staticmethod
def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args,
Expand Down
17 changes: 17 additions & 0 deletions graphene_django/filter/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,20 @@ class Query(ObjectType):
assert not result.errors
# We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2


def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)

def resolve_child_reporters(self, args, context, info):
return []

class Meta:
model = Reporter
interfaces = (Node, )

class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)

assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode