1
- from typing import ( Dict , List , Set , Iterator , Union )
1
+ from typing import Dict , List , Set , Iterator , Union , Optional , cast
2
2
from contextlib import contextmanager
3
3
4
4
from mypy .types import Type , AnyType , PartialType , UnionType , NoneTyp
@@ -31,6 +31,13 @@ def __init__(self) -> None:
31
31
self .unreachable = False
32
32
33
33
34
+ class DeclarationsFrame (Dict [Key , Optional [Type ]]):
35
+ """Same as above, but allowed to have None values."""
36
+
37
+ def __init__ (self ) -> None :
38
+ self .unreachable = False
39
+
40
+
34
41
class ConditionalTypeBinder :
35
42
"""Keep track of conditional types of variables.
36
43
@@ -68,9 +75,9 @@ def __init__(self) -> None:
68
75
# has no corresponding element in this list.
69
76
self .options_on_return = [] # type: List[List[Frame]]
70
77
71
- # Maps expr.literal_hash] to get_declaration(expr)
78
+ # Maps expr.literal_hash to get_declaration(expr)
72
79
# for every expr stored in the binder
73
- self .declarations = Frame ()
80
+ self .declarations = DeclarationsFrame ()
74
81
# Set of other keys to invalidate if a key is changed, e.g. x -> {x.a, x[0]}
75
82
# Whenever a new key (e.g. x.a.b) is added, we update this
76
83
self .dependencies = {} # type: Dict[Key, Set[Key]]
@@ -101,7 +108,7 @@ def push_frame(self) -> Frame:
101
108
def _put (self , key : Key , type : Type , index : int = - 1 ) -> None :
102
109
self .frames [index ][key ] = type
103
110
104
- def _get (self , key : Key , index : int = - 1 ) -> Type :
111
+ def _get (self , key : Key , index : int = - 1 ) -> Optional [ Type ] :
105
112
if index < 0 :
106
113
index += len (self .frames )
107
114
for i in range (index , - 1 , - 1 ):
@@ -124,7 +131,7 @@ def put(self, expr: Expression, typ: Type) -> None:
124
131
def unreachable (self ) -> None :
125
132
self .frames [- 1 ].unreachable = True
126
133
127
- def get (self , expr : Expression ) -> Type :
134
+ def get (self , expr : Expression ) -> Optional [ Type ] :
128
135
return self ._get (expr .literal_hash )
129
136
130
137
def is_unreachable (self ) -> bool :
@@ -163,15 +170,17 @@ def update_from_options(self, frames: List[Frame]) -> bool:
163
170
# know anything about key in at least one possible frame.
164
171
continue
165
172
173
+ type = resulting_values [0 ]
174
+ assert type is not None
166
175
if isinstance (self .declarations .get (key ), AnyType ):
167
- type = resulting_values [ 0 ]
168
- if not all (is_same_type (type , t ) for t in resulting_values [1 :]):
176
+ # At this point resulting values can't contain None, see continue above
177
+ if not all (is_same_type (type , cast ( Type , t ) ) for t in resulting_values [1 :]):
169
178
type = AnyType ()
170
179
else :
171
- type = resulting_values [0 ]
172
180
for other in resulting_values [1 :]:
181
+ assert other is not None
173
182
type = join_simple (self .declarations [key ], type , other )
174
- if not is_same_type (type , current_value ):
183
+ if current_value is None or not is_same_type (type , current_value ):
175
184
self ._put (key , type )
176
185
changed = True
177
186
@@ -252,7 +261,7 @@ def invalidate_dependencies(self, expr: BindableExpression) -> None:
252
261
for dep in self .dependencies .get (expr .literal_hash , set ()):
253
262
self ._cleanse_key (dep )
254
263
255
- def most_recent_enclosing_type (self , expr : BindableExpression , type : Type ) -> Type :
264
+ def most_recent_enclosing_type (self , expr : BindableExpression , type : Type ) -> Optional [ Type ] :
256
265
if isinstance (type , AnyType ):
257
266
return get_declaration (expr )
258
267
key = expr .literal_hash
@@ -342,7 +351,7 @@ def top_frame_context(self) -> Iterator[Frame]:
342
351
self .pop_frame (True , 0 )
343
352
344
353
345
- def get_declaration (expr : BindableExpression ) -> Type :
354
+ def get_declaration (expr : BindableExpression ) -> Optional [ Type ] :
346
355
if isinstance (expr , RefExpr ) and isinstance (expr .node , Var ):
347
356
type = expr .node .type
348
357
if not isinstance (type , PartialType ):
0 commit comments