diff --git a/.flake8 b/.flake8 index 5960dc31..c8479de9 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs +max-line-length = 88 diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index 10398898..10e9b86f 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -5,11 +5,21 @@ """ from .execute import ( - execute, default_field_resolver, response_path_as_list, - ExecutionContext, ExecutionResult) + execute, + default_field_resolver, + response_path_as_list, + ExecutionContext, + ExecutionResult, +) +from .middleware import MiddlewareManager from .values import get_directive_values __all__ = [ - 'execute', 'default_field_resolver', 'response_path_as_list', - 'ExecutionContext', 'ExecutionResult', - 'get_directive_values'] + "execute", + "default_field_resolver", + "response_path_as_list", + "ExecutionContext", + "ExecutionResult", + "MiddlewareManager", + "get_directive_values", +] diff --git a/graphql/execution/execute.py b/graphql/execution/execute.py index e8808de1..699d13a3 100644 --- a/graphql/execution/execute.py +++ b/graphql/execution/execute.py @@ -20,6 +20,8 @@ is_non_null_type, is_object_type) from .values import ( get_argument_values, get_directive_values, get_variable_values) +from .middleware import MiddlewareManager + __all__ = [ 'add_path', 'assert_valid_execution_arguments', 'default_field_resolver', @@ -64,7 +66,8 @@ def execute( schema: GraphQLSchema, document: DocumentNode, root_value: Any=None, context_value: Any=None, variable_values: Dict[str, Any]=None, - operation_name: str=None, field_resolver: GraphQLFieldResolver=None + operation_name: str=None, field_resolver: GraphQLFieldResolver=None, + middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None ) -> MaybeAwaitable[ExecutionResult]: """Execute a GraphQL operation. @@ -84,7 +87,7 @@ def execute( # arguments, a "Response" with only errors is returned. exe_context = ExecutionContext.build( schema, document, root_value, context_value, - variable_values, operation_name, field_resolver) + variable_values, operation_name, field_resolver, middleware) # Return early errors if execution context failed. if isinstance(exe_context, list): @@ -116,6 +119,7 @@ class ExecutionContext: operation: OperationDefinitionNode variable_values: Dict[str, Any] field_resolver: GraphQLFieldResolver + middleware_manager: Optional[MiddlewareManager] errors: List[GraphQLError] def __init__( @@ -125,6 +129,7 @@ def __init__( operation: OperationDefinitionNode, variable_values: Dict[str, Any], field_resolver: GraphQLFieldResolver, + middleware_manager: Optional[MiddlewareManager], errors: List[GraphQLError]) -> None: self.schema = schema self.fragments = fragments @@ -133,6 +138,7 @@ def __init__( self.operation = operation self.variable_values = variable_values self.field_resolver = field_resolver # type: ignore + self.middleware_manager = middleware_manager self.errors = errors self._subfields_cache: Dict[ Tuple[GraphQLObjectType, Tuple[FieldNode, ...]], @@ -144,7 +150,8 @@ def build( root_value: Any=None, context_value: Any=None, raw_variable_values: Dict[str, Any]=None, operation_name: str=None, - field_resolver: GraphQLFieldResolver=None + field_resolver: GraphQLFieldResolver=None, + middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None ) -> Union[List[GraphQLError], 'ExecutionContext']: """Build an execution context @@ -157,6 +164,18 @@ def build( operation: Optional[OperationDefinitionNode] = None has_multiple_assumed_operations = False fragments: Dict[str, FragmentDefinitionNode] = {} + middleware_manager: Optional[MiddlewareManager] = None + if middleware: + if isinstance(middleware, Iterable): + middleware_manager = MiddlewareManager(*middleware) + elif isinstance(middleware, MiddlewareManager): + middleware_manager = middleware + else: + raise TypeError( + f"middlewares have to be an instance" + "of MiddlewareManager. Received \"{middleware}\"" + ) + for definition in document.definitions: if isinstance(definition, OperationDefinitionNode): if not operation_name and operation: @@ -201,7 +220,8 @@ def build( return cls( schema, fragments, root_value, context_value, operation, - variable_values, field_resolver or default_field_resolver, errors) + variable_values, field_resolver or default_field_resolver, + middleware_manager, errors) def build_response( self, data: MaybeAwaitable[Optional[Dict[str, Any]]] @@ -447,6 +467,9 @@ def resolve_field( resolve_fn = field_def.resolve or self.field_resolver + if self.middleware_manager: + resolve_fn = self.middleware_manager.get_field_resolver(resolve_fn) + info = self.build_resolve_info( field_def, field_nodes, parent_type, path) diff --git a/graphql/execution/middleware.py b/graphql/execution/middleware.py new file mode 100644 index 00000000..aedce55d --- /dev/null +++ b/graphql/execution/middleware.py @@ -0,0 +1,76 @@ +from typing import Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast + +from inspect import isfunction +from functools import partial +from itertools import chain + + +from ..type import GraphQLFieldResolver + + +__all__ = ["MiddlewareManager", "middlewares"] + +# If the provided middleware is a class, this is the attribute we will look at +MIDDLEWARE_RESOLVER_FUNCTION = "resolve" + + +class MiddlewareManager: + """MiddlewareManager helps to chain resolver functions with the provided + middleware functions and classes + """ + + __slots__ = ("middlewares", "_middleware_resolvers", "_cached_resolvers") + + _cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver] + _middleware_resolvers: Optional[Tuple[Callable, ...]] + + def __init__(self, *middlewares: Any) -> None: + self.middlewares = middlewares + if middlewares: + self._middleware_resolvers = tuple(get_middleware_resolvers(middlewares)) + else: + self.__middleware_resolvers = None + self._cached_resolvers = {} + + def get_field_resolver( + self, field_resolver: GraphQLFieldResolver + ) -> GraphQLFieldResolver: + """Wraps the provided resolver returning a function that + executes chains the middleware functions with the resolver function""" + if self._middleware_resolvers is None: + return field_resolver + if field_resolver not in self._cached_resolvers: + self._cached_resolvers[field_resolver] = middleware_chain( + field_resolver, self._middleware_resolvers + ) + + return self._cached_resolvers[field_resolver] + + +middlewares = MiddlewareManager + + +def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]: + """Returns the functions related to the middleware classes or functions""" + for middleware in middlewares: + # If the middleware is a function instead of a class + if isfunction(middleware): + yield middleware + resolver_func = getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION, None) + if resolver_func is not None: + yield resolver_func + + +def middleware_chain( + func: GraphQLFieldResolver, middlewares: Iterable[Callable] +) -> GraphQLFieldResolver: + """Reduces the current function with the provided middlewares, + returning a new resolver function""" + if not middlewares: + return func + middlewares = chain((func,), middlewares) + last_func: Optional[GraphQLFieldResolver] = None + for middleware in middlewares: + last_func = partial(middleware, last_func) if last_func else middleware + + return cast(GraphQLFieldResolver, last_func) diff --git a/graphql/graphql.py b/graphql/graphql.py index a5de20f6..98becc97 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -1,25 +1,27 @@ from asyncio import ensure_future from inspect import isawaitable -from typing import Any, Awaitable, Callable, Dict, Union, cast +from typing import Any, Awaitable, Callable, Dict, Union, Optional, Iterable, cast from .error import GraphQLError from .execution import execute from .language import parse, Source from .pyutils import MaybeAwaitable from .type import GraphQLSchema, validate_schema -from .execution.execute import ExecutionResult +from .execution import ExecutionResult, MiddlewareManager -__all__ = ['graphql', 'graphql_sync'] +__all__ = ["graphql", "graphql_sync"] async def graphql( - schema: GraphQLSchema, - source: Union[str, Source], - root_value: Any=None, - context_value: Any=None, - variable_values: Dict[str, Any]=None, - operation_name: str=None, - field_resolver: Callable=None) -> ExecutionResult: + schema: GraphQLSchema, + source: Union[str, Source], + root_value: Any = None, + context_value: Any = None, + variable_values: Dict[str, Any] = None, + operation_name: str = None, + field_resolver: Callable = None, + middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None, +) -> ExecutionResult: """Execute a GraphQL operation asynchronously. This is the primary entry point function for fulfilling GraphQL operations @@ -56,6 +58,8 @@ async def graphql( A resolver function to use when one is not provided by the schema. If not provided, the default field resolver is used (which looks for a value or method on the source value with the field's name). + :arg middleware: + The middleware to wrap the resolvers with """ # Always return asynchronously for a consistent API. result = graphql_impl( @@ -65,7 +69,9 @@ async def graphql( context_value, variable_values, operation_name, - field_resolver) + field_resolver, + middleware, + ) if isawaitable(result): return await cast(Awaitable[ExecutionResult], result) @@ -74,13 +80,15 @@ async def graphql( def graphql_sync( - schema: GraphQLSchema, - source: Union[str, Source], - root_value: Any = None, - context_value: Any = None, - variable_values: Dict[str, Any] = None, - operation_name: str = None, - field_resolver: Callable = None) -> ExecutionResult: + schema: GraphQLSchema, + source: Union[str, Source], + root_value: Any = None, + context_value: Any = None, + variable_values: Dict[str, Any] = None, + operation_name: str = None, + field_resolver: Callable = None, + middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None, +) -> ExecutionResult: """Execute a GraphQL operation synchronously. The graphql_sync function also fulfills GraphQL operations by parsing, @@ -95,26 +103,28 @@ def graphql_sync( context_value, variable_values, operation_name, - field_resolver) + field_resolver, + middleware, + ) # Assert that the execution was synchronous. if isawaitable(result): ensure_future(cast(Awaitable[ExecutionResult], result)).cancel() - raise RuntimeError( - 'GraphQL execution failed to complete synchronously.') + raise RuntimeError("GraphQL execution failed to complete synchronously.") return cast(ExecutionResult, result) def graphql_impl( - schema, - source, - root_value, - context_value, - variable_values, - operation_name, - field_resolver - ) -> MaybeAwaitable[ExecutionResult]: + schema, + source, + root_value, + context_value, + variable_values, + operation_name, + field_resolver, + middleware, +) -> MaybeAwaitable[ExecutionResult]: """Execute a query, return asynchronously only if necessary.""" # Validate Schema schema_validation_errors = validate_schema(schema) @@ -132,6 +142,7 @@ def graphql_impl( # Validate from .validation import validate + validation_errors = validate(schema, document) if validation_errors: return ExecutionResult(data=None, errors=validation_errors) @@ -144,4 +155,6 @@ def graphql_impl( context_value, variable_values, operation_name, - field_resolver) + field_resolver, + middleware, + ) diff --git a/tests/execution/test_middleware.py b/tests/execution/test_middleware.py new file mode 100644 index 00000000..964017e8 --- /dev/null +++ b/tests/execution/test_middleware.py @@ -0,0 +1,104 @@ +from pytest import raises +from graphql.execution import MiddlewareManager, execute +from graphql.execution.middleware import get_middleware_resolvers, middleware_chain +from graphql.language.parser import parse +from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString + + +def test_middleware(): + doc = """{ + ok + not_ok + }""" + + class Data(object): + def ok(self, info): + return "ok" + + def not_ok(self, info): + return "not_ok" + + doc_ast = parse(doc) + + Type = GraphQLObjectType( + "Type", + {"ok": GraphQLField(GraphQLString), "not_ok": GraphQLField(GraphQLString)}, + ) + + def reversed_middleware(next, *args, **kwargs): + p = next(*args, **kwargs) + return p[::-1] + + middlewares = MiddlewareManager(reversed_middleware) + result = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares) + assert result.data == {"ok": "ko", "not_ok": "ko_ton"} + + +def test_middleware_class(): + doc = """{ + ok + not_ok + }""" + + class Data(object): + def ok(self, info): + return "ok" + + def not_ok(self, info): + return "not_ok" + + doc_ast = parse(doc) + + Type = GraphQLObjectType( + "Type", + {"ok": GraphQLField(GraphQLString), "not_ok": GraphQLField(GraphQLString)}, + ) + + class MyMiddleware(object): + def resolve(self, next, *args, **kwargs): + p = next(*args, **kwargs) + return p[::-1] + + middlewares = MiddlewareManager(MyMiddleware()) + result = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares) + assert result.data == {"ok": "ko", "not_ok": "ko_ton"} + + +def test_middleware_chain(): + call_order = [] + + class CharPrintingMiddleware(object): + def __init__(self, char): + self.char = char + + def resolve(self, next, *args, **kwargs): + call_order.append(f"resolve() called for middleware {self.char}") + value = next(*args, **kwargs) + call_order.append(f"then() for {self.char}") + return value + + middlewares = [ + CharPrintingMiddleware("a"), + CharPrintingMiddleware("b"), + CharPrintingMiddleware("c"), + ] + + middlewares_resolvers = get_middleware_resolvers(middlewares) + + def func(): + return + + chain_iter = middleware_chain(func, middlewares_resolvers) + + assert call_order == [] + + chain_iter() + + assert call_order == [ + "resolve() called for middleware c", + "resolve() called for middleware b", + "resolve() called for middleware a", + "then() for a", + "then() for b", + "then() for c", + ]