|
55 | 55 | from . import config, convert_frame, exc, mutation_guard |
56 | 56 | from .eval_frame import set_guard_error_hook |
57 | 57 | from .source import DefaultsSource, LocalSource, TypeSource |
58 | | -from .types import GuardedCode, GuardFail, GuardFn # noqa: F401 |
| 58 | +from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 |
59 | 59 | from .utils import ( |
60 | 60 | common_constant_types, |
61 | 61 | dict_keys_repr, |
@@ -931,24 +931,22 @@ def must_add_nn_module_guards(guard): |
931 | 931 | ) |
932 | 932 |
|
933 | 933 |
|
| 934 | +class DeletedGuardFn: |
| 935 | + pass |
| 936 | + |
| 937 | + |
934 | 938 | # NB: Naively, you'd expect this to only be a function that produces |
935 | 939 | # the callable that constitutes the guard. However, there is some |
936 | 940 | # delicate handling for invalidating this check function when the |
937 | 941 | # locals/globals get invalidated, so there's some extra state |
938 | 942 | # we have to hold in this manager class. |
939 | | -# |
940 | | -# TODO: this object has reference cycle with itself, via check_fn which |
941 | | -# references back to CheckFunction via ___guarded_code in closure_vars. |
942 | | -# Ideally, there shouldn't be any ref cycle so that guards are |
943 | | -# promptly disposed of. |
944 | 943 | class CheckFunctionManager: |
945 | 944 | def __init__( |
946 | 945 | self, |
947 | 946 | output_graph=None, |
948 | 947 | guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, |
949 | 948 | ): |
950 | 949 | guards = output_graph.guards if output_graph else None |
951 | | - self.valid = True |
952 | 950 | self._weakrefs: Dict[int, ReferenceType[object]] = {} |
953 | 951 | self.output_graph = output_graph |
954 | 952 |
|
@@ -1025,7 +1023,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): |
1025 | 1023 | guards_log.debug("GUARDS:") |
1026 | 1024 |
|
1027 | 1025 | # Don't report this guard, it's always the same, useless! |
1028 | | - code_parts = ["___guarded_code.valid", "___check_global_state()"] |
| 1026 | + code_parts = ["___check_global_state()"] |
1029 | 1027 | verbose_code_parts = code_parts[:] |
1030 | 1028 |
|
1031 | 1029 | def add_code_part(code, guard, log_only=False): |
@@ -1157,7 +1155,6 @@ def convert(size_or_stride): |
1157 | 1155 | # we should only hit this case in NopTests() |
1158 | 1156 | global_state = convert_frame.GlobalStateGuard() |
1159 | 1157 | closure_vars = { |
1160 | | - "___guarded_code": self, |
1161 | 1158 | "___check_tensors": check_tensors_fn, |
1162 | 1159 | "___check_tensors_verbose": check_tensors_verbose_fn, |
1163 | 1160 | "___check_global_state": global_state.check, |
@@ -1194,14 +1191,28 @@ def convert(size_or_stride): |
1194 | 1191 | # Grab only G, but preserve "G" because guards access it as "G" |
1195 | 1192 | guard_fn.global_scope = globals_for_guard_fn |
1196 | 1193 | guard_fn.guard_fail_fn = guard_fail_fn |
| 1194 | + # will be populated by a non-owning reference to CacheEntry/ExtraState |
| 1195 | + # when the CacheEntry is constructed |
| 1196 | + guard_fn.cache_entry = None |
| 1197 | + guard_fn.extra_state = None |
1197 | 1198 | return guard_fn |
1198 | 1199 |
|
1199 | 1200 | def invalidate(self): |
1200 | | - # A weakref is no longer valid, self.check_fn should return false |
1201 | | - # TODO(janimesh) - Free up cache entry after the cache entry formation |
1202 | | - # is in python, and the underlying data structure is a doubly linked |
1203 | | - # list. |
1204 | | - self.valid = False |
| 1201 | + # Some tests reveal that CheckFunctionManager has no attribute |
| 1202 | + # check_fn, but this case should not be of any concern. |
| 1203 | + # This case doesn't seem easy to repro. |
| 1204 | + if ( |
| 1205 | + hasattr(self, "check_fn") |
| 1206 | + and self.check_fn is not DeletedGuardFn |
| 1207 | + and (cache_entry := self.check_fn.cache_entry) is not None |
| 1208 | + and (extra_state := self.check_fn.extra_state) is not None |
| 1209 | + ): |
| 1210 | + assert isinstance(cache_entry, CacheEntry) |
| 1211 | + assert isinstance(extra_state, ExtraState) |
| 1212 | + extra_state.invalidate(cache_entry) |
| 1213 | + self.check_fn.cache_entry = None |
| 1214 | + self.check_fn.extra_state = None |
| 1215 | + self.check_fn = DeletedGuardFn |
1205 | 1216 |
|
1206 | 1217 | def id_ref(self, obj): |
1207 | 1218 | """add a weakref, return the id""" |
|
0 commit comments