diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index e89d0d40..cec99cdc 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +import mongoengine from collections import OrderedDict from functools import partial, reduce @@ -72,8 +73,7 @@ def args(self): def args(self, args): self._base_args = args - @property - def field_args(self): + def _field_args(self, items): def is_filterable(v): return not isinstance(v, (ConnectionField, Dynamic)) @@ -82,15 +82,19 @@ def get_type(v): return v.type.of_type() return v.type() - return {k: get_type(v) for k, v in self.fields.items() - if is_filterable(v)} + return {k: get_type(v) for k, v in items if is_filterable(v)} + + @property + def field_args(self): + return self._field_args(self.fields.items()) @property def reference_args(self): def get_reference_field(r, kv): if callable(getattr(kv[1], 'get_type', None)): node = kv[1].get_type()._type._meta - r.update({kv[0]: node.fields['id']._type.of_type()}) + if not issubclass(node.model, mongoengine.EmbeddedDocument): + r.update({kv[0]: node.fields['id']._type.of_type()}) return r return reduce(get_reference_field, self.fields.items(), {}) @@ -105,7 +109,6 @@ def get_query(cls, model, info, **args): return [], 0 objs = model.objects() - if args: reference_fields = get_model_reference_fields(model) reference_args = {} diff --git a/graphene_mongo/tests/test_relay_query.py b/graphene_mongo/tests/test_relay_query.py index 58f28751..65638bff 100644 --- a/graphene_mongo/tests/test_relay_query.py +++ b/graphene_mongo/tests/test_relay_query.py @@ -1,4 +1,5 @@ import json +import pytest import graphene @@ -11,7 +12,8 @@ PlayerNode, ReporterNode, ChildNode, - ParentWithRelationshipNode) + ParentWithRelationshipNode, + ProfessorVectorNode,) from ..fields import MongoengineConnectionField @@ -726,3 +728,44 @@ class Query(graphene.ObjectType): assert not result.errors assert json.dumps(result.data, sort_keys=True) == json.dumps( expected, sort_keys=True) + + +def test_should_query_with_embedded_document(fixtures): + + class Query(graphene.ObjectType): + + all_professors = MongoengineConnectionField(ProfessorVectorNode) + + query = ''' + query { + allProfessors { + edges { + node { + vec, + metadata { + firstName + } + } + } + } + } + ''' + expected = { + 'allProfessors': { + 'edges': [ + { + 'node': { + 'vec': [1.0, 2.3], + 'metadata': { + 'firstName': 'Steven' + } + } + + } + ] + } + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert dict(result.data['allProfessors']) == expected['allProfessors'] diff --git a/graphene_mongo/tests/types.py b/graphene_mongo/tests/types.py index e432944f..f174377d 100644 --- a/graphene_mongo/tests/types.py +++ b/graphene_mongo/tests/types.py @@ -106,24 +106,35 @@ class Meta: class ChildNode(MongoengineObjectType): + class Meta: model = Child interfaces = (Node,) class ChildRegisteredBeforeNode(MongoengineObjectType): + class Meta: model = ChildRegisteredBefore interfaces = (Node, ) class ParentWithRelationshipNode(MongoengineObjectType): + class Meta: model = ParentWithRelationship interfaces = (Node, ) class ChildRegisteredAfterNode(MongoengineObjectType): + class Meta: model = ChildRegisteredAfter interfaces = (Node, ) + + +class ProfessorVectorNode(MongoengineObjectType): + + class Meta: + model = ProfessorVector + interfaces = (Node, )