Skip to content

ref: Run mypy once and collate messages per-file #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 75 additions & 67 deletions src/pytest_mypy_testing/plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: 2020 David Fritzsche
# SPDX-License-Identifier: Apache-2.0 OR MIT

import importlib.util
import os
import pathlib
import tempfile
from typing import Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union

import mypy.api
import pytest
Expand All @@ -20,11 +20,10 @@
PYTEST_VERSION = pytest.__version__
PYTEST_VERSION_INFO = tuple(int(part) for part in PYTEST_VERSION.split(".")[:3])

have_xdist = importlib.util.find_spec("xdist") is not None


class MypyResult(NamedTuple):
mypy_args: List[str]
returncode: int
output_lines: List[str]
file_messages: List[Message]
non_item_messages: List[Message]

Expand Down Expand Up @@ -55,13 +54,15 @@ def __init__(
self.mypy_item = mypy_item
for mark in self.mypy_item.marks:
self.add_marker(mark)
if have_xdist:
self.add_marker(pytest.mark.xdist_group("mypy"))
Comment on lines +57 to +58
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be done separately, but targets the same overall goal of running mypy once. Without this, pytest-xdist could potentially build a mypy cache for every worker. This puts all mypy jobs on the same worker (provided the user passes --dist loadgroup to pytest), so only one cache is built.


@classmethod
def from_parent(cls, parent, name, mypy_item):
return super().from_parent(parent=parent, name=name, mypy_item=mypy_item)

def runtest(self) -> None:
returncode, actual_messages = self.parent.run_mypy(self.mypy_item)
actual_messages = self.parent.run_mypy(self.mypy_item)

errors = diff_message_sequences(
actual_messages, self.mypy_item.expected_messages
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
)
self.add_marker("mypy")
self.mypy_file = parse_file(self.path, config=config)
self._mypy_result: Optional[MypyResult] = None
COLLECTION.add(self)

@classmethod
def from_parent(cls, parent, **kwargs):
Expand All @@ -131,71 +132,78 @@ def collect(self) -> Iterator[PytestMypyTestItem]:
parent=self, name="[mypy]" + item.name, mypy_item=item
)

def run_mypy(self, item: MypyTestItem) -> Tuple[int, List[Message]]:
if self._mypy_result is None:
self._mypy_result = self._run_mypy(self.path)
return (
self._mypy_result.returncode,
sorted(
item.actual_messages + self._mypy_result.non_item_messages,
key=lambda msg: msg.lineno,
),
def run_mypy(self, item: MypyTestItem) -> List[Message]:
mypy_result = COLLECTION.run_mypy(self)
return sorted(
item.actual_messages + mypy_result.non_item_messages,
key=lambda msg: msg.lineno,
)

def _run_mypy(self, filename: Union[pathlib.Path, os.PathLike, str]) -> MypyResult:
filename = pathlib.Path(filename)
with tempfile.TemporaryDirectory(prefix="pytest-mypy-testing-") as tmp_dir_name:
mypy_cache_dir = os.path.join(tmp_dir_name, "mypy_cache")
os.makedirs(mypy_cache_dir)

mypy_args = [
"--cache-dir={}".format(mypy_cache_dir),
"--check-untyped-defs",
"--hide-error-context",
"--no-color-output",
"--no-error-summary",
"--no-pretty",
"--soft-error-limit=-1",
"--no-silence-site-packages",
"--no-warn-unused-configs",
"--show-column-numbers",
"--show-error-codes",
"--show-traceback",
str(filename),
]

out, err, returncode = mypy.api.run(mypy_args)

lines = (out + err).splitlines()

file_messages = [
msg
for msg in map(Message.from_output, lines)
if (msg.filename == self.mypy_file.filename)
and not (

class MypyFileCollection:
def __init__(self):
self.files: List[PytestMypyFile] = []
self._mypy_results: Optional[Dict[str, MypyResult]] = None

def add(self, file: PytestMypyFile):
self.files.append(file)

def run_mypy(self, file: PytestMypyFile) -> MypyResult:
if self._mypy_results is None:
self._mypy_results = self._run_mypy()
return self._mypy_results[str(file.path)]

def _run_mypy(self) -> Dict[str, MypyResult]:
mypy_args = [
"--cache-dir={}".format(self.files[0].config.cache.mkdir("mypy-cache")),
"--check-untyped-defs",
"--hide-error-context",
"--no-color-output",
"--no-error-summary",
"--no-pretty",
"--soft-error-limit=-1",
"--no-warn-unused-configs",
"--show-column-numbers",
"--show-error-codes",
"--show-traceback",
*(str(file.path) for file in self.files),
]

out, err, returncode = mypy.api.run(mypy_args)

messages_by_file = {}

for line in (out + err).splitlines():
msg = Message.from_output(line)
if msg.filename and not (
msg.severity is Severity.NOTE
and msg.message
== "See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports"
and msg.message.endswith("#missing-imports")
):
messages_by_file.setdefault(msg.filename, []).append(msg)

ret = {}
for file in self.files:
file_messages = messages_by_file.get(str(file.path), [])

non_item_messages = []

for msg in file_messages:
for item in file.mypy_file.items:
if item.lineno <= msg.lineno <= item.end_lineno:
item.actual_messages.append(msg)
break
else:
non_item_messages.append(msg)

ret[str(file.path)] = MypyResult(
file_messages=file_messages,
non_item_messages=non_item_messages,
)
]

non_item_messages = []

for msg in file_messages:
for item in self.mypy_file.items:
if item.lineno <= msg.lineno <= item.end_lineno:
item.actual_messages.append(msg)
break
else:
non_item_messages.append(msg)

return MypyResult(
mypy_args=mypy_args,
returncode=returncode,
output_lines=lines,
file_messages=file_messages,
non_item_messages=non_item_messages,
)
return ret


COLLECTION = MypyFileCollection()


def pytest_collect_file(file_path: pathlib.Path, parent):
Expand Down