diff --git a/Makefile b/Makefile index 529ff45e..511ba897 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: @flake8 graphene_mongo test: clean lint - py.test --cov=graphene_mongo + py.test graphene_mongo/tests --cov=graphene_mongo --cov-report=html --cov-report=term register-pypitest: python setup.py register -r pypitest diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index fb3f94b5..7a3d844b 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -19,6 +19,7 @@ class MongoengineConnectionField(ConnectionField): + def __init__(self, type, *args, **kwargs): get_queryset = kwargs.pop("get_queryset", None) if get_queryset: @@ -186,6 +187,9 @@ def get_queryset(self, model, info, **args): def default_resolver(self, _root, info, **args): args = args or {} + if _root is not None: + args["pk__in"] = [r.pk for r in getattr(_root, info.field_name, [])] + connection_args = { "first": args.pop("first", None), "last": args.pop("last", None), diff --git a/graphene_mongo/tests/setup.py b/graphene_mongo/tests/setup.py index f3dceb28..80559713 100644 --- a/graphene_mongo/tests/setup.py +++ b/graphene_mongo/tests/setup.py @@ -78,9 +78,16 @@ def fixtures(): reporter1.save() Player.drop_collection() - player1 = Player(first_name="Michael", last_name="Jordan") + player1 = Player( + first_name="Michael", + last_name="Jordan", + articles=[article1, article2]) player1.save() - player2 = Player(first_name="Magic", last_name="Johnson", opponent=player1) + player2 = Player( + first_name="Magic", + last_name="Johnson", + opponent=player1, + articles=[article3]) player2.save() player3 = Player(first_name="Larry", last_name="Bird", players=[player1, player2]) player3.save() diff --git a/graphene_mongo/tests/test_relay_query.py b/graphene_mongo/tests/test_relay_query.py index 67316154..825fcdfc 100644 --- a/graphene_mongo/tests/test_relay_query.py +++ b/graphene_mongo/tests/test_relay_query.py @@ -976,3 +976,52 @@ class Query(graphene.ObjectType): assert json.dumps(result.data, sort_keys=True) == json.dumps( expected, sort_keys=True ) + + +def test_should_get_correct_list_of_documents(fixtures): + class Query(graphene.ObjectType): + players = MongoengineConnectionField(nodes.PlayerNode) + + query = """ + query players { + players(firstName: "Michael") { + edges { + node { + firstName, + articles(first: 3) { + edges { + node { + headline + } + } + } + } + } + } + } + """ + expected = { + "players": { + "edges": [{ + "node": { + "firstName": "Michael", + "articles": { + "edges": [{ + "node": { + "headline": "Hello" + } + }, { + "node": { + "headline": "World" + } + }] + } + } + }] + } + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_mongo/types.py b/graphene_mongo/types.py index 8f85bba9..a676dc1a 100644 --- a/graphene_mongo/types.py +++ b/graphene_mongo/types.py @@ -194,7 +194,6 @@ def rescan_fields(cls): cls._meta.fields.update({field: mongoengine_fields[field]}) # Self-referenced fields can't change between scans! - # noqa @classmethod def is_type_of(cls, root, info): if isinstance(root, cls): @@ -212,7 +211,3 @@ def get_node(cls, info, id): def resolve_id(self, info): return str(self.id) - - # @classmethod - # def get_connection(cls): - # return connection_for_type(cls)