Skip to content

Commit 35c67b7

Browse files
committed
Fix memoization of collect_subfields (#91)
1 parent d0a58ee commit 35c67b7

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/graphql/execution/execute.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def __init__(
157157
self.middleware_manager = middleware_manager
158158
if is_awaitable:
159159
self.is_awaitable = is_awaitable
160-
self._subfields_cache: Dict[
161-
Tuple[GraphQLObjectType, int], Dict[str, List[FieldNode]]
162-
] = {}
160+
self._subfields_cache: Dict[Tuple, Dict[str, List[FieldNode]]] = {}
163161

164162
@classmethod
165163
def build(
@@ -953,10 +951,20 @@ def collect_subfields(
953951
subfields are not repeatedly calculated, which saves overhead when resolving
954952
lists of values.
955953
"""
956-
# Use id(field_nodes) as key, since a list cannot be hashed and
957-
# (after conversion to a tuple) hashing nodes would be too slow:
958-
cache_key = return_type, id(field_nodes)
959-
sub_field_nodes = self._subfields_cache.get(cache_key)
954+
cache = self._subfields_cache
955+
# We cannot use the field_nodes themselves as key for the cache, since they
956+
# are not hashable as a list. We also do not want to use the field_nodes
957+
# themselves (converted to a tuple) as keys, since hashing them is slow.
958+
# Therefore we use the ids of the field_nodes as keys. Note that we do not
959+
# use the id of the list, since we want to hit the cache for all lists of
960+
# the same nodes, not only for the same list of nodes. Also, the list id may
961+
# even be reused, in which case we would get wrong results from the cache.
962+
key = (
963+
(return_type, id(field_nodes[0]))
964+
if len(field_nodes) == 1 # optimize most frequent case
965+
else tuple((return_type, *map(id, field_nodes)))
966+
)
967+
sub_field_nodes = cache.get(key)
960968
if sub_field_nodes is None:
961969
sub_field_nodes = {}
962970
visited_fragment_names: Set[str] = set()
@@ -969,7 +977,7 @@ def collect_subfields(
969977
sub_field_nodes,
970978
visited_fragment_names,
971979
)
972-
self._subfields_cache[cache_key] = sub_field_nodes
980+
cache[key] = sub_field_nodes
973981
return sub_field_nodes
974982

975983

0 commit comments

Comments
 (0)