Skip to content

Commit ea590fb

Browse files
committed
Improve --update-data handler
1 parent 6c7e480 commit ea590fb

File tree

4 files changed

+177
-36
lines changed

4 files changed

+177
-36
lines changed

mypy/test/data.py

+30
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import tempfile
1212
from abc import abstractmethod
13+
from pathlib import Path
1314
from typing import Any, Iterator, NamedTuple, Pattern, Union
1415
from typing_extensions import Final, TypeAlias as _TypeAlias
1516

@@ -698,6 +699,12 @@ def collect(self) -> Iterator[DataFileCollector]:
698699
yield DataFileCollector.from_parent(parent=self, name=data_file)
699700

700701

702+
class DataFileFix(NamedTuple):
703+
lineno: int # 1-offset, inclusive
704+
end_lineno: int # 1-offset, exclusive
705+
lines: list[str]
706+
707+
701708
class DataFileCollector(pytest.Collector):
702709
"""Represents a single `.test` data driven test file.
703710
@@ -706,6 +713,8 @@ class DataFileCollector(pytest.Collector):
706713

707714
parent: DataSuiteCollector
708715

716+
_fixes: list[DataFileFix]
717+
709718
@classmethod # We have to fight with pytest here:
710719
def from_parent(
711720
cls, parent: DataSuiteCollector, *, name: str # type: ignore[override]
@@ -721,6 +730,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]:
721730
file=os.path.join(self.parent.obj.data_prefix, self.name),
722731
)
723732

733+
def setup(self) -> None:
734+
super().setup()
735+
self._fixes = []
736+
737+
def teardown(self) -> None:
738+
super().teardown()
739+
self._apply_fixes()
740+
741+
def enqueue_fix(self, fix: DataFileFix) -> None:
742+
self._fixes.append(fix)
743+
744+
def _apply_fixes(self) -> None:
745+
if not self._fixes:
746+
return
747+
data_path = Path(self.parent.obj.data_prefix) / self.name
748+
lines = data_path.read_text().split("\n")
749+
# start from end to prevent line offsets from shifting as we update
750+
for fix in sorted(self._fixes, reverse=True):
751+
lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines
752+
data_path.write_text("\n".join(lines))
753+
724754

725755
def add_test_name_suffix(name: str, suffix: str) -> str:
726756
# Find magic suffix of form "-foobar" (used for things like "-skip").

mypy/test/helpers.py

-33
Original file line numberDiff line numberDiff line change
@@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str])
141141
)
142142

143143

144-
def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
145-
assert testcase.old_cwd is not None, "test was not properly set up"
146-
testcase_path = os.path.join(testcase.old_cwd, testcase.file)
147-
with open(testcase_path, encoding="utf8") as f:
148-
data_lines = f.read().splitlines()
149-
test = "\n".join(data_lines[testcase.line : testcase.last_line])
150-
151-
mapping: dict[str, list[str]] = {}
152-
for old, new in zip(testcase.output, output):
153-
PREFIX = "error:"
154-
ind = old.find(PREFIX)
155-
if ind != -1 and old[:ind] == new[:ind]:
156-
old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :]
157-
mapping.setdefault(old, []).append(new)
158-
159-
for old in mapping:
160-
if test.count(old) == len(mapping[old]):
161-
betweens = test.split(old)
162-
163-
# Interleave betweens and mapping[old]
164-
from itertools import chain
165-
166-
interleaved = [betweens[0]] + list(
167-
chain.from_iterable(zip(mapping[old], betweens[1:]))
168-
)
169-
test = "".join(interleaved)
170-
171-
data_lines[testcase.line : testcase.last_line] = [test]
172-
data = "\n".join(data_lines)
173-
with open(testcase_path, "w", encoding="utf8") as f:
174-
print(data, file=f)
175-
176-
177144
def show_align_message(s1: str, s2: str) -> None:
178145
"""Align s1 and s2 so that the their first difference is highlighted.
179146

mypy/test/testcheck.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,24 @@
55
import os
66
import re
77
import sys
8+
from collections import defaultdict
9+
from typing import Iterator
810

911
from mypy import build
1012
from mypy.build import Graph
1113
from mypy.errors import CompileError
1214
from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths
1315
from mypy.options import TYPE_VAR_TUPLE, UNPACK
1416
from mypy.test.config import test_data_prefix, test_temp_dir
15-
from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path
17+
from mypy.test.data import (
18+
DataDrivenTestCase,
19+
DataFileCollector,
20+
DataFileFix,
21+
DataSuite,
22+
FileOperation,
23+
module_from_path,
24+
parse_test_data,
25+
)
1626
from mypy.test.helpers import (
1727
assert_module_equivalence,
1828
assert_string_arrays_equal,
@@ -22,7 +32,6 @@
2232
normalize_error_messages,
2333
parse_options,
2434
perform_file_operations,
25-
update_testcase_output,
2635
)
2736

2837
try:
@@ -180,7 +189,11 @@ def run_case_once(
180189
raise AssertionError()
181190

182191
if output != a and testcase.config.getoption("--update-data", False):
183-
update_testcase_output(testcase, a)
192+
collector = testcase.parent
193+
assert isinstance(collector, DataFileCollector)
194+
for fix in self.iter_data_file_fixes(a, testcase):
195+
collector.enqueue_fix(fix)
196+
184197
assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))
185198

186199
if res:
@@ -214,6 +227,55 @@ def run_case_once(
214227
if testcase.output_files:
215228
check_test_output_files(testcase, incremental_step, strip_prefix="tmp/")
216229

230+
def iter_data_file_fixes(
231+
self, actual: list[str], testcase: DataDrivenTestCase
232+
) -> Iterator[DataFileFix]:
233+
reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list)
234+
for error_line in actual:
235+
comment_match = re.match(
236+
r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$",
237+
error_line,
238+
)
239+
if comment_match:
240+
filename = comment_match.group("filename")
241+
lineno = int(comment_match.group("lineno")) - 1
242+
severity = comment_match.group("severity")
243+
msg = comment_match.group("msg")
244+
reports_by_line[filename, lineno].append((severity, msg))
245+
246+
for item in parse_test_data(testcase.data, testcase.name):
247+
if item.id == "case":
248+
source_lines = item.data
249+
file_path = "main"
250+
elif item.id == "file":
251+
source_lines = item.data
252+
file_path = f"tmp/{item.arg}"
253+
else:
254+
continue # other sections we don't touch
255+
256+
fix_lines = []
257+
for lineno, source_line in enumerate(source_lines):
258+
reports = reports_by_line.get((file_path, lineno))
259+
comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line)
260+
if comment_match:
261+
source_line = source_line[: comment_match.start("indent")] # strip old comment
262+
if reports:
263+
indent = comment_match.group("indent") if comment_match else " "
264+
# multiline comments are on the first line and then on subsequent lines emtpy lines
265+
# with a continuation backslash
266+
for j, (severity, msg) in enumerate(reports):
267+
out_l = source_line if j == 0 else " " * len(source_line)
268+
is_last = j == len(reports) - 1
269+
severity_char = severity[0].upper()
270+
continuation = "" if is_last else " \\"
271+
fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}")
272+
else:
273+
fix_lines.append(source_line)
274+
275+
lineno = testcase.line + item.line - 1 # both testcase and item are 1-offset
276+
end_lineno = lineno + len(item.data)
277+
yield DataFileFix(lineno, end_lineno, fix_lines)
278+
217279
def verify_cache(
218280
self,
219281
module_data: list[tuple[str, str, str]],

mypy/test/testupdatedata.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import subprocess
2+
import textwrap
3+
from os.path import basename
4+
from tempfile import NamedTemporaryFile
5+
6+
from mypy.test.config import test_data_prefix
7+
from mypy.test.helpers import Suite
8+
9+
10+
class UpdateDataSuite(Suite):
11+
def _update_test(self, testcase: str) -> str:
12+
with NamedTemporaryFile(
13+
mode="w+", dir=test_data_prefix, prefix="check-update-data-", suffix=".test"
14+
) as tmp_file:
15+
tmp_file.write(textwrap.dedent(testcase))
16+
tmp_file.flush()
17+
test_nodeid = f"mypy/test/testcheck.py::TypeCheckSuite::{basename(tmp_file.name)}"
18+
res = subprocess.run(
19+
["pytest", "-n", "0", "-s", "--update-data", test_nodeid], capture_output=True
20+
)
21+
assert res.returncode == 1
22+
tmp_file.seek(0)
23+
return tmp_file.read()
24+
25+
def test_update_data(self) -> None:
26+
actual = self._update_test(
27+
"""
28+
[case testCorrect]
29+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
30+
[case testWrong]
31+
s: str = 42 # E: wrong error
32+
[case testMissingMultiline]
33+
s: str = 42; i: int = 'foo'
34+
[case testExtraneous]
35+
s: str = 'foo' # E: wrong error
36+
[case testExtraneousMultiline]
37+
s: str = 'foo' # E: foo \
38+
# E: bar
39+
[case testExtraneousMultilineNonError]
40+
s: str = 'foo' # W: foo \
41+
# N: bar
42+
[case testWrongMultipleFiles]
43+
import a, b
44+
s: str = 42 # E: foo
45+
[file a.py]
46+
s1: str = 42 # E: bar
47+
[file b.py]
48+
s2: str = 43 # E: baz
49+
[out]
50+
make sure we're not touching this
51+
[builtins fixtures/list.pyi]
52+
"""
53+
)
54+
55+
# Assert
56+
assert actual == textwrap.dedent(
57+
"""
58+
[case testCorrect]
59+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
60+
[case testWrong]
61+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
62+
[case testMissingMultiline]
63+
s: str = 42; i: int = 'foo' # E: Incompatible types in assignment (expression has type "int", variable has type "str") \\
64+
# E: Incompatible types in assignment (expression has type "str", variable has type "int")
65+
[case testExtraneous]
66+
s: str = 'foo'
67+
[case testExtraneousMultiline]
68+
s: str = 'foo'
69+
[case testExtraneousMultilineNonError]
70+
s: str = 'foo'
71+
[case testWrongMultipleFiles]
72+
import a, b
73+
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
74+
[file a.py]
75+
s1: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
76+
[file b.py]
77+
s2: str = 43 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
78+
[out]
79+
make sure we're not touching this
80+
[builtins fixtures/list.pyi]
81+
"""
82+
)

0 commit comments

Comments
 (0)