Skip to content

Commit a9c62c5

Browse files
stubtest: if a default is present in the stub, check that it is correct (#14085)
Helps with python/typeshed#8988.
1 parent d58a851 commit a9c62c5

File tree

3 files changed

+277
-1
lines changed

3 files changed

+277
-1
lines changed

mypy/evalexpr.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
3+
Evaluate an expression.
4+
5+
Used by stubtest; in a separate file because things break if we don't
6+
put it in a mypyc-compiled file.
7+
8+
"""
9+
import ast
10+
from typing_extensions import Final
11+
12+
import mypy.nodes
13+
from mypy.visitor import ExpressionVisitor
14+
15+
UNKNOWN = object()
16+
17+
18+
class _NodeEvaluator(ExpressionVisitor[object]):
19+
def visit_int_expr(self, o: mypy.nodes.IntExpr) -> int:
20+
return o.value
21+
22+
def visit_str_expr(self, o: mypy.nodes.StrExpr) -> str:
23+
return o.value
24+
25+
def visit_bytes_expr(self, o: mypy.nodes.BytesExpr) -> object:
26+
# The value of a BytesExpr is a string created from the repr()
27+
# of the bytes object. Get the original bytes back.
28+
try:
29+
return ast.literal_eval(f"b'{o.value}'")
30+
except SyntaxError:
31+
return ast.literal_eval(f'b"{o.value}"')
32+
33+
def visit_float_expr(self, o: mypy.nodes.FloatExpr) -> float:
34+
return o.value
35+
36+
def visit_complex_expr(self, o: mypy.nodes.ComplexExpr) -> object:
37+
return o.value
38+
39+
def visit_ellipsis(self, o: mypy.nodes.EllipsisExpr) -> object:
40+
return Ellipsis
41+
42+
def visit_star_expr(self, o: mypy.nodes.StarExpr) -> object:
43+
return UNKNOWN
44+
45+
def visit_name_expr(self, o: mypy.nodes.NameExpr) -> object:
46+
if o.name == "True":
47+
return True
48+
elif o.name == "False":
49+
return False
50+
elif o.name == "None":
51+
return None
52+
# TODO: Handle more names by figuring out a way to hook into the
53+
# symbol table.
54+
return UNKNOWN
55+
56+
def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> object:
57+
return UNKNOWN
58+
59+
def visit_yield_from_expr(self, o: mypy.nodes.YieldFromExpr) -> object:
60+
return UNKNOWN
61+
62+
def visit_yield_expr(self, o: mypy.nodes.YieldExpr) -> object:
63+
return UNKNOWN
64+
65+
def visit_call_expr(self, o: mypy.nodes.CallExpr) -> object:
66+
return UNKNOWN
67+
68+
def visit_op_expr(self, o: mypy.nodes.OpExpr) -> object:
69+
return UNKNOWN
70+
71+
def visit_comparison_expr(self, o: mypy.nodes.ComparisonExpr) -> object:
72+
return UNKNOWN
73+
74+
def visit_cast_expr(self, o: mypy.nodes.CastExpr) -> object:
75+
return o.expr.accept(self)
76+
77+
def visit_assert_type_expr(self, o: mypy.nodes.AssertTypeExpr) -> object:
78+
return o.expr.accept(self)
79+
80+
def visit_reveal_expr(self, o: mypy.nodes.RevealExpr) -> object:
81+
return UNKNOWN
82+
83+
def visit_super_expr(self, o: mypy.nodes.SuperExpr) -> object:
84+
return UNKNOWN
85+
86+
def visit_unary_expr(self, o: mypy.nodes.UnaryExpr) -> object:
87+
operand = o.expr.accept(self)
88+
if operand is UNKNOWN:
89+
return UNKNOWN
90+
if o.op == "-":
91+
if isinstance(operand, (int, float, complex)):
92+
return -operand
93+
elif o.op == "+":
94+
if isinstance(operand, (int, float, complex)):
95+
return +operand
96+
elif o.op == "~":
97+
if isinstance(operand, int):
98+
return ~operand
99+
elif o.op == "not":
100+
if isinstance(operand, (bool, int, float, str, bytes)):
101+
return not operand
102+
return UNKNOWN
103+
104+
def visit_assignment_expr(self, o: mypy.nodes.AssignmentExpr) -> object:
105+
return o.value.accept(self)
106+
107+
def visit_list_expr(self, o: mypy.nodes.ListExpr) -> object:
108+
items = [item.accept(self) for item in o.items]
109+
if all(item is not UNKNOWN for item in items):
110+
return items
111+
return UNKNOWN
112+
113+
def visit_dict_expr(self, o: mypy.nodes.DictExpr) -> object:
114+
items = [
115+
(UNKNOWN if key is None else key.accept(self), value.accept(self))
116+
for key, value in o.items
117+
]
118+
if all(key is not UNKNOWN and value is not None for key, value in items):
119+
return dict(items)
120+
return UNKNOWN
121+
122+
def visit_tuple_expr(self, o: mypy.nodes.TupleExpr) -> object:
123+
items = [item.accept(self) for item in o.items]
124+
if all(item is not UNKNOWN for item in items):
125+
return tuple(items)
126+
return UNKNOWN
127+
128+
def visit_set_expr(self, o: mypy.nodes.SetExpr) -> object:
129+
items = [item.accept(self) for item in o.items]
130+
if all(item is not UNKNOWN for item in items):
131+
return set(items)
132+
return UNKNOWN
133+
134+
def visit_index_expr(self, o: mypy.nodes.IndexExpr) -> object:
135+
return UNKNOWN
136+
137+
def visit_type_application(self, o: mypy.nodes.TypeApplication) -> object:
138+
return UNKNOWN
139+
140+
def visit_lambda_expr(self, o: mypy.nodes.LambdaExpr) -> object:
141+
return UNKNOWN
142+
143+
def visit_list_comprehension(self, o: mypy.nodes.ListComprehension) -> object:
144+
return UNKNOWN
145+
146+
def visit_set_comprehension(self, o: mypy.nodes.SetComprehension) -> object:
147+
return UNKNOWN
148+
149+
def visit_dictionary_comprehension(self, o: mypy.nodes.DictionaryComprehension) -> object:
150+
return UNKNOWN
151+
152+
def visit_generator_expr(self, o: mypy.nodes.GeneratorExpr) -> object:
153+
return UNKNOWN
154+
155+
def visit_slice_expr(self, o: mypy.nodes.SliceExpr) -> object:
156+
return UNKNOWN
157+
158+
def visit_conditional_expr(self, o: mypy.nodes.ConditionalExpr) -> object:
159+
return UNKNOWN
160+
161+
def visit_type_var_expr(self, o: mypy.nodes.TypeVarExpr) -> object:
162+
return UNKNOWN
163+
164+
def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> object:
165+
return UNKNOWN
166+
167+
def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> object:
168+
return UNKNOWN
169+
170+
def visit_type_alias_expr(self, o: mypy.nodes.TypeAliasExpr) -> object:
171+
return UNKNOWN
172+
173+
def visit_namedtuple_expr(self, o: mypy.nodes.NamedTupleExpr) -> object:
174+
return UNKNOWN
175+
176+
def visit_enum_call_expr(self, o: mypy.nodes.EnumCallExpr) -> object:
177+
return UNKNOWN
178+
179+
def visit_typeddict_expr(self, o: mypy.nodes.TypedDictExpr) -> object:
180+
return UNKNOWN
181+
182+
def visit_newtype_expr(self, o: mypy.nodes.NewTypeExpr) -> object:
183+
return UNKNOWN
184+
185+
def visit__promote_expr(self, o: mypy.nodes.PromoteExpr) -> object:
186+
return UNKNOWN
187+
188+
def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> object:
189+
return UNKNOWN
190+
191+
def visit_temp_node(self, o: mypy.nodes.TempNode) -> object:
192+
return UNKNOWN
193+
194+
195+
_evaluator: Final = _NodeEvaluator()
196+
197+
198+
def evaluate_expression(expr: mypy.nodes.Expression) -> object:
199+
"""Evaluate an expression at runtime.
200+
201+
Return the result of the expression, or UNKNOWN if the expression cannot be
202+
evaluated.
203+
"""
204+
return expr.accept(_evaluator)

mypy/stubtest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929

3030
import mypy.build
3131
import mypy.modulefinder
32+
import mypy.nodes
3233
import mypy.state
3334
import mypy.types
3435
import mypy.version
3536
from mypy import nodes
3637
from mypy.config_parser import parse_config_file
38+
from mypy.evalexpr import UNKNOWN, evaluate_expression
3739
from mypy.options import Options
3840
from mypy.util import FancyFormatter, bytes_to_human_readable_repr, is_dunder, plural_s
3941

@@ -573,6 +575,23 @@ def _verify_arg_default_value(
573575
f"has a default value of type {runtime_type}, "
574576
f"which is incompatible with stub argument type {stub_type}"
575577
)
578+
if stub_arg.initializer is not None:
579+
stub_default = evaluate_expression(stub_arg.initializer)
580+
if (
581+
stub_default is not UNKNOWN
582+
and stub_default is not ...
583+
and (
584+
stub_default != runtime_arg.default
585+
# We want the types to match exactly, e.g. in case the stub has
586+
# True and the runtime has 1 (or vice versa).
587+
or type(stub_default) is not type(runtime_arg.default) # noqa: E721
588+
)
589+
):
590+
yield (
591+
f'runtime argument "{runtime_arg.name}" '
592+
f"has a default value of {runtime_arg.default!r}, "
593+
f"which is different from stub argument default {stub_default!r}"
594+
)
576595
else:
577596
if stub_arg.kind.is_optional():
578597
yield (

mypy/test/teststubtest.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_arg_kind(self) -> Iterator[Case]:
302302
)
303303

304304
@collect_cases
305-
def test_default_value(self) -> Iterator[Case]:
305+
def test_default_presence(self) -> Iterator[Case]:
306306
yield Case(
307307
stub="def f1(text: str = ...) -> None: ...",
308308
runtime="def f1(text = 'asdf'): pass",
@@ -336,6 +336,59 @@ def f6(text: _T = ...) -> None: ...
336336
error="f6",
337337
)
338338

339+
@collect_cases
340+
def test_default_value(self) -> Iterator[Case]:
341+
yield Case(
342+
stub="def f1(text: str = 'x') -> None: ...",
343+
runtime="def f1(text = 'y'): pass",
344+
error="f1",
345+
)
346+
yield Case(
347+
stub='def f2(text: bytes = b"x\'") -> None: ...',
348+
runtime='def f2(text = b"x\'"): pass',
349+
error=None,
350+
)
351+
yield Case(
352+
stub='def f3(text: bytes = b"y\'") -> None: ...',
353+
runtime='def f3(text = b"x\'"): pass',
354+
error="f3",
355+
)
356+
yield Case(
357+
stub="def f4(text: object = 1) -> None: ...",
358+
runtime="def f4(text = 1.0): pass",
359+
error="f4",
360+
)
361+
yield Case(
362+
stub="def f5(text: object = True) -> None: ...",
363+
runtime="def f5(text = 1): pass",
364+
error="f5",
365+
)
366+
yield Case(
367+
stub="def f6(text: object = True) -> None: ...",
368+
runtime="def f6(text = True): pass",
369+
error=None,
370+
)
371+
yield Case(
372+
stub="def f7(text: object = not True) -> None: ...",
373+
runtime="def f7(text = False): pass",
374+
error=None,
375+
)
376+
yield Case(
377+
stub="def f8(text: object = not True) -> None: ...",
378+
runtime="def f8(text = True): pass",
379+
error="f8",
380+
)
381+
yield Case(
382+
stub="def f9(text: object = {1: 2}) -> None: ...",
383+
runtime="def f9(text = {1: 3}): pass",
384+
error="f9",
385+
)
386+
yield Case(
387+
stub="def f10(text: object = [1, 2]) -> None: ...",
388+
runtime="def f10(text = [1, 2]): pass",
389+
error=None,
390+
)
391+
339392
@collect_cases
340393
def test_static_class_method(self) -> Iterator[Case]:
341394
yield Case(

0 commit comments

Comments
 (0)