From e76f0877ad66af03fc1b0839250d857d5d07f522 Mon Sep 17 00:00:00 2001 From: Ilya Priven Date: Wed, 24 May 2023 22:35:39 -0400 Subject: [PATCH 1/2] squash --- mypy/test/data.py | 42 ++++++++++- mypy/test/helpers.py | 33 -------- mypy/test/testcheck.py | 89 +++++++++++++++++++++- mypy/test/testupdatedata.py | 146 ++++++++++++++++++++++++++++++++++++ 4 files changed, 272 insertions(+), 38 deletions(-) create mode 100644 mypy/test/testupdatedata.py diff --git a/mypy/test/data.py b/mypy/test/data.py index 57d2377241d9..daf815dbdbdc 100644 --- a/mypy/test/data.py +++ b/mypy/test/data.py @@ -11,6 +11,7 @@ import tempfile from abc import abstractmethod from dataclasses import dataclass +from pathlib import Path from typing import Any, Iterator, NamedTuple, Pattern, Union from typing_extensions import Final, TypeAlias as _TypeAlias @@ -426,8 +427,16 @@ class TestItem: id: str arg: str | None + # Processed, collapsed text data data: list[str] + # Start line: 1-based, inclusive, relative to testcase line: int + # End line: 1-based, exclusive, relative to testcase; not same as `line + len(test_item.data)` due to collapsing + end_line: int + + @property + def trimmed_newlines(self) -> int: # compensates for strip_list + return self.end_line - self.line - len(self.data) def parse_test_data(raw_data: str, name: str) -> list[TestItem]: @@ -449,7 +458,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]: if id: data = collapse_line_continuation(data) data = strip_list(data) - ret.append(TestItem(id, arg, strip_list(data), i0 + 1)) + ret.append(TestItem(id, arg, data, i0 + 1, i)) i0 = i id = s[1:-1] @@ -470,7 +479,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]: if id: data = collapse_line_continuation(data) data = strip_list(data) - ret.append(TestItem(id, arg, data, i0 + 1)) + ret.append(TestItem(id, arg, data, i0 + 1, i - 1)) return ret @@ -693,6 +702,12 @@ def collect(self) -> Iterator[DataFileCollector]: yield DataFileCollector.from_parent(parent=self, name=data_file) +class DataFileFix(NamedTuple): + lineno: int # 1-offset, inclusive + end_lineno: int # 1-offset, exclusive + lines: list[str] + + class DataFileCollector(pytest.Collector): """Represents a single `.test` data driven test file. @@ -701,6 +716,8 @@ class DataFileCollector(pytest.Collector): parent: DataSuiteCollector + _fixes: list[DataFileFix] + @classmethod # We have to fight with pytest here: def from_parent( cls, parent: DataSuiteCollector, *, name: str # type: ignore[override] @@ -716,6 +733,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]: file=os.path.join(self.parent.obj.data_prefix, self.name), ) + def setup(self) -> None: + super().setup() + self._fixes = [] + + def teardown(self) -> None: + super().teardown() + self._apply_fixes() + + def enqueue_fix(self, fix: DataFileFix) -> None: + self._fixes.append(fix) + + def _apply_fixes(self) -> None: + if not self._fixes: + return + data_path = Path(self.parent.obj.data_prefix) / self.name + lines = data_path.read_text().split("\n") + # start from end to prevent line offsets from shifting as we update + for fix in sorted(self._fixes, reverse=True): + lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines + data_path.write_text("\n".join(lines)) + def add_test_name_suffix(name: str, suffix: str) -> str: # Find magic suffix of form "-foobar" (used for things like "-skip"). diff --git a/mypy/test/helpers.py b/mypy/test/helpers.py index ca9b02eac805..849ccdc376bd 100644 --- a/mypy/test/helpers.py +++ b/mypy/test/helpers.py @@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str]) ) -def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None: - assert testcase.old_cwd is not None, "test was not properly set up" - testcase_path = os.path.join(testcase.old_cwd, testcase.file) - with open(testcase_path, encoding="utf8") as f: - data_lines = f.read().splitlines() - test = "\n".join(data_lines[testcase.line : testcase.last_line]) - - mapping: dict[str, list[str]] = {} - for old, new in zip(testcase.output, output): - PREFIX = "error:" - ind = old.find(PREFIX) - if ind != -1 and old[:ind] == new[:ind]: - old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :] - mapping.setdefault(old, []).append(new) - - for old in mapping: - if test.count(old) == len(mapping[old]): - betweens = test.split(old) - - # Interleave betweens and mapping[old] - from itertools import chain - - interleaved = [betweens[0]] + list( - chain.from_iterable(zip(mapping[old], betweens[1:])) - ) - test = "".join(interleaved) - - data_lines[testcase.line : testcase.last_line] = [test] - data = "\n".join(data_lines) - with open(testcase_path, "w", encoding="utf8") as f: - print(data, file=f) - - def show_align_message(s1: str, s2: str) -> None: """Align s1 and s2 so that the their first difference is highlighted. diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index fad70945f740..298168352ed3 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -5,6 +5,8 @@ import os import re import sys +from collections import defaultdict +from typing import Iterator from mypy import build from mypy.build import Graph @@ -12,7 +14,15 @@ from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths from mypy.options import TYPE_VAR_TUPLE, UNPACK from mypy.test.config import test_data_prefix, test_temp_dir -from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path +from mypy.test.data import ( + DataDrivenTestCase, + DataFileCollector, + DataFileFix, + DataSuite, + FileOperation, + module_from_path, + parse_test_data, +) from mypy.test.helpers import ( assert_module_equivalence, assert_string_arrays_equal, @@ -22,7 +32,6 @@ normalize_error_messages, parse_options, perform_file_operations, - update_testcase_output, ) try: @@ -192,7 +201,13 @@ def run_case_once( output = testcase.output2.get(incremental_step, []) if output != a and testcase.config.getoption("--update-data", False): - update_testcase_output(testcase, a) + collector = testcase.parent + assert isinstance(collector, DataFileCollector) + for fix in self.iter_data_file_fixes( + testcase, actual=a, incremental_step=incremental_step + ): + collector.enqueue_fix(fix) + assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line)) if res: @@ -226,6 +241,74 @@ def run_case_once( if testcase.output_files: check_test_output_files(testcase, incremental_step, strip_prefix="tmp/") + def iter_data_file_fixes( + self, testcase: DataDrivenTestCase, *, actual: list[str], incremental_step: int + ) -> Iterator[DataFileFix]: + reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list) + for error_line in actual: + comment_match = re.match( + r"^(?P[^:]+):(?P\d+): (?Perror|note|warning): (?P.+)$", + error_line, + ) + if comment_match: + filename = comment_match.group("filename") + lineno = int(comment_match.group("lineno")) + severity = comment_match.group("severity") + msg = comment_match.group("msg") + reports_by_line[filename, lineno].append((severity, msg)) + + test_items = parse_test_data(testcase.data, testcase.name) + + # If we have [out] and/or [outN], we update just those sections. + if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items): + for test_item in test_items: + if (incremental_step < 2 and test_item.id == "out") or ( + incremental_step >= 2 and test_item.id == f"out{incremental_step}" + ): + yield DataFileFix( + lineno=testcase.line + test_item.line - 1, + end_lineno=testcase.line + test_item.end_line - 1, + lines=actual + [""] * test_item.trimmed_newlines, + ) + + return + + # Update assertion comments within the sections + for test_item in test_items: + if test_item.id == "case": + source_lines = test_item.data + file_path = "main" + elif test_item.id == "file": + source_lines = test_item.data + file_path = f"tmp/{test_item.arg}" + else: + continue # other sections we don't touch + + fix_lines = [] + for lineno, source_line in enumerate(source_lines, start=1): + reports = reports_by_line.get((file_path, lineno)) + comment_match = re.search(r"(?P\s+)(?P# [EWN]: .+)$", source_line) + if comment_match: + source_line = source_line[: comment_match.start("indent")] # strip old comment + if reports: + indent = comment_match.group("indent") if comment_match else " " + # multiline comments are on the first line and then on subsequent lines emtpy lines + # with a continuation backslash + for j, (severity, msg) in enumerate(reports): + out_l = source_line if j == 0 else " " * len(source_line) + is_last = j == len(reports) - 1 + severity_char = severity[0].upper() + continuation = "" if is_last else " \\" + fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}") + else: + fix_lines.append(source_line) + + yield DataFileFix( + lineno=testcase.line + test_item.line - 1, + end_lineno=testcase.line + test_item.end_line - 1, + lines=fix_lines + [""] * test_item.trimmed_newlines, + ) + def verify_cache( self, module_data: list[tuple[str, str, str]], diff --git a/mypy/test/testupdatedata.py b/mypy/test/testupdatedata.py new file mode 100644 index 000000000000..2e86c641ce6d --- /dev/null +++ b/mypy/test/testupdatedata.py @@ -0,0 +1,146 @@ +import shlex +import subprocess +import sys +import textwrap +from pathlib import Path + +from mypy.test.config import test_data_prefix +from mypy.test.helpers import Suite + + +class UpdateDataSuite(Suite): + def _run_pytest_update_data(self, data_suite: str, *, max_attempts: int) -> str: + """ + Runs a suite of data test cases through 'pytest --update-data' until either tests pass + or until a maximum number of attempts (needed for incremental tests). + """ + p = Path(test_data_prefix) / "check-update-data.test" + assert not p.exists() + try: + p.write_text(textwrap.dedent(data_suite).lstrip()) + + test_nodeid = f"mypy/test/testcheck.py::TypeCheckSuite::{p.name}" + args = [sys.executable, "-m", "pytest", "-n", "0", "-s", "--update-data", test_nodeid] + if sys.version_info >= (3, 8): + cmd = shlex.join(args) + else: + cmd = " ".join(args) + for i in range(max_attempts - 1, -1, -1): + res = subprocess.run(args) + if res.returncode == 0: + break + print(f"`{cmd}` returned {res.returncode}: {i} attempts remaining") + + return p.read_text() + finally: + p.unlink() + + def test_update_data(self) -> None: + # Note: We test multiple testcases rather than 'test case per test case' + # so we could also exercise rewriting multiple testcases at once. + actual = self._run_pytest_update_data( + """ + [case testCorrect] + s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testWrong] + s: str = 42 # E: wrong error + + [case testWrongMultiline] + s: str = 42 # E: foo \ + # N: bar + + [case testMissingMultiline] + s: str = 42; i: int = 'foo' + + [case testExtraneous] + s: str = 'foo' # E: wrong error + + [case testExtraneousMultiline] + s: str = 'foo' # E: foo \ + # E: bar + + [case testExtraneousMultilineNonError] + s: str = 'foo' # W: foo \ + # N: bar + + [case testOutCorrect] + s: str = 42 + [out] + main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testOutWrong] + s: str = 42 + [out] + main:1: error: foobar + + [case testOutWrongIncremental] + s: str = 42 + [out] + main:1: error: foobar + [out2] + main:1: error: foobar + + [case testWrongMultipleFiles] + import a, b + s: str = 42 # E: foo + [file a.py] + s1: str = 42 # E: bar + [file b.py] + s2: str = 43 # E: baz + [builtins fixtures/list.pyi] + """, + max_attempts=3, + ) + + # Assert + expected = """ + [case testCorrect] + s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testWrong] + s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testWrongMultiline] + s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testMissingMultiline] + s: str = 42; i: int = 'foo' # E: Incompatible types in assignment (expression has type "int", variable has type "str") \\ + # E: Incompatible types in assignment (expression has type "str", variable has type "int") + + [case testExtraneous] + s: str = 'foo' + + [case testExtraneousMultiline] + s: str = 'foo' + + [case testExtraneousMultilineNonError] + s: str = 'foo' + + [case testOutCorrect] + s: str = 42 + [out] + main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testOutWrong] + s: str = 42 + [out] + main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testOutWrongIncremental] + s: str = 42 + [out] + main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str") + [out2] + main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str") + + [case testWrongMultipleFiles] + import a, b + s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + [file a.py] + s1: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + [file b.py] + s2: str = 43 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + [builtins fixtures/list.pyi] + """ + assert actual == textwrap.dedent(expected).lstrip() From cf9e0d642ec7dc0cc6e2bf61bf33c6a969d8d798 Mon Sep 17 00:00:00 2001 From: Ilya Priven Date: Thu, 25 May 2023 00:05:07 -0400 Subject: [PATCH 2/2] factor out into separate file --- mypy/test/testcheck.py | 88 ++-------------------------------------- mypy/test/update.py | 0 mypy/test/update_data.py | 85 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 85 deletions(-) delete mode 100644 mypy/test/update.py create mode 100644 mypy/test/update_data.py diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 298168352ed3..58c0ee803359 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -5,8 +5,6 @@ import os import re import sys -from collections import defaultdict -from typing import Iterator from mypy import build from mypy.build import Graph @@ -14,15 +12,7 @@ from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths from mypy.options import TYPE_VAR_TUPLE, UNPACK from mypy.test.config import test_data_prefix, test_temp_dir -from mypy.test.data import ( - DataDrivenTestCase, - DataFileCollector, - DataFileFix, - DataSuite, - FileOperation, - module_from_path, - parse_test_data, -) +from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path from mypy.test.helpers import ( assert_module_equivalence, assert_string_arrays_equal, @@ -33,6 +23,7 @@ parse_options, perform_file_operations, ) +from mypy.test.update_data import update_testcase_output try: import lxml # type: ignore[import] @@ -201,12 +192,7 @@ def run_case_once( output = testcase.output2.get(incremental_step, []) if output != a and testcase.config.getoption("--update-data", False): - collector = testcase.parent - assert isinstance(collector, DataFileCollector) - for fix in self.iter_data_file_fixes( - testcase, actual=a, incremental_step=incremental_step - ): - collector.enqueue_fix(fix) + update_testcase_output(testcase, a, incremental_step=incremental_step) assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line)) @@ -241,74 +227,6 @@ def run_case_once( if testcase.output_files: check_test_output_files(testcase, incremental_step, strip_prefix="tmp/") - def iter_data_file_fixes( - self, testcase: DataDrivenTestCase, *, actual: list[str], incremental_step: int - ) -> Iterator[DataFileFix]: - reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list) - for error_line in actual: - comment_match = re.match( - r"^(?P[^:]+):(?P\d+): (?Perror|note|warning): (?P.+)$", - error_line, - ) - if comment_match: - filename = comment_match.group("filename") - lineno = int(comment_match.group("lineno")) - severity = comment_match.group("severity") - msg = comment_match.group("msg") - reports_by_line[filename, lineno].append((severity, msg)) - - test_items = parse_test_data(testcase.data, testcase.name) - - # If we have [out] and/or [outN], we update just those sections. - if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items): - for test_item in test_items: - if (incremental_step < 2 and test_item.id == "out") or ( - incremental_step >= 2 and test_item.id == f"out{incremental_step}" - ): - yield DataFileFix( - lineno=testcase.line + test_item.line - 1, - end_lineno=testcase.line + test_item.end_line - 1, - lines=actual + [""] * test_item.trimmed_newlines, - ) - - return - - # Update assertion comments within the sections - for test_item in test_items: - if test_item.id == "case": - source_lines = test_item.data - file_path = "main" - elif test_item.id == "file": - source_lines = test_item.data - file_path = f"tmp/{test_item.arg}" - else: - continue # other sections we don't touch - - fix_lines = [] - for lineno, source_line in enumerate(source_lines, start=1): - reports = reports_by_line.get((file_path, lineno)) - comment_match = re.search(r"(?P\s+)(?P# [EWN]: .+)$", source_line) - if comment_match: - source_line = source_line[: comment_match.start("indent")] # strip old comment - if reports: - indent = comment_match.group("indent") if comment_match else " " - # multiline comments are on the first line and then on subsequent lines emtpy lines - # with a continuation backslash - for j, (severity, msg) in enumerate(reports): - out_l = source_line if j == 0 else " " * len(source_line) - is_last = j == len(reports) - 1 - severity_char = severity[0].upper() - continuation = "" if is_last else " \\" - fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}") - else: - fix_lines.append(source_line) - - yield DataFileFix( - lineno=testcase.line + test_item.line - 1, - end_lineno=testcase.line + test_item.end_line - 1, - lines=fix_lines + [""] * test_item.trimmed_newlines, - ) - def verify_cache( self, module_data: list[tuple[str, str, str]], diff --git a/mypy/test/update.py b/mypy/test/update.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/mypy/test/update_data.py b/mypy/test/update_data.py new file mode 100644 index 000000000000..b37b440de38f --- /dev/null +++ b/mypy/test/update_data.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Iterator + +from mypy.test.data import DataDrivenTestCase, DataFileCollector, DataFileFix, parse_test_data + + +def update_testcase_output( + testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int +) -> None: + collector = testcase.parent + assert isinstance(collector, DataFileCollector) + for fix in _iter_fixes(testcase, actual, incremental_step=incremental_step): + collector.enqueue_fix(fix) + + +def _iter_fixes( + testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int +) -> Iterator[DataFileFix]: + reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list) + for error_line in actual: + comment_match = re.match( + r"^(?P[^:]+):(?P\d+): (?Perror|note|warning): (?P.+)$", + error_line, + ) + if comment_match: + filename = comment_match.group("filename") + lineno = int(comment_match.group("lineno")) + severity = comment_match.group("severity") + msg = comment_match.group("msg") + reports_by_line[filename, lineno].append((severity, msg)) + + test_items = parse_test_data(testcase.data, testcase.name) + + # If we have [out] and/or [outN], we update just those sections. + if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items): + for test_item in test_items: + if (incremental_step < 2 and test_item.id == "out") or ( + incremental_step >= 2 and test_item.id == f"out{incremental_step}" + ): + yield DataFileFix( + lineno=testcase.line + test_item.line - 1, + end_lineno=testcase.line + test_item.end_line - 1, + lines=actual + [""] * test_item.trimmed_newlines, + ) + + return + + # Update assertion comments within the sections + for test_item in test_items: + if test_item.id == "case": + source_lines = test_item.data + file_path = "main" + elif test_item.id == "file": + source_lines = test_item.data + file_path = f"tmp/{test_item.arg}" + else: + continue # other sections we don't touch + + fix_lines = [] + for lineno, source_line in enumerate(source_lines, start=1): + reports = reports_by_line.get((file_path, lineno)) + comment_match = re.search(r"(?P\s+)(?P# [EWN]: .+)$", source_line) + if comment_match: + source_line = source_line[: comment_match.start("indent")] # strip old comment + if reports: + indent = comment_match.group("indent") if comment_match else " " + # multiline comments are on the first line and then on subsequent lines emtpy lines + # with a continuation backslash + for j, (severity, msg) in enumerate(reports): + out_l = source_line if j == 0 else " " * len(source_line) + is_last = j == len(reports) - 1 + severity_char = severity[0].upper() + continuation = "" if is_last else " \\" + fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}") + else: + fix_lines.append(source_line) + + yield DataFileFix( + lineno=testcase.line + test_item.line - 1, + end_lineno=testcase.line + test_item.end_line - 1, + lines=fix_lines + [""] * test_item.trimmed_newlines, + )