Skip to content

Commit 99f3869

Browse files
committed
Make collect_sub_fields a part of the public API
Replicates graphql/graphql-js@e95ea9b
1 parent ecbb566 commit 99f3869

File tree

4 files changed

+65
-32
lines changed

4 files changed

+65
-32
lines changed

src/graphql/execution/collect_fields.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..utilities.type_from_ast import type_from_ast
1919
from .values import get_directive_values
2020

21-
__all__ = ["collect_fields"]
21+
__all__ = ["collect_fields", "collect_sub_fields"]
2222

2323

2424
def collect_fields(
@@ -27,20 +27,68 @@ def collect_fields(
2727
variable_values: Dict[str, Any],
2828
runtime_type: GraphQLObjectType,
2929
selection_set: SelectionSetNode,
30-
fields: Dict[str, List[FieldNode]],
31-
visited_fragment_names: Set[str],
3230
) -> Dict[str, List[FieldNode]]:
3331
"""Collect fields.
3432
35-
Given a selection_set, adds all of the fields in that selection to the passed in
36-
map of fields, and returns it at the end.
33+
Given a selection_set, collects all of the fields returns them at the end.
3734
38-
collect_fields requires the "runtime type" of an object. For a field which
35+
collect_fields requires the "runtime type" of an object. For a field that
3936
returns an Interface or Union type, the "runtime type" will be the actual
40-
Object type returned by that field.
37+
object type returned by that field.
4138
4239
For internal use only.
4340
"""
41+
fields: Dict[str, List[FieldNode]] = {}
42+
collect_fields_impl(
43+
schema, fragments, variable_values, runtime_type, selection_set, fields, set()
44+
)
45+
return fields
46+
47+
48+
def collect_sub_fields(
49+
schema: GraphQLSchema,
50+
fragments: Dict[str, FragmentDefinitionNode],
51+
variable_values: Dict[str, Any],
52+
return_type: GraphQLObjectType,
53+
field_nodes: List[FieldNode],
54+
) -> Dict[str, List[FieldNode]]:
55+
"""Collect sub fields.
56+
57+
Given a list of field nodes, collects all of the subfields of the passed
58+
in fields, and returns them at the end.
59+
60+
collect_sub_fields requires the "return type" of an object. For a field that
61+
returns an Interface or Union type, the "return type" will be the actual
62+
object type returned by that field.
63+
64+
For internal use only.
65+
"""
66+
sub_field_nodes: Dict[str, List[FieldNode]] = {}
67+
visited_fragment_names: Set[str] = set()
68+
for node in field_nodes:
69+
if node.selection_set:
70+
collect_fields_impl(
71+
schema,
72+
fragments,
73+
variable_values,
74+
return_type,
75+
node.selection_set,
76+
sub_field_nodes,
77+
visited_fragment_names,
78+
)
79+
return sub_field_nodes
80+
81+
82+
def collect_fields_impl(
83+
schema: GraphQLSchema,
84+
fragments: Dict[str, FragmentDefinitionNode],
85+
variable_values: Dict[str, Any],
86+
runtime_type: GraphQLObjectType,
87+
selection_set: SelectionSetNode,
88+
fields: Dict[str, List[FieldNode]],
89+
visited_fragment_names: Set[str],
90+
) -> None:
91+
"""Collect fields (internal implementation)."""
4492
for selection in selection_set.selections:
4593
if isinstance(selection, FieldNode):
4694
if not should_include_node(variable_values, selection):
@@ -52,7 +100,7 @@ def collect_fields(
52100
variable_values, selection
53101
) or not does_fragment_condition_match(schema, selection, runtime_type):
54102
continue
55-
collect_fields(
103+
collect_fields_impl(
56104
schema,
57105
fragments,
58106
variable_values,
@@ -73,7 +121,7 @@ def collect_fields(
73121
schema, fragment, runtime_type
74122
):
75123
continue
76-
collect_fields(
124+
collect_fields_impl(
77125
schema,
78126
fragments,
79127
variable_values,
@@ -82,7 +130,6 @@ def collect_fields(
82130
fields,
83131
visited_fragment_names,
84132
)
85-
return fields
86133

87134

88135
def should_include_node(

src/graphql/execution/execute.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Iterable,
1010
List,
1111
Optional,
12-
Set,
1312
Union,
1413
Tuple,
1514
Type,
@@ -61,7 +60,7 @@
6160
is_non_null_type,
6261
is_object_type,
6362
)
64-
from .collect_fields import collect_fields
63+
from .collect_fields import collect_fields, collect_sub_fields
6564
from .middleware import MiddlewareManager
6665
from .values import get_argument_values, get_variable_values
6766

@@ -343,8 +342,6 @@ def execute_operation(
343342
self.variable_values,
344343
type_,
345344
operation.selection_set,
346-
{},
347-
set(),
348345
)
349346

350347
path = None
@@ -961,20 +958,13 @@ def collect_subfields(
961958
)
962959
sub_field_nodes = cache.get(key)
963960
if sub_field_nodes is None:
964-
sub_field_nodes = {}
965-
visited_fragment_names: Set[str] = set()
966-
for field_node in field_nodes:
967-
selection_set = field_node.selection_set
968-
if selection_set:
969-
sub_field_nodes = collect_fields(
970-
self.schema,
971-
self.fragments,
972-
self.variable_values,
973-
return_type,
974-
selection_set,
975-
sub_field_nodes,
976-
visited_fragment_names,
977-
)
961+
sub_field_nodes = collect_sub_fields(
962+
self.schema,
963+
self.fragments,
964+
self.variable_values,
965+
return_type,
966+
field_nodes,
967+
)
978968
cache[key] = sub_field_nodes
979969
return sub_field_nodes
980970

src/graphql/subscription/subscribe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
170170
context.variable_values,
171171
type_,
172172
context.operation.selection_set,
173-
{},
174-
set(),
175173
)
176174
response_name, field_nodes = next(iter(fields.items()))
177175
field_def = get_field_def(schema, type_, field_nodes[0])

src/graphql/validation/rules/single_field_subscriptions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ def enter_operation_definition(
4242
variable_values,
4343
subscription_type,
4444
node.selection_set,
45-
{},
46-
set(),
4745
)
4846
if len(fields) > 1:
4947
field_selection_lists = list(fields.values())

0 commit comments

Comments
 (0)