diff --git a/graphql/execution/tests/test_middleware.py b/graphql/execution/tests/test_middleware.py index 528e64ab..9715288d 100644 --- a/graphql/execution/tests/test_middleware.py +++ b/graphql/execution/tests/test_middleware.py @@ -1,10 +1,12 @@ # type: ignore +from __future__ import print_function + import json from pytest import raises - from graphql.error import GraphQLError from graphql.execution import MiddlewareManager, execute +from graphql.execution.middleware import get_middleware_resolvers, middleware_chain from graphql.language.parser import parse from graphql.type import ( GraphQLArgument, @@ -138,3 +140,52 @@ def resolve(self, next, *args, **kwargs): "ok": "ok", "not_ok": "not_ok", } + + +def test_middleware_chain(capsys): + # type: (Any) -> None + class CharPrintingMiddleware(object): + def __init__(self, char): + # type: (str) -> None + self.char = char + + def resolve(self, next, *args, **kwargs): + # type: (Callable, *Any, **Any) -> str + print("resolve() called for middleware {}".format(self.char)) + return next(*args, **kwargs).then( + lambda x: print("then() for {}".format(self.char)) + ) + + middlewares = [ + CharPrintingMiddleware("a"), + CharPrintingMiddleware("b"), + CharPrintingMiddleware("c"), + ] + + middlewares_resolvers = get_middleware_resolvers(middlewares) + + def func(): + # type: () -> None + return + + chain_iter = middleware_chain(func, middlewares_resolvers, wrap_in_promise=True) + + assert_stdout(capsys, "") + + chain_iter() + + expected_stdout = ( + "resolve() called for middleware c\n" + "resolve() called for middleware b\n" + "resolve() called for middleware a\n" + "then() for a\n" + "then() for b\n" + "then() for c\n" + ) + assert_stdout(capsys, expected_stdout) + + +def assert_stdout(capsys, expected_stdout): + # type: (Any, str) -> None + captured = capsys.readouterr() + assert captured.out == expected_stdout