Skip to content

Commit 84e0f84

Browse files
committed
Fix memoization of collect_subfields (#91)
1 parent 4192a5c commit 84e0f84

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/graphql/execution/execute.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,7 @@ def __init__(
210210
self.type_resolver = type_resolver # type: ignore
211211
self.errors = errors
212212
self.middleware_manager = middleware_manager
213-
self._subfields_cache: Dict[
214-
Tuple[GraphQLObjectType, int], Dict[str, List[FieldNode]]
215-
] = {}
213+
self._subfields_cache: Dict[Tuple, Dict[str, List[FieldNode]]] = {}
216214

217215
@classmethod
218216
def build(
@@ -995,10 +993,20 @@ def collect_subfields(
995993
subfields are not repeatedly calculated, which saves overhead when resolving
996994
lists of values.
997995
"""
998-
# Use id(field_nodes) as key, since a list cannot be hashed and
999-
# (after conversion to a tuple) hashing nodes would be too slow:
1000-
cache_key = return_type, id(field_nodes)
1001-
sub_field_nodes = self._subfields_cache.get(cache_key)
996+
cache = self._subfields_cache
997+
# We cannot use the field_nodes themselves as key for the cache, since they
998+
# are not hashable as a list. We also do not want to use the field_nodes
999+
# themselves (converted to a tuple) as keys, since hashing them is slow.
1000+
# Therefore we use the ids of the field_nodes as keys. Note that we do not
1001+
# use the id of the list, since we want to hit the cache for all lists of
1002+
# the same nodes, not only for the same list of nodes. Also, the list id may
1003+
# even be reused, in which case we would get wrong results from the cache.
1004+
key = (
1005+
(return_type, id(field_nodes[0]))
1006+
if len(field_nodes) == 1 # optimize most frequent case
1007+
else tuple((return_type, *map(id, field_nodes)))
1008+
)
1009+
sub_field_nodes = cache.get(key)
10021010
if sub_field_nodes is None:
10031011
sub_field_nodes = {}
10041012
visited_fragment_names: Set[str] = set()
@@ -1011,7 +1019,7 @@ def collect_subfields(
10111019
sub_field_nodes,
10121020
visited_fragment_names,
10131021
)
1014-
self._subfields_cache[cache_key] = sub_field_nodes
1022+
cache[key] = sub_field_nodes
10151023
return sub_field_nodes
10161024

10171025

0 commit comments

Comments
 (0)