diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index f71cf80e..7fd1dff9 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -7,7 +7,6 @@ import sys import time import warnings -from importlib import util as importlib_util from pathlib import Path from typing import Any from typing import Generator @@ -28,6 +27,7 @@ from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import count_outcomes from _pytask.path import find_case_sensitive_path +from _pytask.path import import_path from _pytask.report import CollectionReport from _pytask.session import Session from _pytask.shared import find_duplicates @@ -121,13 +121,7 @@ def pytask_collect_file( ) -> list[CollectionReport] | None: """Collect a file.""" if any(path.match(pattern) for pattern in session.config["task_files"]): - spec = importlib_util.spec_from_file_location(path.stem, str(path)) - - if spec is None: - raise ImportError(f"Can't find module {path.stem!r} at location {path}.") - - mod = importlib_util.module_from_spec(spec) - spec.loader.exec_module(mod) + mod = import_path(path, session.config["root"]) collected_reports = [] for name, obj in inspect.getmembers(mod): diff --git a/src/_pytask/path.py b/src/_pytask/path.py index 71f8f99e..3e7864c9 100644 --- a/src/_pytask/path.py +++ b/src/_pytask/path.py @@ -2,8 +2,11 @@ from __future__ import annotations import functools +import importlib.util import os +import sys from pathlib import Path +from types import ModuleType from typing import Sequence @@ -120,3 +123,76 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path: """ out = path.resolve() if platform == "win32" else path return out + + +def import_path(path: Path, root: Path) -> ModuleType: + """Import and return a module from the given path. + + The function is taken from pytest when the import mode is set to ``importlib``. It + pytest's recommended import mode for new projects although the default is set to + ``prepend``. More discussion and information can be found in :gh:`373`. + + """ + module_name = _module_name_from_path(path, root) + + spec = importlib.util.spec_from_file_location(module_name, str(path)) + + if spec is None: + raise ImportError(f"Can't find module {module_name!r} at location {path}.") + + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + _insert_missing_modules(sys.modules, module_name) + return mod + + +def _module_name_from_path(path: Path, root: Path) -> str: + """Return a dotted module name based on the given path, anchored on root. + + For example: path="projects/src/project/task_foo.py" and root="/projects", the + resulting module name will be "src.project.task_foo". + + """ + path = path.with_suffix("") + try: + relative_path = path.relative_to(root) + except ValueError: + # If we can't get a relative path to root, use the full path, except for the + # first part ("d:\\" or "/" depending on the platform, for example). + path_parts = path.parts[1:] + else: + # Use the parts for the relative path to the root path. + path_parts = relative_path.parts + + return ".".join(path_parts) + + +def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None: + """Insert missing modules when importing modules with :func:`import_path`. + + When we want to import a module as ``src.project.task_foo`` for example, we need to + create empty modules ``src`` and ``src.project`` after inserting + ``src.project.task_foo``, otherwise ``src.project.task_foo`` is not importable by + ``__import__``. + + """ + module_parts = module_name.split(".") + while module_name: + if module_name not in modules: + try: + # If sys.meta_path is empty, calling import_module will issue a warning + # and raise ModuleNotFoundError. To avoid the warning, we check + # sys.meta_path explicitly and raise the error ourselves to fall back to + # creating a dummy module. + if not sys.meta_path: + raise ModuleNotFoundError + importlib.import_module(module_name) + except ModuleNotFoundError: + module = ModuleType( + module_name, + doc="Empty module created by pytask.", + ) + modules[module_name] = module + module_parts.pop(-1) + module_name = ".".join(module_parts) diff --git a/tests/test_collect.py b/tests/test_collect.py index 7f114921..9a10831a 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -37,6 +37,46 @@ def task_write_text(depends_on, produces): assert tmp_path.joinpath("out.txt").read_text() == "Relative paths work." +@pytest.mark.end_to_end() +def test_collect_tasks_from_modules_with_the_same_name(tmp_path): + """We need to check that task modules can have the same name. See #373 and #374.""" + tmp_path.joinpath("a").mkdir() + tmp_path.joinpath("b").mkdir() + tmp_path.joinpath("a", "task_module.py").write_text("def task_a(): pass") + tmp_path.joinpath("b", "task_module.py").write_text("def task_a(): pass") + session = main({"paths": tmp_path}) + assert len(session.collection_reports) == 2 + assert all( + report.outcome == CollectionOutcome.SUCCESS + for report in session.collection_reports + ) + assert { + report.node.function.__module__ for report in session.collection_reports + } == {"a.task_module", "b.task_module"} + + +@pytest.mark.end_to_end() +def test_collect_module_name(tmp_path): + """We need to add a task module to the sys.modules. See #373 and #374.""" + source = """ + # without this import, everything works fine + from __future__ import annotations + + import dataclasses + + @dataclasses.dataclass + class Data: + x: int + + def task_my_task(): + pass + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + session = main({"paths": tmp_path}) + outcome = session.collection_reports[0].outcome + assert outcome == CollectionOutcome.SUCCESS + + @pytest.mark.end_to_end() def test_collect_filepathnode_with_unknown_type(tmp_path): """If a node cannot be parsed because unknown type, raise an error.""" diff --git a/tests/test_debugging.py b/tests/test_debugging.py index 229b510c..41a15f84 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -365,6 +365,7 @@ def task_1(): x = 3 print("hello18") assert count_continue == 2, "unexpected_failure: %d != 2" % count_continue + raise Exception("expected_failure") """ tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) @@ -399,7 +400,8 @@ def task_1(): assert "1" in rest assert "failed" in rest assert "Failed" in rest - assert "AssertionError: unexpected_failure" in rest + assert "AssertionError: unexpected_failure" not in rest + assert "expected_failure" in rest _flush(child) diff --git a/tests/test_path.py b/tests/test_path.py index 403aee9c..58737a47 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -1,15 +1,21 @@ from __future__ import annotations import sys +import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path from pathlib import PurePosixPath from pathlib import PureWindowsPath +from types import ModuleType +from typing import Any import pytest +from _pytask.path import _insert_missing_modules +from _pytask.path import _module_name_from_path from _pytask.path import find_case_sensitive_path from _pytask.path import find_closest_ancestor from _pytask.path import find_common_ancestor +from _pytask.path import import_path from _pytask.path import relative_to @@ -117,3 +123,182 @@ def test_find_case_sensitive_path(tmp_path, path, existing_paths, expected): result = find_case_sensitive_path(tmp_path / path, sys.platform) assert result == tmp_path / expected + + +@pytest.fixture() +def simple_module(tmp_path: Path) -> Path: + fn = tmp_path / "_src/project/mymod.py" + fn.parent.mkdir(parents=True) + fn.write_text("def foo(x): return 40 + x") + return fn + + +def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None: + """`importlib` mode does not change sys.path.""" + module = import_path(simple_module, root=tmp_path) + assert module.foo(2) == 42 # type: ignore[attr-defined] + assert str(simple_module.parent) not in sys.path + assert module.__name__ in sys.modules + assert module.__name__ == "_src.project.mymod" + assert "_src" in sys.modules + assert "_src.project" in sys.modules + + +def test_importmode_twice_is_different_module( + simple_module: Path, tmp_path: Path +) -> None: + """`importlib` mode always returns a new module.""" + module1 = import_path(simple_module, root=tmp_path) + module2 = import_path(simple_module, root=tmp_path) + assert module1 is not module2 + + +def test_no_meta_path_found( + simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Even without any meta_path should still import module.""" + monkeypatch.setattr(sys, "meta_path", []) + module = import_path(simple_module, root=tmp_path) + assert module.foo(2) == 42 # type: ignore[attr-defined] + + # mode='importlib' fails if no spec is found to load the module + import importlib.util + + monkeypatch.setattr( + importlib.util, "spec_from_file_location", lambda *args: None # noqa: ARG005 + ) + with pytest.raises(ImportError): + import_path(simple_module, root=tmp_path) + + +def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None: + """ + Ensure that importlib mode works with a module containing dataclasses (#373, + pytest#7856). + """ + fn = tmp_path.joinpath("_src/project/task_dataclass.py") + fn.parent.mkdir(parents=True) + fn.write_text( + textwrap.dedent( + """ + from dataclasses import dataclass + + @dataclass + class Data: + value: str + """ + ) + ) + + module = import_path(fn, root=tmp_path) + Data: Any = module.Data # noqa: N806 + data = Data(value="foo") + assert data.value == "foo" + assert data.__module__ == "_src.project.task_dataclass" + + +def test_importmode_importlib_with_pickle(tmp_path: Path) -> None: + """Ensure that importlib mode works with pickle (#373, pytest#7859).""" + fn = tmp_path.joinpath("_src/project/task_pickle.py") + fn.parent.mkdir(parents=True) + fn.write_text( + textwrap.dedent( + """ + import pickle + + def _action(): + return 42 + + def round_trip(): + s = pickle.dumps(_action) + return pickle.loads(s) + """ + ) + ) + + module = import_path(fn, root=tmp_path) + round_trip = module.round_trip + action = round_trip() + assert action() == 42 + + +def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None: + """ + Ensure that importlib mode works can load pickles that look similar but are + defined in separate modules. + """ + fn1 = tmp_path.joinpath("_src/m1/project/task.py") + fn1.parent.mkdir(parents=True) + fn1.write_text( + textwrap.dedent( + """ + import dataclasses + import pickle + + @dataclasses.dataclass + class Data: + x: int = 42 + """ + ) + ) + + fn2 = tmp_path.joinpath("_src/m2/project/task.py") + fn2.parent.mkdir(parents=True) + fn2.write_text( + textwrap.dedent( + """ + import dataclasses + import pickle + + @dataclasses.dataclass + class Data: + x: str = "" + """ + ) + ) + + import pickle + + def round_trip(obj): + s = pickle.dumps(obj) + return pickle.loads(s) # noqa: S301 + + module = import_path(fn1, root=tmp_path) + Data1 = module.Data # noqa: N806 + + module = import_path(fn2, root=tmp_path) + Data2 = module.Data # noqa: N806 + + assert round_trip(Data1(20)) == Data1(20) + assert round_trip(Data2("hello")) == Data2("hello") + assert Data1.__module__ == "_src.m1.project.task" + assert Data2.__module__ == "_src.m2.project.task" + + +def test_module_name_from_path(tmp_path: Path) -> None: + result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path) + assert result == "src.project.task_foo" + + # Path is not relative to root dir: use the full path to obtain the module name. + result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar")) + assert result == "home.foo.task_foo" + + +def test_insert_missing_modules( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + # Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and + # don't end up being imported. + modules = {"xxx.project.foo": ModuleType("xxx.project.foo")} + _insert_missing_modules(modules, "xxx.project.foo") + assert sorted(modules) == ["xxx", "xxx.project", "xxx.project.foo"] + + mod = ModuleType("mod", doc="My Module") + modules = {"xxy": mod} + _insert_missing_modules(modules, "xxy") + assert modules == {"xxy": mod} + + modules = {} + _insert_missing_modules(modules, "") + assert not modules