Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 143 additions & 40 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
from collections import OrderedDict
from functools import partial, reduce

import bson
import graphene
import mongoengine
from bson import DBRef, ObjectId
from graphene import Context
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from mongoengine.base import get_document
from promise import Promise
from graphql_relay import from_global_id
from graphene.relay import ConnectionField
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
from graphene.types.structures import Structure
from graphql_relay.connection.array_connection import cursor_to_offset
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from graphql_relay import from_global_id
from graphql_relay.connection.arrayconnection import cursor_to_offset
from mongoengine import QuerySet
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure

from .advanced_types import (
FileFieldType,
Expand All @@ -30,6 +32,9 @@
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
connection_from_iterables
import pymongo

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])


class MongoengineConnectionField(ConnectionField):
Expand Down Expand Up @@ -77,9 +82,27 @@ def registry(self):

@property
def args(self):
_field_args = self.field_args
_advance_args = self.advance_args
_filter_args = self.filter_args
_extended_args = self.extended_args
if self._type._meta.non_filter_fields:
for _field in self._type._meta.non_filter_fields:
if _field in _field_args:
_field_args.pop(_field)
if _field in _advance_args:
_advance_args.pop(_field)
if _field in _filter_args:
_filter_args.pop(_field)
if _field in _extended_args:
_filter_args.pop(_field)
extra_args = dict(dict(dict(_field_args, **_advance_args), **_filter_args), **_extended_args)

for key in list(self._base_args.keys()):
extra_args.pop(key, None)
return to_arguments(
self._base_args or OrderedDict(),
dict(dict(dict(self.field_args, **self.advance_args), **self.filter_args), **self.extended_args),
extra_args
)

@args.setter
Expand All @@ -100,6 +123,14 @@ def is_filterable(k):
return False
if not hasattr(self.model, k):
return False
else:
# else section is a patch for federated field error
field_ = self.fields[k]
type_ = field_.type
while hasattr(type_, "of_type"):
type_ = type_.of_type
if hasattr(type_, "_sdl") and "@key" in type_._sdl:
return False
if isinstance(getattr(self.model, k), property):
return False
try:
Expand Down Expand Up @@ -128,6 +159,9 @@ def is_filterable(k):
getattr(converted, "_of_type", None), graphene.Union
):
return False
# below if condition: workaround for DB filterable field redefined as custom graphene type
if hasattr(field_, 'type') and hasattr(converted, 'type') and converted.type != field_.type:
return False
return True

def get_filter_type(_type):
Expand All @@ -150,7 +184,7 @@ def filter_args(self):
if self._type._meta.filter_fields:
for field, filter_collection in self._type._meta.filter_fields.items():
for each in filter_collection:
if str(self._type._meta.fields[field].type) == 'PointFieldType':
if str(self._type._meta.fields[field].type) in ('PointFieldType', 'PointFieldType!'):
if each == 'max_distance':
filter_type = graphene.Int
else:
Expand Down Expand Up @@ -279,17 +313,17 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, **args):
def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
args = args or {}
for key, value in dict(args).items():
if value is None:
del args[key]
if _root is not None:
if _root is not None and not resolved:
field_name = to_snake_case(info.field_name)
if not hasattr(_root, "_fields_ordered"):
if getattr(_root, field_name, []) is not None:
if isinstance(getattr(_root, field_name, []), list):
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
elif field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field,
mongoengine.EmbeddedDocumentField) or
Expand All @@ -316,25 +350,33 @@ def default_resolver(self, _root, info, required_fields=None, **args):
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
if callable(getattr(self.model, "objects", None)):
if "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
else:
args["pk__in"] = args["pk__in"][skip:skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
elif _root is None or args:

if resolved is not None:
items = resolved

if isinstance(items, QuerySet):
try:
count = items.count(with_limit_and_skip=True)
except OperationFailure:
count = len(items)
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)

if limit:
if reverse:
items = items[::-1][skip:skip + limit]
else:
items = items[skip:skip + limit]
elif skip:
items = items[skip:]
iterables = items
list_length = len(iterables)

elif callable(getattr(self.model, "objects", None)):
if _root is None or args or isinstance(getattr(_root, field_name, []), MongoengineConnectionField):
args_copy = args.copy()
for key in args.copy():
if key not in self.model._fields_ordered:
Expand All @@ -346,8 +388,20 @@ def default_resolver(self, _root, info, required_fields=None, **args):
mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key),
mongoengine.fields.CachedReferenceField):
if not isinstance(args_copy[key], ObjectId):
args_copy[key] = from_global_id(args_copy[key])[1]
count = mongoengine.get_db()[self.model._get_collection_name()].count_documents(args_copy)
_from_global_id = from_global_id(args_copy[key])[1]
if bson.objectid.ObjectId.is_valid(_from_global_id):
args_copy[key] = ObjectId(_from_global_id)
else:
args_copy[key] = _from_global_id
elif isinstance(getattr(self.model, key),
mongoengine.fields.EnumField):
if getattr(args_copy[key], "value", None):
args_copy[key] = args_copy[key].value

if PYMONGO_VERSION >= (3, 7):
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
else:
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
Expand All @@ -358,6 +412,24 @@ def default_resolver(self, _root, info, required_fields=None, **args):
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
else:
args["pk__in"] = args["pk__in"][skip:skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)

elif _root is not None:
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
Expand All @@ -373,6 +445,7 @@ def default_resolver(self, _root, info, required_fields=None, **args):
items = items[skip:]
iterables = items
list_length = len(iterables)

has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
has_previous_page = True if skip else False
if reverse:
Expand All @@ -391,31 +464,42 @@ def default_resolver(self, _root, info, required_fields=None, **args):
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):

for key, value in dict(args).items():
if value is None:
del args[key]

required_fields = list()

for field in self.required_fields:
if field in self.model._fields_ordered:
required_fields.append(field)

for field in get_query_fields(info):
if to_snake_case(field) in self.model._fields_ordered:
required_fields.append(to_snake_case(field))

args_copy = args.copy()

if not bool(args) or not is_partial:
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):

from itertools import filterfalse
connection_fields = [field for field in self.fields if
type(self.fields[field]) == MongoengineConnectionField]
filterable_args = tuple(filterfalse(connection_fields.__contains__, list(self.model._fields_ordered)))
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
if arg_name not in filterable_args + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)

# XXX: Filter nested args
resolved = resolver(root, info, **args)

if resolved is not None:
if isinstance(resolved, list):
if resolved == list():
Expand All @@ -428,36 +512,55 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
self.filter_args.keys()):
if "." in arg_name or arg_name not in self.model._fields_ordered + (
'first', 'last', 'before', 'after') + tuple(
self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
operation = list(arg.keys())[0]
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
if not isinstance(arg, ObjectId) and '.' in arg_name:
operation = list(arg.keys())[0]
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
if type(arg) == dict:
operation = list(arg.keys())[0]
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[
operation]
else:
args_copy[arg_name.replace('.', '__')] = arg
elif '.' in arg_name and isinstance(arg, ObjectId):
args_copy[arg_name.replace('.', '__')] = arg
else:
operations = ["$lte", "$gte", "$ne", "$in"]
if isinstance(arg, dict) and any(op in arg for op in operations):
operation = list(arg.keys())[0]
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]
return self.default_resolver(root, info, required_fields, **args_copy)
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return resolved

return self.default_resolver(root, info, required_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
if root:
for key, value in root.__dict__.items():
if value:
try:
setattr(root, key, from_global_id(value)[1])
except Exception as error:
pass
iterable = resolver(root, info, **args)

if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
Expand Down
5 changes: 5 additions & 0 deletions graphene_mongo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def register_enum(self, cls):
assert type(cls) == EnumMeta, 'Only EnumMeta can be registered, received "{}"'.format(
cls.__name__
)
if not cls.__name__.endswith('Enum'):
name = cls.__name__ + 'Enum'
else:
name = cls.__name__
cls.__name__ = name
self._registry_enum[cls] = Enum.from_enum(cls)

def get_type_for_model(self, model):
Expand Down
Loading