Skip to content

Commit e76f087

Browse files
committed
squash
1 parent b8dd40e commit e76f087

File tree

4 files changed

+272
-38
lines changed

4 files changed

+272
-38
lines changed

mypy/test/data.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
from abc import abstractmethod
1313
from dataclasses import dataclass
14+
from pathlib import Path
1415
from typing import Any, Iterator, NamedTuple, Pattern, Union
1516
from typing_extensions import Final, TypeAlias as _TypeAlias
1617

@@ -426,8 +427,16 @@ class TestItem:
426427

427428
id: str
428429
arg: str | None
430+
# Processed, collapsed text data
429431
data: list[str]
432+
# Start line: 1-based, inclusive, relative to testcase
430433
line: int
434+
# End line: 1-based, exclusive, relative to testcase; not same as `line + len(test_item.data)` due to collapsing
435+
end_line: int
436+
437+
@property
438+
def trimmed_newlines(self) -> int: # compensates for strip_list
439+
return self.end_line - self.line - len(self.data)
431440

432441

433442
def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
@@ -449,7 +458,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
449458
if id:
450459
data = collapse_line_continuation(data)
451460
data = strip_list(data)
452-
ret.append(TestItem(id, arg, strip_list(data), i0 + 1))
461+
ret.append(TestItem(id, arg, data, i0 + 1, i))
453462

454463
i0 = i
455464
id = s[1:-1]
@@ -470,7 +479,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
470479
if id:
471480
data = collapse_line_continuation(data)
472481
data = strip_list(data)
473-
ret.append(TestItem(id, arg, data, i0 + 1))
482+
ret.append(TestItem(id, arg, data, i0 + 1, i - 1))
474483

475484
return ret
476485

@@ -693,6 +702,12 @@ def collect(self) -> Iterator[DataFileCollector]:
693702
yield DataFileCollector.from_parent(parent=self, name=data_file)
694703

695704

705+
class DataFileFix(NamedTuple):
706+
lineno: int # 1-offset, inclusive
707+
end_lineno: int # 1-offset, exclusive
708+
lines: list[str]
709+
710+
696711
class DataFileCollector(pytest.Collector):
697712
"""Represents a single `.test` data driven test file.
698713
@@ -701,6 +716,8 @@ class DataFileCollector(pytest.Collector):
701716

702717
parent: DataSuiteCollector
703718

719+
_fixes: list[DataFileFix]
720+
704721
@classmethod # We have to fight with pytest here:
705722
def from_parent(
706723
cls, parent: DataSuiteCollector, *, name: str # type: ignore[override]
@@ -716,6 +733,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]:
716733
file=os.path.join(self.parent.obj.data_prefix, self.name),
717734
)
718735

736+
def setup(self) -> None:
737+
super().setup()
738+
self._fixes = []
739+
740+
def teardown(self) -> None:
741+
super().teardown()
742+
self._apply_fixes()
743+
744+
def enqueue_fix(self, fix: DataFileFix) -> None:
745+
self._fixes.append(fix)
746+
747+
def _apply_fixes(self) -> None:
748+
if not self._fixes:
749+
return
750+
data_path = Path(self.parent.obj.data_prefix) / self.name
751+
lines = data_path.read_text().split("\n")
752+
# start from end to prevent line offsets from shifting as we update
753+
for fix in sorted(self._fixes, reverse=True):
754+
lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines
755+
data_path.write_text("\n".join(lines))
756+
719757

720758
def add_test_name_suffix(name: str, suffix: str) -> str:
721759
# Find magic suffix of form "-foobar" (used for things like "-skip").

mypy/test/helpers.py

-33
Original file line numberDiff line numberDiff line change
@@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str])
141141
)
142142

143143

144-
def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
145-
assert testcase.old_cwd is not None, "test was not properly set up"
146-
testcase_path = os.path.join(testcase.old_cwd, testcase.file)
147-
with open(testcase_path, encoding="utf8") as f:
148-
data_lines = f.read().splitlines()
149-
test = "\n".join(data_lines[testcase.line : testcase.last_line])
150-
151-
mapping: dict[str, list[str]] = {}
152-
for old, new in zip(testcase.output, output):
153-
PREFIX = "error:"
154-
ind = old.find(PREFIX)
155-
if ind != -1 and old[:ind] == new[:ind]:
156-
old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :]
157-
mapping.setdefault(old, []).append(new)
158-
159-
for old in mapping:
160-
if test.count(old) == len(mapping[old]):
161-
betweens = test.split(old)
162-
163-
# Interleave betweens and mapping[old]
164-
from itertools import chain
165-
166-
interleaved = [betweens[0]] + list(
167-
chain.from_iterable(zip(mapping[old], betweens[1:]))
168-
)
169-
test = "".join(interleaved)
170-
171-
data_lines[testcase.line : testcase.last_line] = [test]
172-
data = "\n".join(data_lines)
173-
with open(testcase_path, "w", encoding="utf8") as f:
174-
print(data, file=f)
175-
176-
177144
def show_align_message(s1: str, s2: str) -> None:
178145
"""Align s1 and s2 so that the their first difference is highlighted.
179146

mypy/test/testcheck.py

+86-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,24 @@
55
import os
66
import re
77
import sys
8+
from collections import defaultdict
9+
from typing import Iterator
810

911
from mypy import build
1012
from mypy.build import Graph
1113
from mypy.errors import CompileError
1214
from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths
1315
from mypy.options import TYPE_VAR_TUPLE, UNPACK
1416
from mypy.test.config import test_data_prefix, test_temp_dir
15-
from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path
17+
from mypy.test.data import (
18+
DataDrivenTestCase,
19+
DataFileCollector,
20+
DataFileFix,
21+
DataSuite,
22+
FileOperation,
23+
module_from_path,
24+
parse_test_data,
25+
)
1626
from mypy.test.helpers import (
1727
assert_module_equivalence,
1828
assert_string_arrays_equal,
@@ -22,7 +32,6 @@
2232
normalize_error_messages,
2333
parse_options,
2434
perform_file_operations,
25-
update_testcase_output,
2635
)
2736

2837
try:
@@ -192,7 +201,13 @@ def run_case_once(
192201
output = testcase.output2.get(incremental_step, [])
193202

194203
if output != a and testcase.config.getoption("--update-data", False):
195-
update_testcase_output(testcase, a)
204+
collector = testcase.parent
205+
assert isinstance(collector, DataFileCollector)
206+
for fix in self.iter_data_file_fixes(
207+
testcase, actual=a, incremental_step=incremental_step
208+
):
209+
collector.enqueue_fix(fix)
210+
196211
assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))
197212

198213
if res:
@@ -226,6 +241,74 @@ def run_case_once(
226241
if testcase.output_files:
227242
check_test_output_files(testcase, incremental_step, strip_prefix="tmp/")
228243

244+
def iter_data_file_fixes(
245+
self, testcase: DataDrivenTestCase, *, actual: list[str], incremental_step: int
246+
) -> Iterator[DataFileFix]:
247+
reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list)
248+
for error_line in actual:
249+
comment_match = re.match(
250+
r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$",
251+
error_line,
252+
)
253+
if comment_match:
254+
filename = comment_match.group("filename")
255+
lineno = int(comment_match.group("lineno"))
256+
severity = comment_match.group("severity")
257+
msg = comment_match.group("msg")
258+
reports_by_line[filename, lineno].append((severity, msg))
259+
260+
test_items = parse_test_data(testcase.data, testcase.name)
261+
262+
# If we have [out] and/or [outN], we update just those sections.
263+
if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items):
264+
for test_item in test_items:
265+
if (incremental_step < 2 and test_item.id == "out") or (
266+
incremental_step >= 2 and test_item.id == f"out{incremental_step}"
267+
):
268+
yield DataFileFix(
269+
lineno=testcase.line + test_item.line - 1,
270+
end_lineno=testcase.line + test_item.end_line - 1,
271+
lines=actual + [""] * test_item.trimmed_newlines,
272+
)
273+
274+
return
275+
276+
# Update assertion comments within the sections
277+
for test_item in test_items:
278+
if test_item.id == "case":
279+
source_lines = test_item.data
280+
file_path = "main"
281+
elif test_item.id == "file":
282+
source_lines = test_item.data
283+
file_path = f"tmp/{test_item.arg}"
284+
else:
285+
continue # other sections we don't touch
286+
287+
fix_lines = []
288+
for lineno, source_line in enumerate(source_lines, start=1):
289+
reports = reports_by_line.get((file_path, lineno))
290+
comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line)
291+
if comment_match:
292+
source_line = source_line[: comment_match.start("indent")] # strip old comment
293+
if reports:
294+
indent = comment_match.group("indent") if comment_match else " "
295+
# multiline comments are on the first line and then on subsequent lines emtpy lines
296+
# with a continuation backslash
297+
for j, (severity, msg) in enumerate(reports):
298+
out_l = source_line if j == 0 else " " * len(source_line)
299+
is_last = j == len(reports) - 1
300+
severity_char = severity[0].upper()
301+
continuation = "" if is_last else " \\"
302+
fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}")
303+
else:
304+
fix_lines.append(source_line)
305+
306+
yield DataFileFix(
307+
lineno=testcase.line + test_item.line - 1,
308+
end_lineno=testcase.line + test_item.end_line - 1,
309+
lines=fix_lines + [""] * test_item.trimmed_newlines,
310+
)
311+
229312
def verify_cache(
230313
self,
231314
module_data: list[tuple[str, str, str]],

0 commit comments

Comments
 (0)