|
5 | 5 | import os
|
6 | 6 | import re
|
7 | 7 | import sys
|
| 8 | +from collections import defaultdict |
| 9 | +from typing import Iterator |
8 | 10 |
|
9 | 11 | from mypy import build
|
10 | 12 | from mypy.build import Graph
|
11 | 13 | from mypy.errors import CompileError
|
12 | 14 | from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths
|
13 | 15 | from mypy.options import TYPE_VAR_TUPLE, UNPACK
|
14 | 16 | from mypy.test.config import test_data_prefix, test_temp_dir
|
15 |
| -from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path |
| 17 | +from mypy.test.data import ( |
| 18 | + DataDrivenTestCase, |
| 19 | + DataFileCollector, |
| 20 | + DataFileFix, |
| 21 | + DataSuite, |
| 22 | + FileOperation, |
| 23 | + module_from_path, |
| 24 | + parse_test_data, |
| 25 | +) |
16 | 26 | from mypy.test.helpers import (
|
17 | 27 | assert_module_equivalence,
|
18 | 28 | assert_string_arrays_equal,
|
|
22 | 32 | normalize_error_messages,
|
23 | 33 | parse_options,
|
24 | 34 | perform_file_operations,
|
25 |
| - update_testcase_output, |
26 | 35 | )
|
27 | 36 |
|
28 | 37 | try:
|
@@ -180,7 +189,11 @@ def run_case_once(
|
180 | 189 | raise AssertionError()
|
181 | 190 |
|
182 | 191 | if output != a and testcase.config.getoption("--update-data", False):
|
183 |
| - update_testcase_output(testcase, a) |
| 192 | + collector = testcase.parent |
| 193 | + assert isinstance(collector, DataFileCollector) |
| 194 | + for fix in self.iter_data_file_fixes(a, testcase): |
| 195 | + collector.enqueue_fix(fix) |
| 196 | + |
184 | 197 | assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))
|
185 | 198 |
|
186 | 199 | if res:
|
@@ -214,6 +227,55 @@ def run_case_once(
|
214 | 227 | if testcase.output_files:
|
215 | 228 | check_test_output_files(testcase, incremental_step, strip_prefix="tmp/")
|
216 | 229 |
|
| 230 | + def iter_data_file_fixes( |
| 231 | + self, actual: list[str], testcase: DataDrivenTestCase |
| 232 | + ) -> Iterator[DataFileFix]: |
| 233 | + reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list) |
| 234 | + for error_line in actual: |
| 235 | + comment_match = re.match( |
| 236 | + r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$", |
| 237 | + error_line, |
| 238 | + ) |
| 239 | + if comment_match: |
| 240 | + filename = comment_match.group("filename") |
| 241 | + lineno = int(comment_match.group("lineno")) - 1 |
| 242 | + severity = comment_match.group("severity") |
| 243 | + msg = comment_match.group("msg") |
| 244 | + reports_by_line[filename, lineno].append((severity, msg)) |
| 245 | + |
| 246 | + for item in parse_test_data(testcase.data, testcase.name): |
| 247 | + if item.id == "case": |
| 248 | + source_lines = item.data |
| 249 | + file_path = "main" |
| 250 | + elif item.id == "file": |
| 251 | + source_lines = item.data |
| 252 | + file_path = f"tmp/{item.arg}" |
| 253 | + else: |
| 254 | + continue # other sections we don't touch |
| 255 | + |
| 256 | + fix_lines = [] |
| 257 | + for lineno, source_line in enumerate(source_lines): |
| 258 | + reports = reports_by_line.get((file_path, lineno)) |
| 259 | + comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line) |
| 260 | + if comment_match: |
| 261 | + source_line = source_line[: comment_match.start("indent")] # strip old comment |
| 262 | + if reports: |
| 263 | + indent = comment_match.group("indent") if comment_match else " " |
| 264 | + # multiline comments are on the first line and then on subsequent lines emtpy lines |
| 265 | + # with a continuation backslash |
| 266 | + for j, (severity, msg) in enumerate(reports): |
| 267 | + out_l = source_line if j == 0 else " " * len(source_line) |
| 268 | + is_last = j == len(reports) - 1 |
| 269 | + severity_char = severity[0].upper() |
| 270 | + continuation = "" if is_last else " \\" |
| 271 | + fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}") |
| 272 | + else: |
| 273 | + fix_lines.append(source_line) |
| 274 | + |
| 275 | + lineno = testcase.line + item.line - 1 # both testcase and item are 1-offset |
| 276 | + end_lineno = lineno + len(item.data) |
| 277 | + yield DataFileFix(lineno, end_lineno, fix_lines) |
| 278 | + |
217 | 279 | def verify_cache(
|
218 | 280 | self,
|
219 | 281 | module_data: list[tuple[str, str, str]],
|
|
0 commit comments