Skip to content

Commit 29b9f46

Browse files
committed
Minor type cleanup in the binder and explanation of Key
This also fixes a bug where the index expression x['m'] was being treated by the binder as equivalent to the member expression x.m.
1 parent 02de8f9 commit 29b9f46

File tree

2 files changed

+64
-25
lines changed

2 files changed

+64
-25
lines changed

mypy/binder.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from typing import (Any, Dict, List, Set, Iterator)
1+
from typing import (Dict, List, Set, Iterator)
22
from contextlib import contextmanager
33

44
from mypy.types import Type, AnyType, PartialType
5-
from mypy.nodes import (Node, Var)
5+
from mypy.nodes import (Key, Node, Var, RefExpr)
66

77
from mypy.subtypes import is_subtype
88
from mypy.join import join_simple
99
from mypy.sametypes import is_same_type
1010

1111

12-
class Frame(Dict[Any, Type]):
12+
class Frame(Dict[Key, Type]):
1313
"""A Frame represents a specific point in the execution of a program.
1414
It carries information about the current types of expressions at
1515
that point, arising either from assignments to those expressions
@@ -25,10 +25,6 @@ def __init__(self) -> None:
2525
self.unreachable = False
2626

2727

28-
class Key(AnyType):
29-
pass
30-
31-
3228
class ConditionalTypeBinder:
3329
"""Keep track of conditional types of variables.
3430
@@ -84,8 +80,8 @@ def _add_dependencies(self, key: Key, value: Key = None) -> None:
8480
value = key
8581
else:
8682
self.dependencies.setdefault(key, set()).add(value)
87-
if isinstance(key, tuple):
88-
for elt in key:
83+
for elt in key:
84+
if isinstance(elt, Key):
8985
self._add_dependencies(elt, value)
9086

9187
def push_frame(self) -> Frame:
@@ -190,8 +186,8 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
190186

191187
return result
192188

193-
def get_declaration(self, expr: Any) -> Type:
194-
if hasattr(expr, 'node') and isinstance(expr.node, Var):
189+
def get_declaration(self, expr: Node) -> Type:
190+
if isinstance(expr, RefExpr) and isinstance(expr.node, Var):
195191
type = expr.node.type
196192
if isinstance(type, PartialType):
197193
return None

mypy/nodes.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,18 @@ def get_line(self) -> int: pass
8888
for alias, name in type_aliases.items()) # type: Dict[str, str]
8989

9090

91+
# See [Note Literals and literal_hash] below
92+
Key = tuple
93+
94+
9195
class Node(Context):
9296
"""Common base class for all non-type parse tree nodes."""
9397

9498
line = -1
9599

100+
# See [Note Literals and literal_hash] below
96101
literal = LITERAL_NO
97-
literal_hash = None # type: Any
102+
literal_hash = None # type: Key
98103

99104
def __str__(self) -> str:
100105
ans = self.accept(mypy.strconv.StrConv())
@@ -132,6 +137,44 @@ def deserialize(cls, data: JsonDict) -> 'Node':
132137
raise NotImplementedError('unexpected .class {}'.format(classname))
133138

134139

140+
# [Note Literals and literal_hash]
141+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142+
#
143+
# Mypy uses the term "literal" to refer to any expression built out of
144+
# the following:
145+
#
146+
# * Plain literal expressions, like `1` (integer, float, string, etc.)
147+
#
148+
# * Compound literal expressions, like `(lit1, lit2)` (list, dict,
149+
# set, or tuple)
150+
#
151+
# * Operator expressions, like `lit1 + lit2`
152+
#
153+
# * Variable references, like `x`
154+
#
155+
# * Member references, like `lit.m`
156+
#
157+
# * Index expressions, like `lit[0]`
158+
#
159+
# A typical "literal" looks like `x[(i,j+1)].m`.
160+
#
161+
# An expression that is a literal has a `literal_hash`, with the
162+
# following properties.
163+
#
164+
# * `literal_hash` is a Key: a tuple containing basic data types and
165+
# possibly other Keys. So it can be used as a key in a dictionary
166+
# that will be compared by value (as opposed to the Node itself,
167+
# which is compared by identity).
168+
#
169+
# * Two expressions have equal `literal_hash`es if and only if they
170+
# are syntactically equal expressions. (NB: Actually, we also
171+
# identify as equal expressions like `3` and `3.0`; is this a good
172+
# idea?)
173+
#
174+
# * The elements of `literal_hash` that are tuples are exactly the
175+
# subexpressions of the original expression (e.g. the base and index
176+
# of an index expression, or the operands of an operator expression).
177+
135178
class SymbolNode(Node):
136179
# Nodes that can be stored in a symbol table.
137180

@@ -953,7 +996,7 @@ class IntExpr(Node):
953996

954997
def __init__(self, value: int) -> None:
955998
self.value = value
956-
self.literal_hash = value
999+
self.literal_hash = ('Literal', value)
9571000

9581001
def accept(self, visitor: NodeVisitor[T]) -> T:
9591002
return visitor.visit_int_expr(self)
@@ -978,7 +1021,7 @@ class StrExpr(Node):
9781021

9791022
def __init__(self, value: str) -> None:
9801023
self.value = value
981-
self.literal_hash = value
1024+
self.literal_hash = ('Literal', value)
9821025

9831026
def accept(self, visitor: NodeVisitor[T]) -> T:
9841027
return visitor.visit_str_expr(self)
@@ -992,7 +1035,7 @@ class BytesExpr(Node):
9921035

9931036
def __init__(self, value: str) -> None:
9941037
self.value = value
995-
self.literal_hash = value
1038+
self.literal_hash = ('Literal', value)
9961039

9971040
def accept(self, visitor: NodeVisitor[T]) -> T:
9981041
return visitor.visit_bytes_expr(self)
@@ -1006,7 +1049,7 @@ class UnicodeExpr(Node):
10061049

10071050
def __init__(self, value: str) -> None:
10081051
self.value = value
1009-
self.literal_hash = value
1052+
self.literal_hash = ('Literal', value)
10101053

10111054
def accept(self, visitor: NodeVisitor[T]) -> T:
10121055
return visitor.visit_unicode_expr(self)
@@ -1020,7 +1063,7 @@ class FloatExpr(Node):
10201063

10211064
def __init__(self, value: float) -> None:
10221065
self.value = value
1023-
self.literal_hash = value
1066+
self.literal_hash = ('Literal', value)
10241067

10251068
def accept(self, visitor: NodeVisitor[T]) -> T:
10261069
return visitor.visit_float_expr(self)
@@ -1034,7 +1077,7 @@ class ComplexExpr(Node):
10341077

10351078
def __init__(self, value: complex) -> None:
10361079
self.value = value
1037-
self.literal_hash = value
1080+
self.literal_hash = ('Literal', value)
10381081

10391082
def accept(self, visitor: NodeVisitor[T]) -> T:
10401083
return visitor.visit_complex_expr(self)
@@ -1223,7 +1266,7 @@ def __init__(self, base: Node, index: Node) -> None:
12231266
self.analyzed = None
12241267
if self.index.literal == LITERAL_YES:
12251268
self.literal = self.base.literal
1226-
self.literal_hash = ('Member', base.literal_hash,
1269+
self.literal_hash = ('Index', base.literal_hash,
12271270
index.literal_hash)
12281271

12291272
def accept(self, visitor: NodeVisitor[T]) -> T:
@@ -1336,7 +1379,7 @@ def __init__(self, operators: List[str], operands: List[Node]) -> None:
13361379
self.operands = operands
13371380
self.method_types = []
13381381
self.literal = min(o.literal for o in self.operands)
1339-
self.literal_hash = (('Comparison',) + tuple(operators) +
1382+
self.literal_hash = ((cast(Any, 'Comparison'),) + tuple(operators) +
13401383
tuple(o.literal_hash for o in operands))
13411384

13421385
def accept(self, visitor: NodeVisitor[T]) -> T:
@@ -1426,7 +1469,7 @@ def __init__(self, items: List[Node]) -> None:
14261469
self.items = items
14271470
if all(x.literal == LITERAL_YES for x in items):
14281471
self.literal = LITERAL_YES
1429-
self.literal_hash = ('List',) + tuple(x.literal_hash for x in items)
1472+
self.literal_hash = (cast(Any, 'List'),) + tuple(x.literal_hash for x in items)
14301473

14311474
def accept(self, visitor: NodeVisitor[T]) -> T:
14321475
return visitor.visit_list_expr(self)
@@ -1442,8 +1485,8 @@ def __init__(self, items: List[Tuple[Node, Node]]) -> None:
14421485
if all(x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES
14431486
for x in items):
14441487
self.literal = LITERAL_YES
1445-
self.literal_hash = ('Dict',) + tuple(
1446-
(x[0].literal_hash, x[1].literal_hash) for x in items) # type: ignore
1488+
self.literal_hash = (cast(Any, 'Dict'),) + tuple(
1489+
(x[0].literal_hash, x[1].literal_hash) for x in items)
14471490

14481491
def accept(self, visitor: NodeVisitor[T]) -> T:
14491492
return visitor.visit_dict_expr(self)
@@ -1458,7 +1501,7 @@ def __init__(self, items: List[Node]) -> None:
14581501
self.items = items
14591502
if all(x.literal == LITERAL_YES for x in items):
14601503
self.literal = LITERAL_YES
1461-
self.literal_hash = ('Tuple',) + tuple(x.literal_hash for x in items)
1504+
self.literal_hash = (cast(Any, 'Tuple'),) + tuple(x.literal_hash for x in items)
14621505

14631506
def accept(self, visitor: NodeVisitor[T]) -> T:
14641507
return visitor.visit_tuple_expr(self)
@@ -1473,7 +1516,7 @@ def __init__(self, items: List[Node]) -> None:
14731516
self.items = items
14741517
if all(x.literal == LITERAL_YES for x in items):
14751518
self.literal = LITERAL_YES
1476-
self.literal_hash = ('Set',) + tuple(x.literal_hash for x in items)
1519+
self.literal_hash = (cast(Any, 'Set'),) + tuple(x.literal_hash for x in items)
14771520

14781521
def accept(self, visitor: NodeVisitor[T]) -> T:
14791522
return visitor.visit_set_expr(self)

0 commit comments

Comments
 (0)