From 52cea03564c584e6da394703cdc4aacafaeb4b6f Mon Sep 17 00:00:00 2001 From: xzhang2 Date: Wed, 12 Aug 2020 19:42:51 -0700 Subject: [PATCH] Fixed the bug where a nested GraphQLInputObjectType causing infinite recursive calls to `get_arg_serializer`. --- gql/dsl.py | 21 ++++++--- tests/nested_input/__init__.py | 0 tests/nested_input/schema.py | 30 ++++++++++++ tests/nested_input/test_nested_input.py | 63 +++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 tests/nested_input/__init__.py create mode 100644 tests/nested_input/schema.py create mode 100644 tests/nested_input/test_nested_input.py diff --git a/gql/dsl.py b/gql/dsl.py index 0f66ff5e..bd592ee1 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -105,7 +105,7 @@ def args(self, **kwargs): arg = self.field.args.get(name) if not arg: raise KeyError(f"Argument {name} does not exist in {self.field}.") - arg_type_serializer = get_arg_serializer(arg.type) + arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict()) serialized_value = arg_type_serializer(value) added_args.append( ArgumentNode(name=NameNode(value=name), value=serialized_value) @@ -151,21 +151,28 @@ def serialize_list(serializer, list_values): return ListValueNode(values=FrozenList(serializer(v) for v in list_values)) -def get_arg_serializer(arg_type): +def get_arg_serializer(arg_type, known_serializers): if isinstance(arg_type, GraphQLNonNull): - return get_arg_serializer(arg_type.of_type) + return get_arg_serializer(arg_type.of_type, known_serializers) if isinstance(arg_type, GraphQLInputField): - return get_arg_serializer(arg_type.type) + return get_arg_serializer(arg_type.type, known_serializers) if isinstance(arg_type, GraphQLInputObjectType): - serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()} - return lambda value: ObjectValueNode( + if arg_type in known_serializers: + return known_serializers[arg_type] + known_serializers[arg_type] = None + serializers = { + k: get_arg_serializer(v, known_serializers) + for k, v in arg_type.fields.items() + } + known_serializers[arg_type] = lambda value: ObjectValueNode( fields=FrozenList( ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) for k, v in value.items() ) ) + return known_serializers[arg_type] if isinstance(arg_type, GraphQLList): - inner_serializer = get_arg_serializer(arg_type.of_type) + inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers) return partial(serialize_list, inner_serializer) if isinstance(arg_type, GraphQLEnumType): return lambda value: EnumValueNode(value=arg_type.serialize(value)) diff --git a/tests/nested_input/__init__.py b/tests/nested_input/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py new file mode 100644 index 00000000..f27a94e8 --- /dev/null +++ b/tests/nested_input/schema.py @@ -0,0 +1,30 @@ +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLObjectType, + GraphQLSchema, +) + +nestedInput = GraphQLInputObjectType( + "Nested", + description="The input object that has a field pointing to itself", + fields={"foo": GraphQLInputField(GraphQLInt, description="foo")}, +) + +nestedInput.fields["child"] = GraphQLInputField(nestedInput, description="child") + +queryType = GraphQLObjectType( + "Query", + fields=lambda: { + "foo": GraphQLField( + args={"nested": GraphQLArgument(type_=nestedInput)}, + resolve=lambda *args, **kwargs: 1, + type_=GraphQLInt, + ), + }, +) + +NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py new file mode 100644 index 00000000..037d1518 --- /dev/null +++ b/tests/nested_input/test_nested_input.py @@ -0,0 +1,63 @@ +from functools import partial + +import pytest +from graphql import ( + EnumValueNode, + GraphQLEnumType, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + NameNode, + ObjectFieldNode, + ObjectValueNode, + ast_from_value, +) +from graphql.pyutils import FrozenList + +import gql.dsl as dsl +from gql import Client +from gql.dsl import DSLSchema, serialize_list +from tests.nested_input.schema import NestedInputSchema + +# back up the new func +new_get_arg_serializer = dsl.get_arg_serializer + + +def old_get_arg_serializer(arg_type, known_serializers=None): + if isinstance(arg_type, GraphQLNonNull): + return old_get_arg_serializer(arg_type.of_type) + if isinstance(arg_type, GraphQLInputField): + return old_get_arg_serializer(arg_type.type) + if isinstance(arg_type, GraphQLInputObjectType): + serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()} + return lambda value: ObjectValueNode( + fields=FrozenList( + ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) + for k, v in value.items() + ) + ) + if isinstance(arg_type, GraphQLList): + inner_serializer = old_get_arg_serializer(arg_type.of_type) + return partial(serialize_list, inner_serializer) + if isinstance(arg_type, GraphQLEnumType): + return lambda value: EnumValueNode(value=arg_type.serialize(value)) + return lambda value: ast_from_value(arg_type.serialize(value), arg_type) + + +@pytest.fixture +def ds(): + client = Client(schema=NestedInputSchema) + ds = DSLSchema(client) + return ds + + +def test_nested_input_with_old_get_arg_serializer(ds): + dsl.get_arg_serializer = old_get_arg_serializer + with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + ds.query(ds.Query.foo.args(nested={"foo": 1})) + + +def test_nested_input_with_new_get_arg_serializer(ds): + dsl.get_arg_serializer = new_get_arg_serializer + assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1}