Skip to content

Commit 3e03484

Browse files
authored
Improve --update-data handler (#15283)
1 parent 85e6719 commit 3e03484

File tree

6 files changed

+274
-37
lines changed

6 files changed

+274
-37
lines changed

mypy/test/data.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
from abc import abstractmethod
1313
from dataclasses import dataclass
14+
from pathlib import Path
1415
from typing import Any, Iterator, NamedTuple, Pattern, Union
1516
from typing_extensions import Final, TypeAlias as _TypeAlias
1617

@@ -426,8 +427,16 @@ class TestItem:
426427

427428
id: str
428429
arg: str | None
430+
# Processed, collapsed text data
429431
data: list[str]
432+
# Start line: 1-based, inclusive, relative to testcase
430433
line: int
434+
# End line: 1-based, exclusive, relative to testcase; not same as `line + len(test_item.data)` due to collapsing
435+
end_line: int
436+
437+
@property
438+
def trimmed_newlines(self) -> int: # compensates for strip_list
439+
return self.end_line - self.line - len(self.data)
431440

432441

433442
def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
@@ -449,7 +458,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
449458
if id:
450459
data = collapse_line_continuation(data)
451460
data = strip_list(data)
452-
ret.append(TestItem(id, arg, strip_list(data), i0 + 1))
461+
ret.append(TestItem(id, arg, data, i0 + 1, i))
453462

454463
i0 = i
455464
id = s[1:-1]
@@ -470,7 +479,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
470479
if id:
471480
data = collapse_line_continuation(data)
472481
data = strip_list(data)
473-
ret.append(TestItem(id, arg, data, i0 + 1))
482+
ret.append(TestItem(id, arg, data, i0 + 1, i - 1))
474483

475484
return ret
476485

@@ -693,6 +702,12 @@ def collect(self) -> Iterator[DataFileCollector]:
693702
yield DataFileCollector.from_parent(parent=self, name=data_file)
694703

695704

705+
class DataFileFix(NamedTuple):
706+
lineno: int # 1-offset, inclusive
707+
end_lineno: int # 1-offset, exclusive
708+
lines: list[str]
709+
710+
696711
class DataFileCollector(pytest.Collector):
697712
"""Represents a single `.test` data driven test file.
698713
@@ -701,6 +716,8 @@ class DataFileCollector(pytest.Collector):
701716

702717
parent: DataSuiteCollector
703718

719+
_fixes: list[DataFileFix]
720+
704721
@classmethod # We have to fight with pytest here:
705722
def from_parent(
706723
cls, parent: DataSuiteCollector, *, name: str # type: ignore[override]
@@ -716,6 +733,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]:
716733
file=os.path.join(self.parent.obj.data_prefix, self.name),
717734
)
718735

736+
def setup(self) -> None:
737+
super().setup()
738+
self._fixes = []
739+
740+
def teardown(self) -> None:
741+
super().teardown()
742+
self._apply_fixes()
743+
744+
def enqueue_fix(self, fix: DataFileFix) -> None:
745+
self._fixes.append(fix)
746+
747+
def _apply_fixes(self) -> None:
748+
if not self._fixes:
749+
return
750+
data_path = Path(self.parent.obj.data_prefix) / self.name
751+
lines = data_path.read_text().split("\n")
752+
# start from end to prevent line offsets from shifting as we update
753+
for fix in sorted(self._fixes, reverse=True):
754+
lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines
755+
data_path.write_text("\n".join(lines))
756+
719757

720758
def add_test_name_suffix(name: str, suffix: str) -> str:
721759
# Find magic suffix of form "-foobar" (used for things like "-skip").

mypy/test/helpers.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str])
141141
)
142142

143143

144-
def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
145-
assert testcase.old_cwd is not None, "test was not properly set up"
146-
testcase_path = os.path.join(testcase.old_cwd, testcase.file)
147-
with open(testcase_path, encoding="utf8") as f:
148-
data_lines = f.read().splitlines()
149-
test = "\n".join(data_lines[testcase.line : testcase.last_line])
150-
151-
mapping: dict[str, list[str]] = {}
152-
for old, new in zip(testcase.output, output):
153-
PREFIX = "error:"
154-
ind = old.find(PREFIX)
155-
if ind != -1 and old[:ind] == new[:ind]:
156-
old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :]
157-
mapping.setdefault(old, []).append(new)
158-
159-
for old in mapping:
160-
if test.count(old) == len(mapping[old]):
161-
betweens = test.split(old)
162-
163-
# Interleave betweens and mapping[old]
164-
from itertools import chain
165-
166-
interleaved = [betweens[0]] + list(
167-
chain.from_iterable(zip(mapping[old], betweens[1:]))
168-
)
169-
test = "".join(interleaved)
170-
171-
data_lines[testcase.line : testcase.last_line] = [test]
172-
data = "\n".join(data_lines)
173-
with open(testcase_path, "w", encoding="utf8") as f:
174-
print(data, file=f)
175-
176-
177144
def show_align_message(s1: str, s2: str) -> None:
178145
"""Align s1 and s2 so that the their first difference is highlighted.
179146

mypy/test/testcheck.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
normalize_error_messages,
2323
parse_options,
2424
perform_file_operations,
25-
update_testcase_output,
2625
)
26+
from mypy.test.update_data import update_testcase_output
2727

2828
try:
2929
import lxml # type: ignore[import]
@@ -192,7 +192,8 @@ def run_case_once(
192192
output = testcase.output2.get(incremental_step, [])
193193

194194
if output != a and testcase.config.getoption("--update-data", False):
195-
update_testcase_output(testcase, a)
195+
update_testcase_output(testcase, a, incremental_step=incremental_step)
196+
196197
assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))
197198

198199
if res:

mypy/test/testupdatedata.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import shlex
2+
import subprocess
3+
import sys
4+
import textwrap
5+
from pathlib import Path
6+
7+
from mypy.test.config import test_data_prefix
8+
from mypy.test.helpers import Suite
9+
10+
11+
class UpdateDataSuite(Suite):
12+
def _run_pytest_update_data(self, data_suite: str, *, max_attempts: int) -> str:
13+
"""
14+
Runs a suite of data test cases through 'pytest --update-data' until either tests pass
15+
or until a maximum number of attempts (needed for incremental tests).
16+
"""
17+
p = Path(test_data_prefix) / "check-update-data.test"
18+
assert not p.exists()
19+
try:
20+
p.write_text(textwrap.dedent(data_suite).lstrip())
21+
22+
test_nodeid = f"mypy/test/testcheck.py::TypeCheckSuite::{p.name}"
23+
args = [sys.executable, "-m", "pytest", "-n", "0", "-s", "--update-data", test_nodeid]
24+
if sys.version_info >= (3, 8):
25+
cmd = shlex.join(args)
26+
else:
27+
cmd = " ".join(args)
28+
for i in range(max_attempts - 1, -1, -1):
29+
res = subprocess.run(args)
30+
if res.returncode == 0:
31+
break
32+
print(f"`{cmd}` returned {res.returncode}: {i} attempts remaining")
33+
34+
return p.read_text()
35+
finally:
36+
p.unlink()
37+
38+
def test_update_data(self) -> None:
39+
# Note: We test multiple testcases rather than 'test case per test case'
40+
# so we could also exercise rewriting multiple testcases at once.
41+
actual = self._run_pytest_update_data(
42+
"""
43+
[case testCorrect]
44+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
45+
46+
[case testWrong]
47+
s: str = 42 # E: wrong error
48+
49+
[case testWrongMultiline]
50+
s: str = 42 # E: foo \
51+
# N: bar
52+
53+
[case testMissingMultiline]
54+
s: str = 42; i: int = 'foo'
55+
56+
[case testExtraneous]
57+
s: str = 'foo' # E: wrong error
58+
59+
[case testExtraneousMultiline]
60+
s: str = 'foo' # E: foo \
61+
# E: bar
62+
63+
[case testExtraneousMultilineNonError]
64+
s: str = 'foo' # W: foo \
65+
# N: bar
66+
67+
[case testOutCorrect]
68+
s: str = 42
69+
[out]
70+
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
71+
72+
[case testOutWrong]
73+
s: str = 42
74+
[out]
75+
main:1: error: foobar
76+
77+
[case testOutWrongIncremental]
78+
s: str = 42
79+
[out]
80+
main:1: error: foobar
81+
[out2]
82+
main:1: error: foobar
83+
84+
[case testWrongMultipleFiles]
85+
import a, b
86+
s: str = 42 # E: foo
87+
[file a.py]
88+
s1: str = 42 # E: bar
89+
[file b.py]
90+
s2: str = 43 # E: baz
91+
[builtins fixtures/list.pyi]
92+
""",
93+
max_attempts=3,
94+
)
95+
96+
# Assert
97+
expected = """
98+
[case testCorrect]
99+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
100+
101+
[case testWrong]
102+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
103+
104+
[case testWrongMultiline]
105+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
106+
107+
[case testMissingMultiline]
108+
s: str = 42; i: int = 'foo' # E: Incompatible types in assignment (expression has type "int", variable has type "str") \\
109+
# E: Incompatible types in assignment (expression has type "str", variable has type "int")
110+
111+
[case testExtraneous]
112+
s: str = 'foo'
113+
114+
[case testExtraneousMultiline]
115+
s: str = 'foo'
116+
117+
[case testExtraneousMultilineNonError]
118+
s: str = 'foo'
119+
120+
[case testOutCorrect]
121+
s: str = 42
122+
[out]
123+
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
124+
125+
[case testOutWrong]
126+
s: str = 42
127+
[out]
128+
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
129+
130+
[case testOutWrongIncremental]
131+
s: str = 42
132+
[out]
133+
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
134+
[out2]
135+
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
136+
137+
[case testWrongMultipleFiles]
138+
import a, b
139+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
140+
[file a.py]
141+
s1: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
142+
[file b.py]
143+
s2: str = 43 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
144+
[builtins fixtures/list.pyi]
145+
"""
146+
assert actual == textwrap.dedent(expected).lstrip()

mypy/test/update.py

Whitespace-only changes.

mypy/test/update_data.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from collections import defaultdict
5+
from typing import Iterator
6+
7+
from mypy.test.data import DataDrivenTestCase, DataFileCollector, DataFileFix, parse_test_data
8+
9+
10+
def update_testcase_output(
11+
testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int
12+
) -> None:
13+
collector = testcase.parent
14+
assert isinstance(collector, DataFileCollector)
15+
for fix in _iter_fixes(testcase, actual, incremental_step=incremental_step):
16+
collector.enqueue_fix(fix)
17+
18+
19+
def _iter_fixes(
20+
testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int
21+
) -> Iterator[DataFileFix]:
22+
reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list)
23+
for error_line in actual:
24+
comment_match = re.match(
25+
r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$",
26+
error_line,
27+
)
28+
if comment_match:
29+
filename = comment_match.group("filename")
30+
lineno = int(comment_match.group("lineno"))
31+
severity = comment_match.group("severity")
32+
msg = comment_match.group("msg")
33+
reports_by_line[filename, lineno].append((severity, msg))
34+
35+
test_items = parse_test_data(testcase.data, testcase.name)
36+
37+
# If we have [out] and/or [outN], we update just those sections.
38+
if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items):
39+
for test_item in test_items:
40+
if (incremental_step < 2 and test_item.id == "out") or (
41+
incremental_step >= 2 and test_item.id == f"out{incremental_step}"
42+
):
43+
yield DataFileFix(
44+
lineno=testcase.line + test_item.line - 1,
45+
end_lineno=testcase.line + test_item.end_line - 1,
46+
lines=actual + [""] * test_item.trimmed_newlines,
47+
)
48+
49+
return
50+
51+
# Update assertion comments within the sections
52+
for test_item in test_items:
53+
if test_item.id == "case":
54+
source_lines = test_item.data
55+
file_path = "main"
56+
elif test_item.id == "file":
57+
source_lines = test_item.data
58+
file_path = f"tmp/{test_item.arg}"
59+
else:
60+
continue # other sections we don't touch
61+
62+
fix_lines = []
63+
for lineno, source_line in enumerate(source_lines, start=1):
64+
reports = reports_by_line.get((file_path, lineno))
65+
comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line)
66+
if comment_match:
67+
source_line = source_line[: comment_match.start("indent")] # strip old comment
68+
if reports:
69+
indent = comment_match.group("indent") if comment_match else " "
70+
# multiline comments are on the first line and then on subsequent lines emtpy lines
71+
# with a continuation backslash
72+
for j, (severity, msg) in enumerate(reports):
73+
out_l = source_line if j == 0 else " " * len(source_line)
74+
is_last = j == len(reports) - 1
75+
severity_char = severity[0].upper()
76+
continuation = "" if is_last else " \\"
77+
fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}")
78+
else:
79+
fix_lines.append(source_line)
80+
81+
yield DataFileFix(
82+
lineno=testcase.line + test_item.line - 1,
83+
end_lineno=testcase.line + test_item.end_line - 1,
84+
lines=fix_lines + [""] * test_item.trimmed_newlines,
85+
)

0 commit comments

Comments
 (0)