diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index 2f7f2830..d29f727c 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -18,42 +18,11 @@ 2) fragment "spreads" e.g. "...c" 3) inline fragment "spreads" e.g. "...on Type { a }" """ - +from .execute import execute as _execute from .base import ExecutionResult -from .executor import Executor -from .middlewares.sync import SynchronousExecutionMiddleware def execute(schema, root, ast, operation_name='', args=None): - """ - Executes an AST synchronously. Assumes that the AST is already validated. - """ - return get_default_executor().execute(schema, ast, root, args, operation_name, validate_ast=False) - - -_default_executor = None - - -def get_default_executor(): - """ - Gets the default executor to be used in the `execute` function above. - """ - global _default_executor - if _default_executor is None: - _default_executor = Executor([SynchronousExecutionMiddleware()]) - - return _default_executor - - -def set_default_executor(executor): - """ - Sets the default executor to be used in the `execute` function above. - - If passed `None` will reset to the original default synchronous executor. - """ - assert isinstance(executor, Executor) or executor is None - global _default_executor - _default_executor = executor - + return _execute(schema, ast, root, variable_values=args, operation_name=operation_name) -__all__ = ['ExecutionResult', 'Executor', 'execute', 'get_default_executor', 'set_default_executor'] +__all__ = ['execute', 'ExecutionResult'] diff --git a/graphql/execution/base.py b/graphql/execution/base.py index 66c6a269..e7741930 100644 --- a/graphql/execution/base.py +++ b/graphql/execution/base.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from ..error import GraphQLError from ..language import ast -from ..pyutils.defer import DeferredException from ..type.definition import GraphQLInterfaceType, GraphQLUnionType from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, @@ -18,10 +17,10 @@ class ExecutionContext(object): Namely, schema of the type system that is currently executing, and the fragments defined in the query document""" - __slots__ = 'schema', 'fragments', 'root', 'operation', 'variables', 'errors', 'request_context', \ - 'argument_values_cache' + __slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \ + 'argument_values_cache', 'executor' - def __init__(self, schema, root, document_ast, operation_name, args, request_context): + def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor): """Constructs a ExecutionContext object from the arguments passed to execute, which we will pass throughout the other execution methods.""" @@ -53,16 +52,17 @@ def __init__(self, schema, root, document_ast, operation_name, args, request_con else: raise GraphQLError('Must provide an operation.') - variables = get_variable_values(schema, operation.variable_definitions or [], args) + variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values) self.schema = schema self.fragments = fragments - self.root = root + self.root_value = root_value self.operation = operation - self.variables = variables + self.variable_values = variable_values self.errors = errors - self.request_context = request_context + self.context_value = context_value self.argument_values_cache = {} + self.executor = executor def get_argument_values(self, field_def, field_ast): k = field_def, field_ast @@ -70,7 +70,7 @@ def get_argument_values(self, field_def, field_ast): if not result: result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments, - self.variables) + self.variable_values) return result @@ -84,12 +84,6 @@ class ExecutionResult(object): def __init__(self, data=None, errors=None, invalid=False): self.data = data - if errors: - errors = [ - error.value if isinstance(error, DeferredException) else error - for error in errors - ] - self.errors = errors if invalid: @@ -190,6 +184,7 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names def should_include_node(ctx, directives): """Determines if a field should be included based on the @include and @skip directives, where @skip has higher precidence than @include.""" + # TODO: Refactor based on latest code if directives: skip_ast = None @@ -202,7 +197,7 @@ def should_include_node(ctx, directives): args = get_argument_values( GraphQLSkipDirective.args, skip_ast.arguments, - ctx.variables, + ctx.variable_values, ) return not args.get('if') @@ -217,7 +212,7 @@ def should_include_node(ctx, directives): args = get_argument_values( GraphQLIncludeDirective.args, include_ast.arguments, - ctx.variables, + ctx.variable_values, ) return bool(args.get('if')) @@ -249,36 +244,17 @@ def get_field_entry_key(node): class ResolveInfo(object): - def __init__(self, field_name, field_asts, return_type, parent_type, context): + def __init__(self, field_name, field_asts, return_type, parent_type, + schema, fragments, root_value, operation, variable_values): self.field_name = field_name self.field_asts = field_asts self.return_type = return_type self.parent_type = parent_type - self.context = context - - @property - def schema(self): - return self.context.schema - - @property - def fragments(self): - return self.context.fragments - - @property - def root_value(self): - return self.context.root - - @property - def operation(self): - return self.context.operation - - @property - def variable_values(self): - return self.context.variables - - @property - def request_context(self): - return self.context.request_context + self.schema = schema + self.fragments = fragments + self.root_value = root_value + self.operation = operation + self.variable_values = variable_values def default_resolve_fn(source, args, info): diff --git a/graphql/execution/execute.py b/graphql/execution/execute.py new file mode 100644 index 00000000..def252d4 --- /dev/null +++ b/graphql/execution/execute.py @@ -0,0 +1,336 @@ +import collections +import functools + +from ..error import GraphQLError +from ..pyutils.aplus import Promise, is_thenable, promise_for_dict, promisify +from ..pyutils.default_ordered_dict import DefaultOrderedDict +from ..type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLUnionType) +from .base import (ExecutionContext, ExecutionResult, ResolveInfo, Undefined, + collect_fields, default_resolve_fn, get_field_def, + get_operation_root_type) +from .executors.sync import SyncExecutor + + +def execute(schema, document_ast, root_value=None, context_value=None, + variable_values=None, operation_name=None, executor=None): + assert schema, 'Must provide schema' + assert isinstance(schema, GraphQLSchema), ( + 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' + + 'not multiple versions of GraphQL installed in your node_modules directory.' + ) + + if executor is None: + executor = SyncExecutor() + + context = ExecutionContext( + schema, + document_ast, + root_value, + context_value, + variable_values, + operation_name, + executor + ) + + def executor(resolve, reject): + return resolve(execute_operation(context, context.operation, root_value)) + + def on_rejected(error): + context.errors.append(error) + return None + + def on_resolve(data): + return ExecutionResult(data=data, errors=context.errors) + + p = Promise(executor).catch(on_rejected).then(on_resolve) + context.executor.wait_until_finished() + return p.value + + +def execute_operation(exe_context, operation, root_value): + type = get_operation_root_type(exe_context.schema, operation) + fields = collect_fields( + exe_context, + type, + operation.selection_set, + DefaultOrderedDict(list), + set() + ) + + if operation.operation == 'mutation': + return execute_fields_serially(exe_context, type, root_value, fields) + + return execute_fields(exe_context, type, root_value, fields) + + +def execute_fields_serially(exe_context, parent_type, source_value, fields): + def execute_field_callback(results, response_name): + field_asts = fields[response_name] + result = resolve_field( + exe_context, + parent_type, + source_value, + field_asts + ) + if result is Undefined: + return results + + if is_thenable(result): + def collect_result(resolved_result): + results[response_name] = resolved_result + return results + + return promisify(result).then(collect_result, None) + + results[response_name] = result + return results + + def execute_field(prev_promise, response_name): + return prev_promise.then(lambda results: execute_field_callback(results, response_name)) + + return functools.reduce(execute_field, fields.keys(), Promise.resolve(collections.OrderedDict())) + + +def execute_fields(exe_context, parent_type, source_value, fields): + contains_promise = False + + final_results = collections.OrderedDict() + + for response_name, field_asts in fields.items(): + result = resolve_field(exe_context, parent_type, source_value, field_asts) + if result is Undefined: + continue + + final_results[response_name] = result + if is_thenable(result): + contains_promise = True + + if not contains_promise: + return final_results + + return promise_for_dict(final_results) + + +def resolve_field(exe_context, parent_type, source, field_asts): + field_ast = field_asts[0] + field_name = field_ast.name.value + + field_def = get_field_def(exe_context.schema, parent_type, field_name) + if not field_def: + return Undefined + + return_type = field_def.type + resolve_fn = field_def.resolver or default_resolve_fn + + # Build a dict of arguments from the field.arguments AST, using the variables scope to + # fulfill any variable references. + args = exe_context.get_argument_values(field_def, field_ast) + + # The resolve function's optional third argument is a collection of + # information about the current execution state. + info = ResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + schema=exe_context.schema, + fragments=exe_context.fragments, + root_value=exe_context.root_value, + operation=exe_context.operation, + variable_values=exe_context.variable_values, + ) + + result = resolve_or_error(resolve_fn, source, args, exe_context, info) + + return complete_value_catching_error( + exe_context, + return_type, + field_asts, + info, + result + ) + + +def resolve_or_error(resolve_fn, source, args, exe_context, info): + try: + # return resolve_fn(source, args, exe_context, info) + return exe_context.executor.execute(resolve_fn, source, args, info) + except Exception as e: + return e + + +def complete_value_catching_error(exe_context, return_type, field_asts, info, result): + # If the field type is non-nullable, then it is resolved without any + # protection from errors. + if isinstance(return_type, GraphQLNonNull): + return complete_value(exe_context, return_type, field_asts, info, result) + + # Otherwise, error protection is applied, logging the error and + # resolving a null value for this field if one is encountered. + try: + completed = complete_value(exe_context, return_type, field_asts, info, result) + if is_thenable(completed): + def handle_error(error): + exe_context.errors.append(error) + return Promise.fulfilled(None) + + return promisify(completed).then(None, handle_error) + + return completed + except Exception as e: + exe_context.errors.append(e) + return None + + +def complete_value(exe_context, return_type, field_asts, info, result): + """ + Implements the instructions for completeValue as defined in the + "Field entries" section of the spec. + + If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field + error if that completion returns null, as per the "Nullability" section of the spec. + + If the field type is a List, then this recursively completes the value for the inner type on each item in the + list. + + If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the + `serialize` method of GraphQL type definition. + + If the field is an abstract type, determine the runtime type of the value and then complete based on that type. + + Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all + sub-selections. + """ + # If field type is NonNull, complete for inner type, and throw field error if result is null. + + if is_thenable(result): + return promisify(result).then( + lambda resolved: complete_value( + exe_context, + return_type, + field_asts, + info, + resolved + ), + lambda error: Promise.rejected(GraphQLError(error and str(error), field_asts, error)) + ) + + if isinstance(result, Exception): + raise GraphQLError(str(result), field_asts, result) + + if isinstance(return_type, GraphQLNonNull): + completed = complete_value( + exe_context, return_type.of_type, field_asts, info, result + ) + if completed is None: + raise GraphQLError( + 'Cannot return null for non-nullable field {}.{}.'.format(info.parent_type, info.field_name), + field_asts + ) + + return completed + + # If result is null-like, return null. + if result is None: + return None + + # If field type is List, complete each item in the list with the inner type + if isinstance(return_type, GraphQLList): + return complete_list_value(exe_context, return_type, field_asts, info, result) + + # If field type is Scalar or Enum, serialize to a valid value, returning null if coercion is not possible. + if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): + return complete_leaf_value(return_type, result) + + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + return complete_abstract_value(exe_context, return_type, field_asts, info, result) + + if isinstance(return_type, GraphQLObjectType): + return complete_object_value(exe_context, return_type, field_asts, info, result) + + assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) + + +def complete_list_value(exe_context, return_type, field_asts, info, result): + """ + Complete a list value by completing each item in the list with the inner type + """ + assert isinstance(result, collections.Iterable), \ + ('User Error: expected iterable, but did not find one ' + + 'for field {}.{}.').format(info.parent_type, info.field_name) + + item_type = return_type.of_type + completed_results = [] + contains_promise = False + for item in result: + completed_item = complete_value_catching_error(exe_context, item_type, field_asts, info, item) + if not contains_promise and is_thenable(completed_item): + contains_promise = True + + completed_results.append(completed_item) + + return Promise.all(completed_results) if contains_promise else completed_results + + +def complete_leaf_value(return_type, result): + """ + Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. + """ + serialize = getattr(return_type, 'serialize', None) + assert serialize, 'Missing serialize method on type' + + serialized_result = serialize(result) + + if serialized_result is None: + return None + + return serialized_result + + +# TODO: Refactor based on js implementation +def complete_abstract_value(exe_context, return_type, field_asts, info, result): + """ + Complete an value of an abstract type by determining the runtime type of that value, then completing based + on that type. + """ + # Field type must be Object, Interface or Union and expect sub-selections. + runtime_type = None + + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + runtime_type = return_type.resolve_type(result, info) + if runtime_type and not return_type.is_possible_type(runtime_type): + raise GraphQLError( + u'Runtime Object type "{}" is not a possible type for "{}".'.format(runtime_type, return_type), + field_asts + ) + + if not runtime_type: + return None + + return complete_object_value(exe_context, runtime_type, field_asts, info, result) + + +def complete_object_value(exe_context, return_type, field_asts, info, result): + """ + Complete an Object value by evaluating all sub-selections. + """ + if return_type.is_type_of and not return_type.is_type_of(result, info): + raise GraphQLError( + u'Expected value of type "{}" but got: {}.'.format(return_type, type(result).__name__), + field_asts + ) + + # Collect sub-fields to execute to complete this value. + subfield_asts = DefaultOrderedDict(list) + visited_fragment_names = set() + for field_ast in field_asts: + selection_set = field_ast.selection_set + if selection_set: + subfield_asts = collect_fields( + exe_context, return_type, selection_set, + subfield_asts, visited_fragment_names + ) + + return execute_fields(exe_context, return_type, result, subfield_asts) diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py deleted file mode 100644 index b34188eb..00000000 --- a/graphql/execution/executor.py +++ /dev/null @@ -1,355 +0,0 @@ -import collections -import functools - -from ..error import GraphQLError -from ..language import ast -from ..language.parser import parse -from ..language.source import Source -from ..pyutils.default_ordered_dict import DefaultOrderedDict -from ..pyutils.defer import (Deferred, DeferredDict, DeferredList, defer, - succeed) -from ..type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList, - GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, - GraphQLUnionType) -from ..validation import validate -from .base import (ExecutionContext, ExecutionResult, ResolveInfo, Undefined, - collect_fields, default_resolve_fn, get_field_def, - get_operation_root_type) - - -class Executor(object): - - def __init__(self, execution_middlewares=None, default_resolver=default_resolve_fn, map_type=dict): - assert issubclass(map_type, collections.MutableMapping) - - self._execution_middlewares = execution_middlewares or [] - self._default_resolve_fn = default_resolver - self._map_type = map_type - self._enforce_strict_ordering = issubclass(map_type, collections.OrderedDict) - - @property - def enforce_strict_ordering(self): - return self._enforce_strict_ordering - - @property - def map_type(self): - return self._map_type - - def execute(self, schema, request='', root=None, args=None, operation_name=None, request_context=None, - execute_serially=False, validate_ast=True): - - curried_execution_function = functools.partial( - self._execute, - schema, - request, - root, - args, - operation_name, - request_context, - execute_serially, - validate_ast - ) - - for middleware in self._execution_middlewares: - if hasattr(middleware, 'execution_result'): - curried_execution_function = functools.partial(middleware.execution_result, curried_execution_function) - - return curried_execution_function() - - def _execute(self, schema, request, root, args, operation_name, request_context, execute_serially, validate_ast): - if not isinstance(request, ast.Document): - if not isinstance(request, Source): - request = Source(request, 'GraphQL request') - - request = parse(request) - - if validate_ast: - validation_errors = validate(schema, request) - if validation_errors: - return succeed(ExecutionResult( - errors=validation_errors, - invalid=True, - )) - - return self._execute_graphql_query( - schema, - root or object(), - request, - operation_name, - args or {}, - request_context or {}, - execute_serially) - - def _execute_graphql_query(self, schema, root, ast, operation_name, args, request_context, execute_serially=False): - ctx = ExecutionContext(schema, root, ast, operation_name, args, request_context) - - return defer(self._execute_operation, ctx, root, ctx.operation, execute_serially) \ - .add_errback( - lambda error: ctx.errors.append(error) - ) \ - .add_callback( - lambda data: ExecutionResult(data, ctx.errors), - ) - - def _execute_operation(self, ctx, root, operation, execute_serially): - type = get_operation_root_type(ctx.schema, operation) - - if operation.operation == 'mutation' or execute_serially: - execute_serially = True - - fields = DefaultOrderedDict(list) \ - if (execute_serially or self._enforce_strict_ordering) \ - else collections.defaultdict(list) - - fields = collect_fields(ctx, type, operation.selection_set, fields, set()) - - if execute_serially: - return self._execute_fields_serially(ctx, type, root, fields) - - return self._execute_fields(ctx, type, root, fields) - - def _execute_fields_serially(self, execution_context, parent_type, source_value, fields): - def execute_field_callback(results, response_name): - field_asts = fields[response_name] - result = self._resolve_field(execution_context, parent_type, source_value, field_asts) - if result is Undefined: - return results - - def collect_result(resolved_result): - results[response_name] = resolved_result - return results - - if isinstance(result, Deferred): - return succeed(result).add_callback(collect_result) - - else: - return collect_result(result) - - def execute_field(prev_deferred, response_name): - return prev_deferred.add_callback(execute_field_callback, response_name) - - return functools.reduce(execute_field, fields.keys(), succeed(self._map_type())) - - def _execute_fields(self, execution_context, parent_type, source_value, fields): - contains_deferred = False - - results = self._map_type() - for response_name, field_asts in fields.items(): - result = self._resolve_field(execution_context, parent_type, source_value, field_asts) - if result is Undefined: - continue - - results[response_name] = result - if isinstance(result, Deferred): - contains_deferred = True - - if not contains_deferred: - return results - - return DeferredDict(results) - - def _resolve_field(self, execution_context, parent_type, source, field_asts): - field_ast = field_asts[0] - field_name = field_ast.name.value - - field_def = get_field_def(execution_context.schema, parent_type, field_name) - if not field_def: - return Undefined - - return_type = field_def.type - resolve_fn = field_def.resolver or self._default_resolve_fn - - # Build a dict of arguments from the field.arguments AST, using the variables scope to - # fulfill any variable references. - args = execution_context.get_argument_values(field_def, field_ast) - - # The resolve function's optional third argument is a collection of - # information about the current execution state. - info = ResolveInfo( - field_name, - field_asts, - return_type, - parent_type, - execution_context - ) - - result = self.resolve_or_error(resolve_fn, source, args, info) - return self.complete_value_catching_error( - execution_context, return_type, field_asts, info, result - ) - - def complete_value_catching_error(self, ctx, return_type, field_asts, info, result): - # If the field type is non-nullable, then it is resolved without any - # protection from errors. - if isinstance(return_type, GraphQLNonNull): - return self.complete_value(ctx, return_type, field_asts, info, result) - - # Otherwise, error protection is applied, logging the error and - # resolving a null value for this field if one is encountered. - try: - completed = self.complete_value(ctx, return_type, field_asts, info, result) - if isinstance(completed, Deferred): - def handle_error(error): - ctx.errors.append(error) - return None - - return completed.add_errback(handle_error) - - return completed - except Exception as e: - ctx.errors.append(e) - return None - - def complete_value(self, ctx, return_type, field_asts, info, result): - """ - Implements the instructions for completeValue as defined in the - "Field entries" section of the spec. - - If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field - error if that completion returns null, as per the "Nullability" section of the spec. - - If the field type is a List, then this recursively completes the value for the inner type on each item in the - list. - - If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the - `serialize` method of GraphQL type definition. - - If the field is an abstract type, determine the runtime type of the value and then complete based on that type. - - Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all - sub-selections. - """ - # If field type is NonNull, complete for inner type, and throw field error if result is null. - if isinstance(result, Deferred): - return result.add_callbacks( - lambda resolved: self.complete_value( - ctx, - return_type, - field_asts, - info, - resolved - ), - lambda error: GraphQLError(error.value and str(error.value), field_asts, error) - ) - - if isinstance(result, Exception): - raise GraphQLError(str(result), field_asts, result) - - if isinstance(return_type, GraphQLNonNull): - completed = self.complete_value( - ctx, return_type.of_type, field_asts, info, result - ) - if completed is None: - raise GraphQLError( - 'Cannot return null for non-nullable field {}.{}.'.format(info.parent_type, info.field_name), - field_asts - ) - - return completed - - # If result is null-like, return null. - if result is None: - return None - - # If field type is List, complete each item in the list with the inner type - if isinstance(return_type, GraphQLList): - return self.complete_list_value(ctx, return_type, field_asts, info, result) - - # If field type is Scalar or Enum, serialize to a valid value, returning null if coercion is not possible. - if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): - return self.complete_leaf_value(ctx, return_type, field_asts, info, result) - - if isinstance(return_type, GraphQLObjectType): - return self.complete_object_value(ctx, return_type, field_asts, info, result) - - if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): - return self.complete_abstract_value(ctx, return_type, field_asts, info, result) - - assert False, u'Cannot complete value of unexpected type "{}"'.format(return_type) - - def complete_list_value(self, ctx, return_type, field_asts, info, result): - """ - Complete a list value by completing each item in the list with the inner type - """ - assert isinstance(result, collections.Iterable), \ - ('User Error: expected iterable, but did not find one' + - 'for field {}.{}').format(info.parent_type, info.field_name) - - item_type = return_type.of_type - completed_results = [] - contains_deferred = False - for item in result: - completed_item = self.complete_value_catching_error(ctx, item_type, field_asts, info, item) - if not contains_deferred and isinstance(completed_item, Deferred): - contains_deferred = True - - completed_results.append(completed_item) - - return DeferredList(completed_results) if contains_deferred else completed_results - - def complete_leaf_value(self, ctx, return_type, field_asts, info, result): - """ - Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. - """ - serialized_result = return_type.serialize(result) - - if serialized_result is None: - return None - - return serialized_result - - def complete_object_value(self, ctx, return_type, field_asts, info, result): - """ - Complete an Object value by evaluating all sub-selections. - """ - if return_type.is_type_of and not return_type.is_type_of(result, info): - raise GraphQLError( - u'Expected value of type "{}" but got {}.'.format(return_type, type(result).__name__), - field_asts - ) - - # Collect sub-fields to execute to complete this value. - subfield_asts = DefaultOrderedDict(list) if self._enforce_strict_ordering else collections.defaultdict(list) - visited_fragment_names = set() - for field_ast in field_asts: - selection_set = field_ast.selection_set - if selection_set: - subfield_asts = collect_fields( - ctx, return_type, selection_set, - subfield_asts, visited_fragment_names - ) - - return self._execute_fields(ctx, return_type, result, subfield_asts) - - def complete_abstract_value(self, ctx, return_type, field_asts, info, result): - """ - Complete an value of an abstract type by determining the runtime type of that value, then completing based - on that type. - """ - # Field type must be Object, Interface or Union and expect sub-selections. - runtime_type = None - - if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): - runtime_type = return_type.resolve_type(result, info) - if runtime_type and not return_type.is_possible_type(runtime_type): - raise GraphQLError( - u'Runtime Object type "{}" is not a possible type for "{}".'.format(runtime_type, return_type), - field_asts - ) - - if not runtime_type: - return None - - return self.complete_object_value(ctx, runtime_type, field_asts, info, result) - - def resolve_or_error(self, resolve_fn, source, args, info): - curried_resolve_fn = functools.partial(resolve_fn, source, args, info) - - try: - for middleware in self._execution_middlewares: - if hasattr(middleware, 'run_resolve_fn'): - curried_resolve_fn = functools.partial(middleware.run_resolve_fn, curried_resolve_fn, resolve_fn) - - return curried_resolve_fn() - except Exception as e: - return e diff --git a/graphql/execution/executors/__init__.py b/graphql/execution/executors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py new file mode 100644 index 00000000..8c91bf71 --- /dev/null +++ b/graphql/execution/executors/asyncio.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import + +from asyncio import Future, ensure_future, get_event_loop, iscoroutine, wait + +from graphql.pyutils.aplus import Promise + + +def process_future_result(promise): + def handle_future_result(future): + exception = future.exception() + if exception: + promise.reject(exception) + else: + promise.fulfill(future.result()) + + return handle_future_result + + +class AsyncioExecutor(object): + + def __init__(self): + self.loop = get_event_loop() + self.futures = [] + + def wait_until_finished(self): + self.loop.run_until_complete(wait(self.futures)) + + def execute(self, fn, *args, **kwargs): + result = fn(*args, **kwargs) + if isinstance(result, Future) or iscoroutine(result): + promise = Promise() + future = ensure_future(result) + self.futures.append(future) + future.add_done_callback(process_future_result(promise)) + return promise + return result diff --git a/graphql/execution/executors/gevent.py b/graphql/execution/executors/gevent.py new file mode 100644 index 00000000..e6e04052 --- /dev/null +++ b/graphql/execution/executors/gevent.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import + +import gevent + +from ...pyutils.aplus import Promise +from .utils import process + + +class GeventExecutor(object): + + def __init__(self): + self.jobs = [] + + def wait_until_finished(self): + [j.join() for j in self.jobs] + # gevent.joinall(self.jobs) + + def execute(self, fn, *args, **kwargs): + promise = Promise() + job = gevent.spawn(process, promise, fn, args, kwargs) + self.jobs.append(job) + return promise diff --git a/graphql/execution/executors/process.py b/graphql/execution/executors/process.py new file mode 100644 index 00000000..3f770450 --- /dev/null +++ b/graphql/execution/executors/process.py @@ -0,0 +1,31 @@ +from multiprocessing import Process, Queue + +from ...pyutils.aplus import Promise +from .utils import process + + +def queue_process(q): + promise, fn, args, kwargs = q.get() + process(promise, fn, args, kwargs) + + +class ProcessExecutor(object): + + def __init__(self): + self.processes = [] + self.q = Queue() + + def wait_until_finished(self): + for _process in self.processes: + _process.join() + self.q.close() + self.q.join_thread() + + def execute(self, fn, *args, **kwargs): + promise = Promise() + + self.q.put([promise, fn, args, kwargs], False) + _process = Process(target=queue_process, args=(self.q)) + _process.start() + self.processes.append(_process) + return promise diff --git a/graphql/execution/executors/sync.py b/graphql/execution/executors/sync.py new file mode 100644 index 00000000..85f8471b --- /dev/null +++ b/graphql/execution/executors/sync.py @@ -0,0 +1,7 @@ +class SyncExecutor(object): + + def wait_until_finished(self): + pass + + def execute(self, fn, *args, **kwargs): + return fn(*args, **kwargs) diff --git a/graphql/execution/executors/thread.py b/graphql/execution/executors/thread.py new file mode 100644 index 00000000..b40ae2b4 --- /dev/null +++ b/graphql/execution/executors/thread.py @@ -0,0 +1,34 @@ +from multiprocessing.pool import ThreadPool +from threading import Thread + +from ...pyutils.aplus import Promise +from .utils import process + + +class ThreadExecutor(object): + + pool = None + + def __init__(self, pool=False): + self.threads = [] + if pool: + self.execute = self.execute_in_pool + self.pool = ThreadPool(processes=pool) + else: + self.execute = self.execute_in_thread + + def wait_until_finished(self): + for thread in self.threads: + thread.join() + + def execute_in_thread(self, fn, *args, **kwargs): + promise = Promise() + thread = Thread(target=process, args=(promise, fn, args, kwargs)) + thread.start() + self.threads.append(thread) + return promise + + def execute_in_pool(self, fn, *args, **kwargs): + promise = Promise() + self.pool.map(lambda input: process(*input), [(promise, fn, args, kwargs)]) + return promise diff --git a/graphql/execution/executors/utils.py b/graphql/execution/executors/utils.py new file mode 100644 index 00000000..79b67cbe --- /dev/null +++ b/graphql/execution/executors/utils.py @@ -0,0 +1,6 @@ +def process(p, f, args, kwargs): + try: + val = f(*args, **kwargs) + p.fulfill(val) + except Exception as e: + p.reject(e) diff --git a/graphql/execution/middlewares/__init__.py b/graphql/execution/middlewares/__init__.py deleted file mode 100644 index 9db7df9a..00000000 --- a/graphql/execution/middlewares/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = 'jake' diff --git a/graphql/execution/middlewares/asyncio.py b/graphql/execution/middlewares/asyncio.py deleted file mode 100644 index ef95602e..00000000 --- a/graphql/execution/middlewares/asyncio.py +++ /dev/null @@ -1,40 +0,0 @@ -# flake8: noqa -from asyncio import Future, ensure_future, iscoroutine - -from ...pyutils.defer import Deferred - - -def process_future_result(deferred): - def handle_future_result(future): - exception = future.exception() - if exception: - deferred.errback(exception) - - else: - deferred.callback(future.result()) - - return handle_future_result - - -class AsyncioExecutionMiddleware(object): - - @staticmethod - def run_resolve_fn(resolver, original_resolver): - result = resolver() - if isinstance(result, Future) or iscoroutine(result): - future = ensure_future(result) - d = Deferred() - future.add_done_callback(process_future_result(d)) - return d - - return result - - @staticmethod - def execution_result(executor): - future = Future() - result = executor() - assert isinstance(result, Deferred), 'Another middleware has converted the execution result ' \ - 'away from a Deferred.' - - result.add_callbacks(future.set_result, future.set_exception) - return future diff --git a/graphql/execution/middlewares/gevent.py b/graphql/execution/middlewares/gevent.py deleted file mode 100644 index 1b000c02..00000000 --- a/graphql/execution/middlewares/gevent.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import - -from gevent import get_hub, spawn -from gevent.event import AsyncResult - -from ...pyutils.defer import Deferred, DeferredException -from .utils import resolver_has_tag, tag_resolver - - -def _run_resolver_in_greenlet(d, resolver): - try: - result = resolver() - get_hub().loop.run_callback(d.callback, result) - except: - e = DeferredException() - get_hub().loop.run_callback(d.errback, e) - - -def run_in_greenlet(f): - """ - Marks a resolver to run inside a greenlet. - - @run_in_greenlet - def resolve_something(context, _*): - gevent.sleep(1) - return 5 - - """ - return tag_resolver(f, 'run_in_greenlet') - - -class GeventExecutionMiddleware(object): - - @staticmethod - def run_resolve_fn(resolver, original_resolver): - if resolver_has_tag(original_resolver, 'run_in_greenlet'): - d = Deferred() - spawn(_run_resolver_in_greenlet, d, resolver) - return d - - return resolver() - - @staticmethod - def execution_result(executor): - result = AsyncResult() - deferred = executor() - assert isinstance(deferred, Deferred), 'Another middleware has converted the execution result ' \ - 'away from a Deferred.' - - deferred.add_callbacks(result.set, lambda e: result.set_exception(e.value, (e.type, e.value, e.traceback))) - return result.get() diff --git a/graphql/execution/middlewares/sync.py b/graphql/execution/middlewares/sync.py deleted file mode 100644 index a0fa8bbc..00000000 --- a/graphql/execution/middlewares/sync.py +++ /dev/null @@ -1,18 +0,0 @@ -from ...error import GraphQLError -from ...pyutils.defer import Deferred - - -class SynchronousExecutionMiddleware(object): - - @staticmethod - def run_resolve_fn(resolver, original_resolver): - result = resolver() - if isinstance(result, Deferred): - raise GraphQLError('You cannot return a Deferred from a resolver when using SynchronousExecutionMiddleware') - - return result - - @staticmethod - def execution_result(executor): - result = executor() - return result.result diff --git a/graphql/execution/middlewares/utils.py b/graphql/execution/middlewares/utils.py deleted file mode 100644 index 1f64db10..00000000 --- a/graphql/execution/middlewares/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -def tag_resolver(f, tag): - """ - Tags a resolver function with a specific tag that can be read by a Middleware to denote specific functionality. - :param f: The function to tag. - :param tag: The tag to add to the function. - :return: The function with the tag added. - """ - if not hasattr(f, '_resolver_tags'): - f._resolver_tags = set() - - f._resolver_tags.add(tag) - return f - - -def resolver_has_tag(f, tag): - """ - Checks to see if a function has a specific tag. - """ - if not hasattr(f, '_resolver_tags'): - return False - - return tag in f._resolver_tags - - -def merge_resolver_tags(source_resolver, target_resolver): - if not hasattr(source_resolver, '_resolver_tags'): - return target_resolver - - if not hasattr(target_resolver, '_resolver_tags'): - target_resolver._resolver_tags = set() - - target_resolver._resolver_tags |= source_resolver._resolver_tags - return target_resolver diff --git a/graphql/execution/tests/test_concurrent_executor.py b/graphql/execution/tests/test_concurrent_executor.py deleted file mode 100644 index 70a69726..00000000 --- a/graphql/execution/tests/test_concurrent_executor.py +++ /dev/null @@ -1,360 +0,0 @@ -from collections import OrderedDict - -from graphql.error import format_error -from graphql.execution import Executor -from graphql.execution.middlewares.sync import SynchronousExecutionMiddleware -from graphql.pyutils.defer import Deferred, fail, succeed -from graphql.type import (GraphQLArgument, GraphQLField, GraphQLInt, - GraphQLList, GraphQLObjectType, GraphQLSchema, - GraphQLString) -from graphql.type.definition import GraphQLNonNull - -from .utils import raise_callback_results - - -def test_executes_arbitary_code(): - class Data(object): - a = 'Apple' - b = 'Banana' - c = 'Cookie' - d = 'Donut' - e = 'Egg' - - @property - def f(self): - return succeed('Fish') - - def pic(self, size=50): - return succeed('Pic of size: {}'.format(size)) - - def deep(self): - return DeepData() - - def promise(self): - return succeed(Data()) - - class DeepData(object): - a = 'Already Been Done' - b = 'Boring' - c = ['Contrived', None, succeed('Confusing')] - - def deeper(self): - return [Data(), None, succeed(Data())] - - doc = ''' - query Example($size: Int) { - a, - b, - x: c - ...c - f - ...on DataType { - pic(size: $size) - promise { - a - } - } - deep { - a - b - c - deeper { - a - b - } - } - } - fragment c on DataType { - d - e - } - ''' - - expected = { - 'a': 'Apple', - 'b': 'Banana', - 'x': 'Cookie', - 'd': 'Donut', - 'e': 'Egg', - 'f': 'Fish', - 'pic': 'Pic of size: 100', - 'promise': {'a': 'Apple'}, - 'deep': { - 'a': 'Already Been Done', - 'b': 'Boring', - 'c': ['Contrived', None, 'Confusing'], - 'deeper': [ - {'a': 'Apple', 'b': 'Banana'}, - None, - {'a': 'Apple', 'b': 'Banana'}]} - } - - DataType = GraphQLObjectType('DataType', lambda: { - 'a': GraphQLField(GraphQLString), - 'b': GraphQLField(GraphQLString), - 'c': GraphQLField(GraphQLString), - 'd': GraphQLField(GraphQLString), - 'e': GraphQLField(GraphQLString), - 'f': GraphQLField(GraphQLString), - 'pic': GraphQLField( - args={'size': GraphQLArgument(GraphQLInt)}, - type=GraphQLString, - resolver=lambda obj, args, *_: obj.pic(args['size']), - ), - 'deep': GraphQLField(DeepDataType), - 'promise': GraphQLField(DataType), - }) - - DeepDataType = GraphQLObjectType('DeepDataType', { - 'a': GraphQLField(GraphQLString), - 'b': GraphQLField(GraphQLString), - 'c': GraphQLField(GraphQLList(GraphQLString)), - 'deeper': GraphQLField(GraphQLList(DataType)), - }) - - schema = GraphQLSchema(query=DataType) - executor = Executor() - - def handle_result(result): - assert not result.errors - assert result.data == expected - - raise_callback_results(executor.execute(schema, doc, Data(), {'size': 100}, 'Example'), handle_result) - raise_callback_results(executor.execute(schema, doc, Data(), {'size': 100}, 'Example', execute_serially=True), - handle_result) - - -def test_synchronous_executor_doesnt_support_defers_with_nullable_type_getting_set_to_null(): - class Data(object): - - def promise(self): - return succeed('i shouldn\'nt work') - - def notPromise(self): - return 'i should work' - - DataType = GraphQLObjectType('DataType', { - 'promise': GraphQLField(GraphQLString), - 'notPromise': GraphQLField(GraphQLString), - }) - doc = ''' - query Example { - promise - notPromise - } - ''' - schema = GraphQLSchema(query=DataType) - executor = Executor([SynchronousExecutionMiddleware()]) - - result = executor.execute(schema, doc, Data(), operation_name='Example') - assert not isinstance(result, Deferred) - assert result.data == {"promise": None, 'notPromise': 'i should work'} - formatted_errors = list(map(format_error, result.errors)) - assert formatted_errors == [{'locations': [dict(line=3, column=9)], - 'message': 'You cannot return a Deferred from a resolver ' - 'when using SynchronousExecutionMiddleware'}] - - -def test_synchronous_executor_doesnt_support_defers(): - class Data(object): - - def promise(self): - return succeed('i shouldn\'nt work') - - def notPromise(self): - return 'i should work' - - DataType = GraphQLObjectType('DataType', { - 'promise': GraphQLField(GraphQLNonNull(GraphQLString)), - 'notPromise': GraphQLField(GraphQLString), - }) - doc = ''' - query Example { - promise - notPromise - } - ''' - schema = GraphQLSchema(query=DataType) - executor = Executor([SynchronousExecutionMiddleware()]) - - result = executor.execute(schema, doc, Data(), operation_name='Example') - assert not isinstance(result, Deferred) - assert result.data is None - formatted_errors = list(map(format_error, result.errors)) - assert formatted_errors == [{'locations': [dict(line=3, column=9)], - 'message': 'You cannot return a Deferred from a resolver ' - 'when using SynchronousExecutionMiddleware'}] - - -def test_executor_defer_failure(): - class Data(object): - - def promise(self): - return fail(Exception('Something bad happened! Sucks :(')) - - def notPromise(self): - return 'i should work' - - DataType = GraphQLObjectType('DataType', { - 'promise': GraphQLField(GraphQLNonNull(GraphQLString)), - 'notPromise': GraphQLField(GraphQLString), - }) - doc = ''' - query Example { - promise - notPromise - } - ''' - schema = GraphQLSchema(query=DataType) - executor = Executor() - - result = executor.execute(schema, doc, Data(), operation_name='Example') - assert result.called - result = result.result - assert result.data is None - formatted_errors = list(map(format_error, result.errors)) - assert formatted_errors == [{'locations': [dict(line=3, column=9)], - 'message': "Something bad happened! Sucks :("}] - - -def test_synchronous_executor_will_synchronously_resolve(): - class Data(object): - - def promise(self): - return 'I should work' - - DataType = GraphQLObjectType('DataType', { - 'promise': GraphQLField(GraphQLString), - }) - doc = ''' - query Example { - promise - } - ''' - schema = GraphQLSchema(query=DataType) - executor = Executor([SynchronousExecutionMiddleware()]) - - result = executor.execute(schema, doc, Data(), operation_name='Example') - assert not isinstance(result, Deferred) - assert result.data == {"promise": 'I should work'} - assert not result.errors - - -def test_synchronous_error_nulls_out_error_subtrees(): - doc = ''' - { - sync - syncError - syncReturnError - syncReturnErrorList - async - asyncReject - asyncEmptyReject - asyncReturnError - } - ''' - - class Data: - - def sync(self): - return 'sync' - - def syncError(self): - raise Exception('Error getting syncError') - - def syncReturnError(self): - return Exception("Error getting syncReturnError") - - def syncReturnErrorList(self): - return [ - 'sync0', - Exception('Error getting syncReturnErrorList1'), - 'sync2', - Exception('Error getting syncReturnErrorList3') - ] - - def async(self): - return succeed('async') - - def asyncReject(self): - return fail(Exception('Error getting asyncReject')) - - def asyncEmptyReject(self): - return fail() - - def asyncReturnError(self): - return succeed(Exception('Error getting asyncReturnError')) - - schema = GraphQLSchema( - query=GraphQLObjectType( - name='Type', - fields={ - 'sync': GraphQLField(GraphQLString), - 'syncError': GraphQLField(GraphQLString), - 'syncReturnError': GraphQLField(GraphQLString), - 'syncReturnErrorList': GraphQLField(GraphQLList(GraphQLString)), - 'async': GraphQLField(GraphQLString), - 'asyncReject': GraphQLField(GraphQLString), - 'asyncEmptyReject': GraphQLField(GraphQLString), - 'asyncReturnError': GraphQLField(GraphQLString), - } - ) - ) - - executor = Executor(map_type=OrderedDict) - - def handle_results(result): - assert result.data == { - 'async': 'async', - 'asyncEmptyReject': None, - 'asyncReject': None, - 'asyncReturnError': None, - 'sync': 'sync', - 'syncError': None, - 'syncReturnError': None, - 'syncReturnErrorList': ['sync0', None, 'sync2', None] - } - assert list(map(format_error, result.errors)) == [ - {'locations': [{'line': 4, 'column': 9}], 'message': 'Error getting syncError'}, - {'locations': [{'line': 5, 'column': 9}], 'message': 'Error getting syncReturnError'}, - {'locations': [{'line': 6, 'column': 9}], 'message': 'Error getting syncReturnErrorList1'}, - {'locations': [{'line': 6, 'column': 9}], 'message': 'Error getting syncReturnErrorList3'}, - {'locations': [{'line': 8, 'column': 9}], 'message': 'Error getting asyncReject'}, - {'locations': [{'line': 9, 'column': 9}], 'message': 'An unknown error occurred.'}, - {'locations': [{'line': 10, 'column': 9}], 'message': 'Error getting asyncReturnError'} - ] - - raise_callback_results(executor.execute(schema, doc, Data()), handle_results) - - -def test_executor_can_enforce_strict_ordering(): - Type = GraphQLObjectType('Type', lambda: { - 'a': GraphQLField(GraphQLString, - resolver=lambda *_: succeed('Apple')), - 'b': GraphQLField(GraphQLString, - resolver=lambda *_: succeed('Banana')), - 'c': GraphQLField(GraphQLString, - resolver=lambda *_: succeed('Cherry')), - 'deep': GraphQLField(Type, resolver=lambda *_: succeed({})), - }) - schema = GraphQLSchema(query=Type) - executor = Executor(map_type=OrderedDict) - - query = '{ a b c aa: c cc: c bb: b aaz: a bbz: b deep { b a c deeper: deep { c a b } } ' \ - 'ccz: c zzz: c aaa: a }' - - def handle_results(result): - assert not result.errors - - data = result.data - assert isinstance(data, OrderedDict) - assert list(data.keys()) == ['a', 'b', 'c', 'aa', 'cc', 'bb', 'aaz', 'bbz', 'deep', 'ccz', 'zzz', 'aaa'] - deep = data['deep'] - assert isinstance(deep, OrderedDict) - assert list(deep.keys()) == ['b', 'a', 'c', 'deeper'] - deeper = deep['deeper'] - assert isinstance(deeper, OrderedDict) - assert list(deeper.keys()) == ['c', 'a', 'b'] - - raise_callback_results(executor.execute(schema, query), handle_results) - raise_callback_results(executor.execute(schema, query, execute_serially=True), handle_results) diff --git a/graphql/execution/tests/test_default_executor.py b/graphql/execution/tests/test_default_executor.py deleted file mode 100644 index 10d32762..00000000 --- a/graphql/execution/tests/test_default_executor.py +++ /dev/null @@ -1,17 +0,0 @@ -from graphql.execution import (Executor, get_default_executor, - set_default_executor) - - -def test_get_and_set_default_executor(): - e1 = get_default_executor() - e2 = get_default_executor() - assert e1 is e2 - - new_executor = Executor() - - set_default_executor(new_executor) - assert get_default_executor() is new_executor - - set_default_executor(None) - assert get_default_executor() is not e1 - assert get_default_executor() is not new_executor diff --git a/graphql/execution/tests/test_deferred.py b/graphql/execution/tests/test_deferred.py deleted file mode 100644 index 00111905..00000000 --- a/graphql/execution/tests/test_deferred.py +++ /dev/null @@ -1,278 +0,0 @@ -from pytest import raises - -from graphql.pyutils.defer import (AlreadyCalledDeferred, Deferred, - DeferredDict, DeferredException, - DeferredList, fail, succeed) - - -def test_succeed(): - d = succeed("123") - assert d.result == "123" - assert d.called - assert not d.callbacks - - -def test_fail_none(): - d = fail() - assert isinstance(d.result, DeferredException) - assert d.called - assert not d.callbacks - - -def test_fail_none_catches_exception(): - e = Exception('will be raised') - try: - raise e - except: - d = fail() - assert d.called - assert isinstance(d.result, DeferredException) - assert d.result.value == e - - -def test_fail(): - e = Exception('failed') - d = fail(e) - assert isinstance(d.result, DeferredException) - assert d.result.value == e - assert d.called - assert not d.callbacks - - -def test_nested_succeed(): - d = succeed(succeed('123')) - assert d.result == "123" - assert d.called - assert not d.callbacks - - d = succeed(succeed(succeed('123'))) - assert d.result == "123" - assert d.called - assert not d.callbacks - - -def test_callback_result_transformation(): - d = succeed(5) - d.add_callback(lambda r: r + 5) - assert d.result == 10 - - d.add_callback(lambda r: succeed(r + 5)) - - assert d.result == 15 - - -def test_deferred_list(): - d = Deferred() - - dl = DeferredList([ - 1, - d - ]) - - assert not dl.called - d.callback(2) - - assert dl.called - assert dl.result == [1, 2] - - -def test_deferred_list_with_already_resolved_deferred_values(): - dl = DeferredList([ - 1, - succeed(2), - succeed(3) - ]) - - assert dl.called - assert dl.result == [1, 2, 3] - - -def test_deferred_dict(): - d = Deferred() - - dd = DeferredDict({ - 'a': 1, - 'b': d - }) - - assert not dd.called - d.callback(2) - - assert dd.called - assert dd.result == {'a': 1, 'b': 2} - - -def test_deferred_list_of_no_defers(): - dl = DeferredList([ - {'ab': 1}, - 2, - [1, 2, 3], - "345" - ]) - - assert dl.called - assert dl.result == [ - {'ab': 1}, - 2, - [1, 2, 3], - "345" - ] - - -def test_callback_resolution(): - d = Deferred() - d.add_callback(lambda r: fail(Exception(r + "b"))) - d.add_errback(lambda e: e.value.args[0] + "c") - d.add_callbacks(lambda r: r + "d", lambda e: e.value.args[0] + 'f') - - d.callback("a") - - assert d.result == "abcd" - - -def test_callback_resolution_weaving(): - d = Deferred() - d.add_callbacks(lambda r: fail(Exception(r + "b")), lambda e: e.value.args[0] + 'w') - d.add_callbacks(lambda e: Exception(e + "x"), lambda e: e.value.args[0] + "c") - d.add_callbacks(lambda r: Exception(r + "d"), lambda e: e.value.args[0] + 'y') - d.add_callbacks(lambda r: r + "z", lambda e: e.value.args[0] + 'e') - - d.callback("a") - - assert d.result == "abcde" - - -def test_callback_resolution_weaving_2(): - d = Deferred() - d.add_callbacks(lambda r: fail(Exception(r + "b")), lambda e: e.value.args[0] + 'w') - d.add_callbacks(lambda e: Exception(e + "x"), lambda e: e.value.args[0] + "c") - d.add_callbacks(lambda r: Exception(r + "d"), lambda e: e.value.args[0] + 'y') - d.add_callbacks(lambda r: fail(ValueError(r + "z")), lambda e: e.value.args[0] + 'e') - - d.errback(Exception('v')) - - assert isinstance(d.result, DeferredException) - assert isinstance(d.result.value, ValueError) - assert d.result.value.args[0] == "vwxyz" - - -def test_callback_raises_exception(): - def callback(val): - raise AttributeError(val) - - d = Deferred() - d.add_callback(callback) - d.callback('test') - - assert isinstance(d.result, DeferredException) - assert isinstance(d.result.value, AttributeError) - assert d.result.value.args[0] == "test" - - -def test_errback(): - holder = [] - d = Deferred() - e = Exception('errback test') - d.add_errback(lambda e: holder.append(e)) - d.errback(e) - - assert isinstance(holder[0], DeferredException) - assert holder[0].value == e - - -def test_errback_chain(): - holder = [] - d = Deferred() - e = Exception('a') - d.add_callbacks(holder.append, lambda e: Exception(e.value.args[0] + 'b')) - d.add_callbacks(holder.append, lambda e: Exception(e.value.args[0] + 'c')) - - d.errback(e) - - assert d.result.value.args[0] == 'abc' - assert len(holder) == 0 - - -def test_deferred_list_fails(): - d1 = Deferred() - d2 = Deferred() - d3 = Deferred() - - dl = DeferredList([ - 1, - succeed(2), - d1, - d2, - d3 - ]) - - assert not dl.called - - e1 = Exception('d1 failed') - d1.errback(e1) - d2.errback(Exception('d2 failed')) - d3.callback('hello') - - assert dl.called - assert isinstance(dl.result, DeferredException) - assert dl.result.value == e1 - - -def test_cant_callback_twice(): - d1 = Deferred() - d1.callback('hello') - - with raises(AlreadyCalledDeferred): - d1.callback('world') - - -def test_cant_errback_twice(): - d1 = Deferred() - d1.errback(Exception('hello')) - - with raises(AlreadyCalledDeferred): - d1.errback(Exception('world')) - - -def test_callbacks_and_errbacks_return_original_deferred(): - d = Deferred() - assert d.add_callback(lambda a: None) is d - assert d.add_errback(lambda a: None) is d - assert d.add_callbacks(lambda a: None, lambda a: None) is d - - -def test_callback_var_args(): - holder = [] - d = Deferred() - d.add_callback(lambda *args, **kwargs: holder.append((args, kwargs)), 2, 3, a=4, b=5) - d.callback(1) - - assert holder[0] == ((1, 2, 3), {'a': 4, 'b': 5}) - - -def test_deferred_callback_returns_another_deferred(): - d = Deferred() - d2 = Deferred() - - d.add_callback(lambda r: succeed(r + 5).add_callback(lambda v: v + 5)) - d.add_callback(lambda r: d2) - d.callback(5) - - assert d.result is d2 - assert d.paused - assert d.called - - d2.callback(7) - assert d.result == 7 - assert d2.result == 7 - - -def test_deferred_exception_catch(): - def dummy_errback(deferred_exception): - deferred_exception.catch(OSError) - return "caught" - - deferred = Deferred() - deferred.add_errback(dummy_errback) - deferred.errback(OSError()) - assert deferred.result == 'caught' diff --git a/graphql/execution/tests/test_executor_schema.py b/graphql/execution/tests/test_execute_schema.py similarity index 100% rename from graphql/execution/tests/test_executor_schema.py rename to graphql/execution/tests/test_execute_schema.py diff --git a/graphql/execution/tests/test_executor.py b/graphql/execution/tests/test_executor.py index 36b52e89..05fa5b2a 100644 --- a/graphql/execution/tests/test_executor.py +++ b/graphql/execution/tests/test_executor.py @@ -1,11 +1,9 @@ import json -from collections import OrderedDict from pytest import raises from graphql.error import GraphQLError -from graphql.execution import Executor, execute -from graphql.execution.middlewares.sync import SynchronousExecutionMiddleware +from graphql.execution import execute from graphql.language.parser import parse from graphql.type import (GraphQLArgument, GraphQLBoolean, GraphQLField, GraphQLInt, GraphQLList, GraphQLObjectType, @@ -456,7 +454,7 @@ def __init__(self, value): ] } - assert 'Expected value of type "SpecialType" but got NotSpecial.' in str(result.errors) + assert 'Expected value of type "SpecialType" but got: NotSpecial.' in [str(e) for e in result.errors] def test_fails_to_execute_a_query_containing_a_type_definition(): @@ -479,45 +477,3 @@ def test_fails_to_execute_a_query_containing_a_type_definition(): execute(schema, None, query) assert excinfo.value.message == 'GraphQL cannot execute a request containing a ObjectTypeDefinition.' - - -def test_executor_detects_strict_ordering(): - executor = Executor() - assert not executor.enforce_strict_ordering - assert executor.map_type is dict - - executor = Executor(map_type=OrderedDict) - assert executor.enforce_strict_ordering - assert executor.map_type is OrderedDict - - -def test_executor_can_enforce_strict_ordering(): - Type = GraphQLObjectType('Type', lambda: { - 'a': GraphQLField(GraphQLString, - resolver=lambda *_: 'Apple'), - 'b': GraphQLField(GraphQLString, - resolver=lambda *_: 'Banana'), - 'c': GraphQLField(GraphQLString, - resolver=lambda *_: 'Cherry'), - 'deep': GraphQLField(Type, resolver=lambda *_: {}), - }) - schema = GraphQLSchema(query=Type) - executor = Executor(execution_middlewares=[SynchronousExecutionMiddleware], map_type=OrderedDict) - query = '{ a b c aa: c cc: c bb: b aaz: a bbz: b deep { b a c deeper: deep { c a b } } ' \ - 'ccz: c zzz: c aaa: a }' - - def check_result(result): - assert not result.errors - - data = result.data - assert isinstance(data, OrderedDict) - assert list(data.keys()) == ['a', 'b', 'c', 'aa', 'cc', 'bb', 'aaz', 'bbz', 'deep', 'ccz', 'zzz', 'aaa'] - deep = data['deep'] - assert isinstance(deep, OrderedDict) - assert list(deep.keys()) == ['b', 'a', 'c', 'deeper'] - deeper = deep['deeper'] - assert isinstance(deeper, OrderedDict) - assert list(deeper.keys()) == ['c', 'a', 'b'] - - check_result(executor.execute(schema, query)) - check_result(executor.execute(schema, query, execute_serially=True)) diff --git a/graphql/execution/tests/test_gevent.py b/graphql/execution/tests/test_executor_gevent.py similarity index 72% rename from graphql/execution/tests/test_gevent.py rename to graphql/execution/tests/test_executor_gevent.py index 743a3cf0..3a3ff4fd 100644 --- a/graphql/execution/tests/test_gevent.py +++ b/graphql/execution/tests/test_executor_gevent.py @@ -2,21 +2,21 @@ import gevent from graphql.error import format_error -from graphql.execution import Executor -from graphql.execution.middlewares.gevent import (GeventExecutionMiddleware, - run_in_greenlet) from graphql.language.location import SourceLocation +from graphql.language.parser import parse from graphql.type import (GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString) +from ..execute import execute +from ..executors.gevent import GeventExecutor +from .test_mutations import assert_evaluate_mutations_serially + def test_gevent_executor(): - @run_in_greenlet def resolver(context, *_): gevent.sleep(0.001) return 'hey' - @run_in_greenlet def resolver_2(context, *_): gevent.sleep(0.003) return 'hey2' @@ -30,22 +30,19 @@ def resolver_3(contest, *_): 'c': GraphQLField(GraphQLString, resolver=resolver_3) }) - doc = '{ a b c }' - executor = Executor([GeventExecutionMiddleware()]) - result = executor.execute(GraphQLSchema(Type), doc) + ast = parse('{ a b c }') + result = execute(GraphQLSchema(Type), ast, executor=GeventExecutor()) assert not result.errors assert result.data == {'a': 'hey', 'b': 'hey2', 'c': 'hey3'} def test_gevent_executor_with_error(): - doc = 'query Example { a, b }' + ast = parse('query Example { a, b }') - @run_in_greenlet def resolver(context, *_): gevent.sleep(0.001) return 'hey' - @run_in_greenlet def resolver_2(context, *_): gevent.sleep(0.003) raise Exception('resolver_2 failed!') @@ -55,8 +52,11 @@ def resolver_2(context, *_): 'b': GraphQLField(GraphQLString, resolver=resolver_2) }) - executor = Executor([GeventExecutionMiddleware()]) - result = executor.execute(GraphQLSchema(Type), doc) + result = execute(GraphQLSchema(Type), ast, executor=GeventExecutor()) formatted_errors = list(map(format_error, result.errors)) assert formatted_errors == [{'locations': [{'line': 1, 'column': 20}], 'message': 'resolver_2 failed!'}] assert result.data == {'a': 'hey', 'b': None} + + +def test_evaluates_mutations_serially(): + assert_evaluate_mutations_serially(executor=GeventExecutor()) diff --git a/graphql/execution/tests/test_executor_thread.py b/graphql/execution/tests/test_executor_thread.py new file mode 100644 index 00000000..dca1739a --- /dev/null +++ b/graphql/execution/tests/test_executor_thread.py @@ -0,0 +1,218 @@ + +from graphql.error import format_error +from graphql.execution.execute import execute +from graphql.language.parser import parse +from graphql.type import (GraphQLArgument, GraphQLField, GraphQLInt, + GraphQLList, GraphQLObjectType, GraphQLSchema, + GraphQLString) + +from ..executors.thread import ThreadExecutor +from .test_mutations import assert_evaluate_mutations_serially +from .utils import rejected, resolved + + +def test_executes_arbitary_code(): + class Data(object): + a = 'Apple' + b = 'Banana' + c = 'Cookie' + d = 'Donut' + e = 'Egg' + + @property + def f(self): + return resolved('Fish') + + def pic(self, size=50): + return resolved('Pic of size: {}'.format(size)) + + def deep(self): + return DeepData() + + def promise(self): + return resolved(Data()) + + class DeepData(object): + a = 'Already Been Done' + b = 'Boring' + c = ['Contrived', None, resolved('Confusing')] + + def deeper(self): + return [Data(), None, resolved(Data())] + + ast = parse(''' + query Example($size: Int) { + a, + b, + x: c + ...c + f + ...on DataType { + pic(size: $size) + promise { + a + } + } + deep { + a + b + c + deeper { + a + b + } + } + } + fragment c on DataType { + d + e + } + ''') + + expected = { + 'a': 'Apple', + 'b': 'Banana', + 'x': 'Cookie', + 'd': 'Donut', + 'e': 'Egg', + 'f': 'Fish', + 'pic': 'Pic of size: 100', + 'promise': {'a': 'Apple'}, + 'deep': { + 'a': 'Already Been Done', + 'b': 'Boring', + 'c': ['Contrived', None, 'Confusing'], + 'deeper': [ + {'a': 'Apple', 'b': 'Banana'}, + None, + {'a': 'Apple', 'b': 'Banana'}]} + } + + DataType = GraphQLObjectType('DataType', lambda: { + 'a': GraphQLField(GraphQLString), + 'b': GraphQLField(GraphQLString), + 'c': GraphQLField(GraphQLString), + 'd': GraphQLField(GraphQLString), + 'e': GraphQLField(GraphQLString), + 'f': GraphQLField(GraphQLString), + 'pic': GraphQLField( + args={'size': GraphQLArgument(GraphQLInt)}, + type=GraphQLString, + resolver=lambda obj, args, *_: obj.pic(args['size']), + ), + 'deep': GraphQLField(DeepDataType), + 'promise': GraphQLField(DataType), + }) + + DeepDataType = GraphQLObjectType('DeepDataType', { + 'a': GraphQLField(GraphQLString), + 'b': GraphQLField(GraphQLString), + 'c': GraphQLField(GraphQLList(GraphQLString)), + 'deeper': GraphQLField(GraphQLList(DataType)), + }) + + schema = GraphQLSchema(query=DataType) + + def handle_result(result): + assert not result.errors + assert result.data == expected + + handle_result( + execute( + schema, + ast, + Data(), + variable_values={ + 'size': 100}, + operation_name='Example', + executor=ThreadExecutor())) + handle_result(execute(schema, ast, Data(), variable_values={'size': 100}, operation_name='Example')) + + +def test_synchronous_error_nulls_out_error_subtrees(): + ast = parse(''' + { + sync + syncError + syncReturnError + syncReturnErrorList + async + asyncReject + asyncEmptyReject + asyncReturnError + } + ''') + + class Data: + + def sync(self): + return 'sync' + + def syncError(self): + raise Exception('Error getting syncError') + + def syncReturnError(self): + return Exception("Error getting syncReturnError") + + def syncReturnErrorList(self): + return [ + 'sync0', + Exception('Error getting syncReturnErrorList1'), + 'sync2', + Exception('Error getting syncReturnErrorList3') + ] + + def async(self): + return resolved('async') + + def asyncReject(self): + return rejected(Exception('Error getting asyncReject')) + + def asyncEmptyReject(self): + return rejected(Exception('An unknown error occurred.')) + + def asyncReturnError(self): + return resolved(Exception('Error getting asyncReturnError')) + + schema = GraphQLSchema( + query=GraphQLObjectType( + name='Type', + fields={ + 'sync': GraphQLField(GraphQLString), + 'syncError': GraphQLField(GraphQLString), + 'syncReturnError': GraphQLField(GraphQLString), + 'syncReturnErrorList': GraphQLField(GraphQLList(GraphQLString)), + 'async': GraphQLField(GraphQLString), + 'asyncReject': GraphQLField(GraphQLString), + 'asyncEmptyReject': GraphQLField(GraphQLString), + 'asyncReturnError': GraphQLField(GraphQLString), + } + ) + ) + + def handle_results(result): + assert result.data == { + 'async': 'async', + 'asyncEmptyReject': None, + 'asyncReject': None, + 'asyncReturnError': None, + 'sync': 'sync', + 'syncError': None, + 'syncReturnError': None, + 'syncReturnErrorList': ['sync0', None, 'sync2', None] + } + assert list(map(format_error, result.errors)) == [ + {'locations': [{'line': 4, 'column': 9}], 'message': 'Error getting syncError'}, + {'locations': [{'line': 5, 'column': 9}], 'message': 'Error getting syncReturnError'}, + {'locations': [{'line': 6, 'column': 9}], 'message': 'Error getting syncReturnErrorList1'}, + {'locations': [{'line': 6, 'column': 9}], 'message': 'Error getting syncReturnErrorList3'}, + {'locations': [{'line': 8, 'column': 9}], 'message': 'Error getting asyncReject'}, + {'locations': [{'line': 9, 'column': 9}], 'message': 'An unknown error occurred.'}, + {'locations': [{'line': 10, 'column': 9}], 'message': 'Error getting asyncReturnError'} + ] + + handle_results(execute(schema, ast, Data(), executor=ThreadExecutor())) + + +def test_evaluates_mutations_serially(): + assert_evaluate_mutations_serially(executor=ThreadExecutor()) diff --git a/graphql/execution/tests/test_lists.py b/graphql/execution/tests/test_lists.py index 9478f76f..867d5f47 100644 --- a/graphql/execution/tests/test_lists.py +++ b/graphql/execution/tests/test_lists.py @@ -1,15 +1,15 @@ from collections import namedtuple from graphql.error import format_error -from graphql.execution import Executor, execute +from graphql.execution import execute from graphql.language.parser import parse -from graphql.pyutils.defer import fail, succeed from graphql.type import (GraphQLField, GraphQLInt, GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLSchema) +from .utils import rejected, resolved + Data = namedtuple('Data', 'test') ast = parse('{ nest { test } }') -executor = Executor() def check(test_data, expected): @@ -26,9 +26,7 @@ def run_check(self): ) schema = GraphQLSchema(query=DataType) - response = executor.execute(schema, ast, data) - assert response.called - response = response.result + response = execute(schema, data, ast) if response.errors: result = { @@ -56,10 +54,10 @@ class Test_ListOfT_Array_T: # [T] Array class Test_ListOfT_Promise_Array_T: # [T] Promise> type = GraphQLList(GraphQLInt) - test_contains_values = check(succeed([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check(succeed([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) - test_returns_null = check(succeed(None), {'data': {'nest': {'test': None}}}) - test_rejected = check(lambda: fail(Exception('bad')), { + test_contains_values = check(resolved([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check(resolved([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) + test_returns_null = check(resolved(None), {'data': {'nest': {'test': None}}}) + test_rejected = check(lambda: rejected(Exception('bad')), { 'data': {'nest': {'test': None}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -68,9 +66,9 @@ class Test_ListOfT_Promise_Array_T: # [T] Promise> class Test_ListOfT_Array_Promise_T: # [T] Array> type = GraphQLList(GraphQLInt) - test_contains_values = check([succeed(1), succeed(2)], {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check([succeed(1), succeed(None), succeed(2)], {'data': {'nest': {'test': [1, None, 2]}}}) - test_contains_reject = check(lambda: [succeed(1), fail(Exception('bad')), succeed(2)], { + test_contains_values = check([resolved(1), resolved(2)], {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check([resolved(1), resolved(None), resolved(2)], {'data': {'nest': {'test': [1, None, 2]}}}) + test_contains_reject = check(lambda: [resolved(1), rejected(Exception('bad')), resolved(2)], { 'data': {'nest': {'test': [1, None, 2]}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -79,9 +77,9 @@ class Test_ListOfT_Array_Promise_T: # [T] Array> class Test_NotNullListOfT_Array_T: # [T]! Array type = GraphQLNonNull(GraphQLList(GraphQLInt)) - test_contains_values = check(succeed([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check(succeed([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) - test_returns_null = check(succeed(None), { + test_contains_values = check(resolved([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check(resolved([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) + test_returns_null = check(resolved(None), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] @@ -91,15 +89,15 @@ class Test_NotNullListOfT_Array_T: # [T]! Array class Test_NotNullListOfT_Promise_Array_T: # [T]! Promise>> type = GraphQLNonNull(GraphQLList(GraphQLInt)) - test_contains_values = check(succeed([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check(succeed([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) - test_returns_null = check(succeed(None), { + test_contains_values = check(resolved([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check(resolved([1, None, 2]), {'data': {'nest': {'test': [1, None, 2]}}}) + test_returns_null = check(resolved(None), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_rejected = check(lambda: fail(Exception('bad')), { + test_rejected = check(lambda: rejected(Exception('bad')), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -107,9 +105,9 @@ class Test_NotNullListOfT_Promise_Array_T: # [T]! Promise>> class Test_NotNullListOfT_Array_Promise_T: # [T]! Promise>> type = GraphQLNonNull(GraphQLList(GraphQLInt)) - test_contains_values = check([succeed(1), succeed(2)], {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check([succeed(1), succeed(None), succeed(2)], {'data': {'nest': {'test': [1, None, 2]}}}) - test_contains_reject = check(lambda: [succeed(1), fail(Exception('bad')), succeed(2)], { + test_contains_values = check([resolved(1), resolved(2)], {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check([resolved(1), resolved(None), resolved(2)], {'data': {'nest': {'test': [1, None, 2]}}}) + test_contains_reject = check(lambda: [resolved(1), rejected(Exception('bad')), resolved(2)], { 'data': {'nest': {'test': [1, None, 2]}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -130,16 +128,16 @@ class TestListOfNotNullT_Array_T: # [T!] Array class TestListOfNotNullT_Promise_Array_T: # [T!] Promise> type = GraphQLList(GraphQLNonNull(GraphQLInt)) - test_contains_value = check(succeed([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check(succeed([1, None, 2]), { + test_contains_value = check(resolved([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check(resolved([1, None, 2]), { 'data': {'nest': {'test': None}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_returns_null = check(succeed(None), {'data': {'nest': {'test': None}}}) + test_returns_null = check(resolved(None), {'data': {'nest': {'test': None}}}) - test_rejected = check(lambda: fail(Exception('bad')), { + test_rejected = check(lambda: rejected(Exception('bad')), { 'data': {'nest': {'test': None}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -148,13 +146,13 @@ class TestListOfNotNullT_Promise_Array_T: # [T!] Promise> class TestListOfNotNullT_Array_Promise_T: # [T!] Array> type = GraphQLList(GraphQLNonNull(GraphQLInt)) - test_contains_values = check([succeed(1), succeed(2)], {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check([succeed(1), succeed(None), succeed(2)], { + test_contains_values = check([resolved(1), resolved(2)], {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check([resolved(1), resolved(None), resolved(2)], { 'data': {'nest': {'test': None}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_contains_reject = check(lambda: [succeed(1), fail(Exception('bad')), succeed(2)], { + test_contains_reject = check(lambda: [resolved(1), rejected(Exception('bad')), resolved(2)], { 'data': {'nest': {'test': None}}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -179,20 +177,20 @@ class TestNotNullListOfNotNullT_Array_T: # [T!]! Array class TestNotNullListOfNotNullT_Promise_Array_T: # [T!]! Promise> type = GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLInt))) - test_contains_value = check(succeed([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check(succeed([1, None, 2]), { + test_contains_value = check(resolved([1, 2]), {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check(resolved([1, None, 2]), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_returns_null = check(succeed(None), { + test_returns_null = check(resolved(None), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_rejected = check(lambda: fail(Exception('bad')), { + test_rejected = check(lambda: rejected(Exception('bad')), { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) @@ -201,13 +199,13 @@ class TestNotNullListOfNotNullT_Promise_Array_T: # [T!]! Promise> class TestNotNullListOfNotNullT_Array_Promise_T: # [T!]! Array> type = GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLInt))) - test_contains_values = check([succeed(1), succeed(2)], {'data': {'nest': {'test': [1, 2]}}}) - test_contains_null = check([succeed(1), succeed(None), succeed(2)], { + test_contains_values = check([resolved(1), resolved(2)], {'data': {'nest': {'test': [1, 2]}}}) + test_contains_null = check([resolved(1), resolved(None), resolved(2)], { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'Cannot return null for non-nullable field DataType.test.'}] }) - test_contains_reject = check(lambda: [succeed(1), fail(Exception('bad')), succeed(2)], { + test_contains_reject = check(lambda: [resolved(1), rejected(Exception('bad')), resolved(2)], { 'data': {'nest': None}, 'errors': [{'locations': [{'column': 10, 'line': 1}], 'message': 'bad'}] }) diff --git a/graphql/execution/tests/test_middleware.py b/graphql/execution/tests/test_middleware.py deleted file mode 100644 index 1d234270..00000000 --- a/graphql/execution/tests/test_middleware.py +++ /dev/null @@ -1,51 +0,0 @@ -from graphql.execution.middlewares.utils import (merge_resolver_tags, - resolver_has_tag, - tag_resolver) - - -def test_tag_resolver(): - resolver = lambda: None - - tag_resolver(resolver, 'test') - assert resolver_has_tag(resolver, 'test') - assert not resolver_has_tag(resolver, 'not test') - - -def test_merge_resolver_tags(): - a = lambda: None - b = lambda: None - - tag_resolver(a, 'a') - tag_resolver(b, 'b') - - merge_resolver_tags(a, b) - - assert resolver_has_tag(a, 'a') - assert not resolver_has_tag(a, 'b') - - assert resolver_has_tag(b, 'a') - assert resolver_has_tag(b, 'b') - - -def test_resolver_has_tag_with_untagged_resolver(): - a = lambda: None - - assert not resolver_has_tag(a, 'anything') - - -def test_merge_resolver_from_untagged_source(): - a = lambda: None - b = lambda: None - - merge_resolver_tags(a, b) - assert not hasattr(b, '_resolver_tags') - - -def test_merge_resolver_to_untagged_target(): - a = lambda: None - b = lambda: None - - tag_resolver(a, 'test') - merge_resolver_tags(a, b) - - assert resolver_has_tag(b, 'test') diff --git a/graphql/execution/tests/test_mutations.py b/graphql/execution/tests/test_mutations.py index 79d18c46..8e99d997 100644 --- a/graphql/execution/tests/test_mutations.py +++ b/graphql/execution/tests/test_mutations.py @@ -1,4 +1,4 @@ -from graphql.execution import execute +from graphql.execution.execute import execute from graphql.language.parser import parse from graphql.type import (GraphQLArgument, GraphQLField, GraphQLInt, GraphQLList, GraphQLObjectType, GraphQLSchema, @@ -66,7 +66,7 @@ def promise_and_fail_to_change_the_number(self, n): schema = GraphQLSchema(QueryType, MutationType) -def test_evaluates_mutations_serially(): +def assert_evaluate_mutations_serially(executor=None): doc = '''mutation M { first: immediatelyChangeTheNumber(newNumber: 1) { theNumber @@ -85,7 +85,7 @@ def test_evaluates_mutations_serially(): } }''' ast = parse(doc) - result = execute(schema, Root(6), ast, 'M') + result = execute(schema, ast, Root(6), operation_name='M', executor=executor) assert not result.errors assert result.data == \ { @@ -97,6 +97,10 @@ def test_evaluates_mutations_serially(): } +def test_evaluates_mutations_serially(): + assert_evaluate_mutations_serially() + + def test_evaluates_mutations_correctly_in_the_presense_of_a_failed_mutation(): doc = '''mutation M { first: immediatelyChangeTheNumber(newNumber: 1) { @@ -119,7 +123,7 @@ def test_evaluates_mutations_correctly_in_the_presense_of_a_failed_mutation(): } }''' ast = parse(doc) - result = execute(schema, Root(6), ast, 'M') + result = execute(schema, ast, Root(6), operation_name='M') assert result.data == \ { 'first': {'theNumber': 1}, diff --git a/graphql/execution/tests/test_nonnull.py b/graphql/execution/tests/test_nonnull.py index e4b35143..d43fee61 100644 --- a/graphql/execution/tests/test_nonnull.py +++ b/graphql/execution/tests/test_nonnull.py @@ -1,19 +1,17 @@ -from collections import OrderedDict from graphql.error import format_error -from graphql.execution import Executor, execute +from graphql.execution import execute from graphql.language.parser import parse -from graphql.pyutils.defer import fail, succeed from graphql.type import (GraphQLField, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString) +from .utils import rejected, resolved + sync_error = Exception('sync') non_null_sync_error = Exception('nonNullSync') promise_error = Exception('promise') non_null_promise_error = Exception('nonNullPromise') -executor = Executor(map_type=OrderedDict) - class ThrowingData(object): @@ -24,10 +22,10 @@ def nonNullSync(self): raise non_null_sync_error def promise(self): - return fail(promise_error) + return rejected(promise_error) def nonNullPromise(self): - return fail(non_null_promise_error) + return rejected(non_null_promise_error) def nest(self): return ThrowingData() @@ -36,10 +34,10 @@ def nonNullNest(self): return ThrowingData() def promiseNest(self): - return succeed(ThrowingData()) + return resolved(ThrowingData()) def nonNullPromiseNest(self): - return succeed(ThrowingData()) + return resolved(ThrowingData()) class NullingData(object): @@ -51,7 +49,10 @@ def nonNullSync(self): return None def promise(self): - return succeed(None) + return resolved(None) + + def nonNullPromise(self): + return resolved(None) def nest(self): return NullingData() @@ -60,10 +61,10 @@ def nonNullNest(self): return NullingData() def promiseNest(self): - return succeed(NullingData()) + return resolved(NullingData()) def nonNullPromiseNest(self): - return succeed(NullingData()) + return resolved(NullingData()) DataType = GraphQLObjectType('DataType', lambda: { @@ -82,9 +83,7 @@ def nonNullPromiseNest(self): def check(doc, data, expected): ast = parse(doc) - response = executor.execute(schema, ast, data) - assert response.called - response = response.result + response = execute(schema, data, ast) if response.errors: result = { diff --git a/graphql/execution/tests/utils.py b/graphql/execution/tests/utils.py index fb93eeea..74e70fc4 100644 --- a/graphql/execution/tests/utils.py +++ b/graphql/execution/tests/utils.py @@ -1,42 +1,9 @@ -from graphql.pyutils.defer import Deferred, DeferredException, _passthrough +from graphql.pyutils.aplus import Promise -class RaisingDeferred(Deferred): +def resolved(value): + return Promise.fulfilled(value) - def _next(self): - """Process the next callback.""" - if self._running or self.paused: - return - while self.callbacks: - # Get the next callback pair - next_pair = self.callbacks.pop(0) - # Continue with the errback if the last result was an exception - callback, args, kwargs = next_pair[isinstance(self.result, - DeferredException)] - - if callback is not _passthrough: - self._running = True - try: - self.result = callback(self.result, *args, **kwargs) - - except: - self.result = DeferredException() - - finally: - self._running = False - - if isinstance(self.result, Exception): - self.result = DeferredException(self.result) - - if isinstance(self.result, DeferredException): - # Print the exception to stderr and stop if there aren't any - # further errbacks to process - self.result.raise_exception() - - -def raise_callback_results(deferred, callback): - d = RaisingDeferred() - d.add_callback(lambda r: r) - d.callback(deferred) - d.add_callback(callback) +def rejected(error): + return Promise.rejected(error) diff --git a/graphql/execution/values.py b/graphql/execution/values.py index c31f2145..cf1bff1e 100644 --- a/graphql/execution/values.py +++ b/graphql/execution/values.py @@ -32,9 +32,11 @@ def get_variable_values(schema, definition_asts, inputs): def get_argument_values(arg_defs, arg_asts, variables): """Prepares an object map of argument values given a list of argument definitions and list of argument AST nodes.""" + if not arg_defs: + return {} + if arg_asts: arg_ast_map = {arg.name.value: arg for arg in arg_asts} - else: arg_ast_map = {} @@ -75,14 +77,15 @@ def get_variable_value(schema, definition_ast, input): [definition_ast] ) - errors = is_valid_value(input, type) + input_type = type + errors = is_valid_value(input, input_type) if not errors: if input is None: default_value = definition_ast.default_value if default_value: - return value_from_ast(default_value, type) + return value_from_ast(default_value, input_type) - return coerce_value(type, input) + return coerce_value(input_type, input) if input is None: raise GraphQLError( diff --git a/graphql/language/tests/test_visitor.py b/graphql/language/tests/test_visitor.py index 043dabe5..5ede5294 100644 --- a/graphql/language/tests/test_visitor.py +++ b/graphql/language/tests/test_visitor.py @@ -15,6 +15,7 @@ def test_allows_editing_a_node_both_on_enter_and_on_leave(): ast = parse('{ a, b, c { a, b, c } }', no_location=True) class TestVisitor(Visitor): + def __init__(self): self.did_enter = False self.did_leave = False @@ -65,6 +66,7 @@ def test_allows_editing_the_root_node_on_enter_and_on_leave(): definitions = ast.definitions class TestVisitor(Visitor): + def __init__(self): self.did_enter = False self.did_leave = False diff --git a/graphql/pyutils/aplus.py b/graphql/pyutils/aplus.py new file mode 100644 index 00000000..fe8ed02e --- /dev/null +++ b/graphql/pyutils/aplus.py @@ -0,0 +1,451 @@ +from threading import Event, RLock + + +class CountdownLatch(object): + + def __init__(self, count): + assert count >= 0 + + self._lock = RLock() + self._count = count + + def dec(self): + with self._lock: + assert self._count > 0 + + self._count -= 1 + + # Return inside lock to return the correct value, + # otherwise an other thread could already have + # decremented again. + return self._count + + @property + def count(self): + return self._count + + +class Promise(object): + """ + This is a class that attempts to comply with the + Promises/A+ specification and test suite: + http://promises-aplus.github.io/promises-spec/ + """ + + # These are the potential states of a promise + PENDING = -1 + REJECTED = 0 + FULFILLED = 1 + + def __init__(self, fn=None): + """ + Initialize the Promise into a pending state. + """ + self._state = self.PENDING + self._value = None + self._reason = None + self._cb_lock = RLock() + self._callbacks = [] + self._errbacks = [] + self._event = Event() + if fn: + self.do_resolve(fn) + + def do_resolve(self, fn): + self._done = False + + def resolve_fn(x): + if self._done: + return + self._done = True + self.fulfill(x) + + def reject_fn(x): + if self._done: + return + self._done = True + self.reject(x) + try: + fn(resolve_fn, reject_fn) + except Exception as e: + self.reject(e) + + @staticmethod + def fulfilled(x): + p = Promise() + p.fulfill(x) + return p + + @staticmethod + def rejected(reason): + p = Promise() + p.reject(reason) + return p + + def fulfill(self, x): + """ + Fulfill the promise with a given value. + """ + + if self is x: + raise TypeError("Cannot resolve promise with itself.") + elif is_thenable(x): + try: + promisify(x).done(self.fulfill, self.reject) + except Exception as e: + self.reject(e) + else: + self._fulfill(x) + + resolve = fulfilled + + def _fulfill(self, value): + with self._cb_lock: + if self._state != Promise.PENDING: + return + + self._value = value + self._state = self.FULFILLED + + callbacks = self._callbacks + # We will never call these callbacks again, so allow + # them to be garbage collected. This is important since + # they probably include closures which are binding variables + # that might otherwise be garbage collected. + # + # Prevent future appending + self._callbacks = None + + # Notify all waiting + self._event.set() + + for callback in callbacks: + try: + callback(value) + except Exception: + # Ignore errors in callbacks + pass + + def reject(self, reason): + """ + Reject this promise for a given reason. + """ + assert isinstance(reason, Exception) + + with self._cb_lock: + if self._state != Promise.PENDING: + return + + self._reason = reason + self._state = self.REJECTED + + errbacks = self._errbacks + # We will never call these errbacks again, so allow + # them to be garbage collected. This is important since + # they probably include closures which are binding variables + # that might otherwise be garbage collected. + # + # Prevent future appending + self._errbacks = None + + # Notify all waiting + self._event.set() + + for errback in errbacks: + try: + errback(reason) + except Exception: + # Ignore errors in errback + pass + + @property + def is_pending(self): + """Indicate whether the Promise is still pending. Could be wrong the moment the function returns.""" + return self._state == self.PENDING + + @property + def is_fulfilled(self): + """Indicate whether the Promise has been fulfilled. Could be wrong the moment the function returns.""" + return self._state == self.FULFILLED + + @property + def is_rejected(self): + """Indicate whether the Promise has been rejected. Could be wrong the moment the function returns.""" + return self._state == self.REJECTED + + @property + def value(self): + return self._value + + @property + def reason(self): + return self._reason + + def get(self, timeout=None): + """Get the value of the promise, waiting if necessary.""" + self.wait(timeout) + + if self._state == self.PENDING: + raise ValueError("Value not available, promise is still pending") + elif self._state == self.FULFILLED: + return self._value + else: + raise self._reason + + def wait(self, timeout=None): + """ + An implementation of the wait method which doesn't involve + polling but instead utilizes a "real" synchronization + scheme. + """ + self._event.wait(timeout) + + def add_callback(self, f): + """ + Add a callback for when this promis is fulfilled. Note that + if you intend to use the value of the promise somehow in + the callback, it is more convenient to use the 'then' method. + """ + assert _is_function(f) + + with self._cb_lock: + if self._state == self.PENDING: + self._callbacks.append(f) + return + + # This is a correct performance optimization in case of concurrency. + # State can never change once it is not PENDING anymore and is thus safe to read + # without acquiring the lock. + if self._state == self.FULFILLED: + f(self._value) + else: + pass + + def add_errback(self, f): + """ + Add a callback for when this promis is rejected. Note that + if you intend to use the rejection reason of the promise + somehow in the callback, it is more convenient to use + the 'then' method. + """ + assert _is_function(f) + + with self._cb_lock: + if self._state == self.PENDING: + self._errbacks.append(f) + return + + # This is a correct performance optimization in case of concurrency. + # State can never change once it is not PENDING anymore and is thus safe to read + # without acquiring the lock. + if self._state == self.REJECTED: + f(self._reason) + else: + pass + + def catch(self, f): + return self.then(None, f) + + def done(self, success=None, failure=None): + """ + This method takes two optional arguments. The first argument + is used if the "self promise" is fulfilled and the other is + used if the "self promise" is rejected. In contrast to then, + the return value of these callback is ignored and nothing is + returned. + """ + with self._cb_lock: + if success is not None: + self.add_callback(success) + if failure is not None: + self.add_errback(failure) + + def done_all(self, *handlers): + """ + :type handlers: list[(object) -> object] | list[((object) -> object, (object) -> object)] + """ + if len(handlers) == 0: + return + elif len(handlers) == 1 and isinstance(handlers[0], list): + handlers = handlers[0] + + for handler in handlers: + if isinstance(handler, tuple): + s, f = handler + + self.done(s, f) + elif isinstance(handler, dict): + s = handler.get('success') + f = handler.get('failure') + + self.done(s, f) + else: + self.done(success=handler) + + def then(self, success=None, failure=None): + """ + This method takes two optional arguments. The first argument + is used if the "self promise" is fulfilled and the other is + used if the "self promise" is rejected. In either case, this + method returns another promise that effectively represents + the result of either the first of the second argument (in the + case that the "self promise" is fulfilled or rejected, + respectively). + Each argument can be either: + * None - Meaning no action is taken + * A function - which will be called with either the value + of the "self promise" or the reason for rejection of + the "self promise". The function may return: + * A value - which will be used to fulfill the promise + returned by this method. + * A promise - which, when fulfilled or rejected, will + cascade its value or reason to the promise returned + by this method. + * A value - which will be assigned as either the value + or the reason for the promise returned by this method + when the "self promise" is either fulfilled or rejected, + respectively. + :type success: (object) -> object + :type failure: (object) -> object + :rtype : Promise + """ + ret = Promise() + + def call_and_fulfill(v): + """ + A callback to be invoked if the "self promise" + is fulfilled. + """ + try: + if _is_function(success): + ret.fulfill(success(v)) + else: + ret.fulfill(v) + except Exception as e: + ret.reject(e) + + def call_and_reject(r): + """ + A callback to be invoked if the "self promise" + is rejected. + """ + try: + if _is_function(failure): + ret.fulfill(failure(r)) + else: + ret.reject(r) + except Exception as e: + ret.reject(e) + + self.done(call_and_fulfill, call_and_reject) + + return ret + + def then_all(self, *handlers): + """ + Utility function which calls 'then' for each handler provided. Handler can either + be a function in which case it is used as success handler, or a tuple containing + the success and the failure handler, where each of them could be None. + :type handlers: list[(object) -> object] | list[((object) -> object, (object) -> object)] + :param handlers + :rtype : list[Promise] + """ + if len(handlers) == 0: + return [] + elif len(handlers) == 1 and isinstance(handlers[0], list): + handlers = handlers[0] + + promises = [] + + for handler in handlers: + if isinstance(handler, tuple): + s, f = handler + + promises.append(self.then(s, f)) + elif isinstance(handler, dict): + s = handler.get('success') + f = handler.get('failure') + + promises.append(self.then(s, f)) + else: + promises.append(self.then(success=handler)) + + return promises + + @staticmethod + def all(values_or_promises): + """ + A special function that takes a bunch of promises + and turns them into a promise for a vector of values. + In other words, this turns an list of promises for values + into a promise for a list of values. + """ + promises = list(filter(is_thenable, values_or_promises)) + if len(promises) == 0: + # All the values or promises are resolved + return Promise.fulfilled(values_or_promises) + + all_promise = Promise() + counter = CountdownLatch(len(promises)) + + def handleSuccess(_): + if counter.dec() == 0: + values = list(map(lambda p: p.value if p in promises else p, values_or_promises)) + all_promise.fulfill(values) + + for p in promises: + promisify(p).done(handleSuccess, all_promise.reject) + + return all_promise + + +def _is_function(v): + """ + A utility function to determine if the specified + value is a function. + """ + return v is not None and hasattr(v, "__call__") + + +def is_thenable(obj): + """ + A utility function to determine if the specified + object is a promise using "duck typing". + """ + return isinstance(obj, Promise) or ( + hasattr(obj, "done") and _is_function(getattr(obj, "done"))) or ( + hasattr(obj, "then") and _is_function(getattr(obj, "then"))) + + +def promisify(obj): + if isinstance(obj, Promise): + return obj + elif hasattr(obj, "done") and _is_function(getattr(obj, "done")): + p = Promise() + obj.done(p.fulfill, p.reject) + return p + elif hasattr(obj, "then") and _is_function(getattr(obj, "then")): + p = Promise() + obj.then(p.fulfill, p.reject) + return p + else: + raise TypeError("Object is not a Promise like object.") + + +def promise_for_dict(m): + """ + A special function that takes a dictionary of promises + and turns them into a promise for a dictionary of values. + In other words, this turns an dictionary of promises for values + into a promise for a dictionary of values. + """ + if not m: + return Promise.fulfilled({}) + + keys, values = zip(*m.items()) + dict_type = type(m) + + def handleSuccess(resolved_values): + return dict_type(zip(keys, resolved_values)) + + return Promise.all(values).then(handleSuccess) diff --git a/graphql/pyutils/defer.py b/graphql/pyutils/defer.py deleted file mode 100644 index 0bafe88f..00000000 --- a/graphql/pyutils/defer.py +++ /dev/null @@ -1,529 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -"""Small framework for asynchronous programming.""" -# Copyright (C) 2008-2010 Sebastian Heinlein -# Copyright (c) 2001-2010 -# Allen Short -# Andy Gayton -# Andrew Bennetts -# Antoine Pitrou -# Apple Computer, Inc. -# Benjamin Bruheim -# Bob Ippolito -# Canonical Limited -# Christopher Armstrong -# David Reid -# Donovan Preston -# Eric Mangold -# Eyal Lotem -# Itamar Shtull-Trauring -# James Knight -# Jason A. Mobarak -# Jean-Paul Calderone -# Jessica McKellar -# Jonathan Jacobs -# Jonathan Lange -# Jonathan D. Simms -# Jürgen Hermann -# Kevin Horn -# Kevin Turner -# Mary Gardiner -# Matthew Lefkowitz -# Massachusetts Institute of Technology -# Moshe Zadka -# Paul Swartz -# Pavel Pergamenshchik -# Ralph Meijer -# Sean Riley -# Software Freedom Conservancy -# Travis B. Hartwell -# Thijs Triemstra -# Thomas Herve -# Timothy Allen -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -import collections -import sys - -from six import reraise - -__all__ = ("Deferred", "AlreadyCalledDeferred", "DeferredException", - "defer", "succeed", "fail", "DeferredDict", "DeferredList") - - -class AlreadyCalledDeferred(Exception): - """The Deferred is already running a callback.""" - - -class DeferredException(object): - """Allows to defer exceptions.""" - __slots__ = 'type', 'value', 'traceback' - - def __init__(self, type=None, value=None, traceback=None): - """Return a new DeferredException instance. - - If type, value and traceback are not specified the infotmation - will be retreieved from the last caught exception: - - >>> try: - ... raise Exception("Test") - ... except: - ... deferred_exc = DeferredException() - >>> deferred_exc.raise_exception() - Traceback (most recent call last): - ... - Exception: Test - - Alternatively you can set the exception manually: - - >>> exception = Exception("Test 2") - >>> deferred_exc = DeferredException(exception) - >>> deferred_exc.raise_exception() - Traceback (most recent call last): - ... - Exception: Test 2 - """ - self.type = type - self.value = value - self.traceback = traceback - if isinstance(type, Exception): - self.type = type.__class__ - self.value = type - elif not type or not value: - self.type, self.value, self.traceback = sys.exc_info() - - def raise_exception(self): - """Raise the stored exception.""" - reraise(self.type, self.value, self.traceback) - - def catch(self, *errors): - """Check if the stored exception is a subclass of one of the - provided exception classes. If this is the case return the - matching exception class. Otherwise raise the stored exception. - - >>> exc = DeferredException(SystemError()) - >>> exc.catch(Exception) # Will catch the exception and return it - - >>> exc.catch(OSError) # Won't catch and raise the stored exception - Traceback (most recent call last): - ... - SystemError - - This method can be used in errbacks of a Deferred: - - >>> def dummy_errback(deferred_exception): - ... '''Error handler for OSError''' - ... deferred_exception.catch(OSError) - ... return "catched" - - The above errback can handle an OSError: - - >>> deferred = Deferred() - >>> deferred.add_errback(dummy_errback) - >>> deferred.errback(OSError()) - >>> deferred.result - 'catched' - - But fails to handle a SystemError: - - >>> deferred2 = Deferred() - >>> deferred2.add_errback(dummy_errback) - >>> deferred2.errback(SystemError()) - >>> deferred2.result #doctest: +ELLIPSIS - - >>> deferred2.result.value - SystemError() - """ - for err in errors: - if issubclass(self.type, err): - return err - self.raise_exception() - - -class Deferred(object): - """The Deferred allows to chain callbacks. - - There are two type of callbacks: normal callbacks and errbacks, which - handle an exception in a normal callback. - - The callbacks are processed in pairs consisting of a normal callback - and an errback. A normal callback will return its result to the - callback of the next pair. If an exception occurs, it will be handled - by the errback of the next pair. If an errback doesn't raise an error - again, the callback of the next pair will be called with the return - value of the errback. Otherwise the exception of the errback will be - returned to the errback of the next pair:: - - CALLBACK1 ERRBACK1 - | \ / | - result failure result failure - | \ / | - | \ / | - | X | - | / \ | - | / \ | - | / \ | - CALLBACK2 ERRBACK2 - | \ / | - result failure result failure - | \ / | - | \ / | - | X | - | / \ | - | / \ | - | / \ | - CALLBACK3 ERRBACK3 - """ - - __slots__ = 'callbacks', 'errbacks', 'called', 'paused', '_running', 'result' - - def __init__(self): - """Return a new Deferred instance.""" - self.callbacks = [] - self.errbacks = [] - self.called = False - self.paused = False - self._running = False - - def add_callbacks(self, callback, errback=None, - callback_args=None, callback_kwargs=None, - errback_args=None, errback_kwargs=None): - """Add a pair of callables (function or method) to the callback and - errback chain. - - Keyword arguments: - callback -- the next chained challback - errback -- the next chained errback - callback_args -- list of additional arguments for the callback - callback_kwargs -- dict of additional arguments for the callback - errback_args -- list of additional arguments for the errback - errback_kwargs -- dict of additional arguments for the errback - - In the following example the first callback pairs raises an - exception that is catched by the errback of the second one and - processed by the third one. - - >>> def callback(previous): - ... '''Return the previous result.''' - ... return "Got: %s" % previous - >>> def callback_raise(previous): - ... '''Fail and raise an exception.''' - ... raise Exception("Test") - >>> def errback(error): - ... '''Recover from an exception.''' - ... #error.catch(Exception) - ... return "catched" - >>> deferred = Deferred() - >>> deferred.callback("start") - >>> deferred.result - 'start' - >>> deferred.add_callbacks(callback_raise, errback) - >>> deferred.result #doctest: +ELLIPSIS - - >>> deferred.add_callbacks(callback, errback) - >>> deferred.result - 'catched' - >>> deferred.add_callbacks(callback, errback) - >>> deferred.result - 'Got: catched' - """ - assert callback is _passthrough or isinstance(callback, collections.Callable) - assert errback is None or errback is _passthrough or isinstance(errback, collections.Callable) - if errback is None: - errback = _passthrough - self.callbacks.append(((callback, - callback_args or ([]), - callback_kwargs or ({})), - (errback or (_passthrough), - errback_args or ([]), - errback_kwargs or ({})))) - if self.called: - self._next() - - return self - - def add_errback(self, func, *args, **kwargs): - """Add a callable (function or method) to the errback chain only. - - If there isn't any exception the result will be passed through to - the callback of the next pair. - - The first argument is the callable instance followed by any - additional argument that will be passed to the errback. - - The errback method will get the most recent DeferredException and - and any additional arguments that was specified in add_errback. - - If the errback can catch the exception it can return a value that - will be passed to the next callback in the chain. Otherwise the - errback chain will not be processed anymore. - - See the documentation of defer.DeferredException.catch for - further information. - - >>> def catch_error(deferred_error, ignore=False): - ... if ignore: - ... return "ignored" - ... deferred_error.catch(Exception) - ... return "catched" - >>> deferred = Deferred() - >>> deferred.errback(SystemError()) - >>> deferred.add_errback(catch_error, ignore=True) - >>> deferred.result - 'ignored' - """ - return self.add_callbacks(_passthrough, func, errback_args=args, - errback_kwargs=kwargs) - - def add_callback(self, func, *args, **kwargs): - """Add a callable (function or method) to the callback chain only. - - An error would be passed through to the next errback. - - The first argument is the callable instance followed by any - additional argument that will be passed to the callback. - - The callback method will get the result of the previous callback - and any additional arguments that was specified in add_callback. - - >>> def callback(previous, counter=False): - ... if counter: - ... return previous + 1 - ... return previous - >>> deferred = Deferred() - >>> deferred.add_callback(callback, counter=True) - >>> deferred.callback(1) - >>> deferred.result - 2 - """ - return self.add_callbacks(func, _passthrough, callback_args=args, - callback_kwargs=kwargs) - - def errback(self, error=None): - """Start processing the errorback chain starting with the - provided exception or DeferredException. - - If an exception is specified it will be wrapped into a - DeferredException. It will be send to the first errback or stored - as finally result if not any further errback has been specified yet. - - >>> deferred = Deferred() - >>> deferred.errback(Exception("Test Error")) - >>> deferred.result #doctest: +ELLIPSIS - - >>> deferred.result.raise_exception() - Traceback (most recent call last): - ... - Exception: Test Error - """ - if self.called: - raise AlreadyCalledDeferred() - if not error: - error = DeferredException() - elif not isinstance(error, DeferredException): - assert isinstance(error, Exception) - error = DeferredException(error.__class__, error, None) - - self.called = True - self.result = error - self._next() - - def callback(self, result=None): - """Start processing the callback chain starting with the - provided result. - - It will be send to the first callback or stored as finally - one if not any further callback has been specified yet. - - >>> deferred = Deferred() - >>> deferred.callback("done") - >>> deferred.result - 'done' - """ - if self.called: - raise AlreadyCalledDeferred() - self.called = True - - if isinstance(result, Deferred): - self.paused = True - return result.add_callbacks(self._continue, self._continue) - - self.result = result - self._next() - - def _continue(self, result): - """Continue processing the Deferred with the given result.""" - # If the result of the deferred is another deferred, we will need to wait for - # it to resolve again. - if isinstance(result, Deferred): - return result.add_callbacks(self._continue, self._continue) - - self.result = result - self.paused = False - if self.called: - self._next() - - return result - - def _next(self): - """Process the next callback.""" - if self._running or self.paused: - return - - while self.callbacks: - # Get the next callback pair - next_pair = self.callbacks.pop(0) - # Continue with the errback if the last result was an exception - callback, args, kwargs = next_pair[isinstance(self.result, - DeferredException)] - - if callback is not _passthrough: - self._running = True - try: - self.result = callback(self.result, *args, **kwargs) - - except: - self.result = DeferredException() - - finally: - self._running = False - - if isinstance(self.result, Exception): - self.result = DeferredException(self.result) - - if isinstance(self.result, Deferred): - # If a Deferred was returned add this deferred as callbacks to - # the returned one. As a result the processing of this Deferred - # will be paused until all callbacks of the returned Deferred - # have been performed - self.paused = True - self.result.add_callbacks(self._continue, self._continue) - break - - -def defer(func, *args, **kwargs): - """Invoke the given function that may or not may be a Deferred. - - If the return object of the function call is a Deferred return, it. - Otherwise wrap it into a Deferred. - - >>> defer(lambda x: x, 10) #doctest: +ELLIPSIS - - - >>> deferred = defer(lambda x: x, "done") - >>> deferred.result - 'done' - - >>> deferred = Deferred() - >>> defer(lambda: deferred) == deferred - True - """ - assert isinstance(func, collections.Callable) - - try: - result = func(*args, **kwargs) - except: - result = DeferredException() - - if isinstance(result, Deferred): - return result - - deferred = Deferred() - deferred.callback(result) - return deferred - - -_passthrough = object() - - -def succeed(result): - d = Deferred() - d.callback(result) - return d - - -def fail(result=None): - d = Deferred() - d.errback(result) - return d - - -class _ResultCollector(Deferred): - objects_remaining_to_resolve = 0 - _result = None - - def _schedule_callbacks(self, items, result, objects_remaining_to_resolve=None, preserve_insert_ordering=False): - self.objects_remaining_to_resolve = \ - objects_remaining_to_resolve if objects_remaining_to_resolve is not None else len(items) - self._result = result - for key, value in items: - if isinstance(value, Deferred): - # We will place a value in place of the resolved key, so that insert order is preserved. - if preserve_insert_ordering: - result[key] = None - - value.add_callbacks(self._cb_deferred, self._cb_deferred, - callback_args=(key, True), - errback_args=(key, False)) - else: - self.objects_remaining_to_resolve -= 1 - result[key] = value - - if self.objects_remaining_to_resolve == 0 and not self.called: - self.callback(self._result) - self._result = None - - def _cb_deferred(self, result, key, succeeded): - # If one item fails, we are going to errback right away with the error. - # This follows the Promise.all(...) spec in ES6. - if self.called: - return result - - if not succeeded: - self.errback(result) - self._result = None - return result - - self.objects_remaining_to_resolve -= 1 - self._result[key] = result - - if self.objects_remaining_to_resolve == 0: - self.callback(self._result) - self._result = None - - return result - - -class DeferredDict(_ResultCollector): - - def __init__(self, mapping): - super(DeferredDict, self).__init__() - assert isinstance(mapping, collections.Mapping) - self._schedule_callbacks(mapping.items(), type(mapping)(), - preserve_insert_ordering=isinstance(mapping, collections.OrderedDict)) - - -class DeferredList(_ResultCollector): - - def __init__(self, sequence): - super(DeferredList, self).__init__() - assert isinstance(sequence, collections.Sequence) - sequence_len = len(sequence) - self._schedule_callbacks(enumerate(sequence), [None] * sequence_len, sequence_len) diff --git a/tests_py35/core_execution/test_asyncio_executor.py b/tests_py35/core_execution/test_asyncio_executor.py index 42ca57ff..36c00d14 100644 --- a/tests_py35/core_execution/test_asyncio_executor.py +++ b/tests_py35/core_execution/test_asyncio_executor.py @@ -3,8 +3,9 @@ import asyncio import functools from graphql.error import format_error -from graphql.execution import Executor -from graphql.execution.middlewares.asyncio import AsyncioExecutionMiddleware +from graphql.execution.execute import execute +from graphql.language.parser import parse +from graphql.execution.executors.asyncio import AsyncioExecutor from graphql.type import ( GraphQLSchema, GraphQLObjectType, @@ -13,17 +14,8 @@ ) -def run_until_complete(fun): - @functools.wraps(fun) - def wrapper(*args, **kwargs): - coro = fun(*args, **kwargs) - return asyncio.get_event_loop().run_until_complete(coro) - return wrapper - - -@run_until_complete -async def test_asyncio_py35_executor(): - doc = 'query Example { a, b, c }' +def test_asyncio_py35_executor(): + ast = parse('query Example { a, b, c }') async def resolver(context, *_): await asyncio.sleep(0.001) @@ -42,14 +34,13 @@ def resolver_3(context, *_): 'c': GraphQLField(GraphQLString, resolver=resolver_3) }) - executor = Executor([AsyncioExecutionMiddleware()]) - result = await executor.execute(GraphQLSchema(Type), doc) + result = execute(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) assert not result.errors assert result.data == {'a': 'hey', 'b': 'hey2', 'c': 'hey3'} -@run_until_complete -async def test_asyncio_py35_executor_with_error(): - doc = 'query Example { a, b }' + +def test_asyncio_py35_executor_with_error(): + ast = parse('query Example { a, b }') async def resolver(context, *_): await asyncio.sleep(0.001) @@ -64,8 +55,7 @@ async def resolver_2(context, *_): 'b': GraphQLField(GraphQLString, resolver=resolver_2) }) - executor = Executor([AsyncioExecutionMiddleware()]) - result = await executor.execute(GraphQLSchema(Type), doc) + result = execute(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) formatted_errors = list(map(format_error, result.errors)) assert formatted_errors == [{'locations': [{'line': 1, 'column': 20}], 'message': 'resolver_2 failed!'}] assert result.data == {'a': 'hey', 'b': None}