Skip to content

Commit 9e6be35

Browse files
Fix import mechanism for task modules. (#373)
Co-authored-by: Nick Crews <[email protected]> Co-authored-by: Tobias Raabe <[email protected]>
1 parent 836518a commit 9e6be35

File tree

5 files changed

+306
-9
lines changed

5 files changed

+306
-9
lines changed

src/_pytask/collect.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import sys
88
import time
99
import warnings
10-
from importlib import util as importlib_util
1110
from pathlib import Path
1211
from typing import Any
1312
from typing import Generator
@@ -28,6 +27,7 @@
2827
from _pytask.outcomes import CollectionOutcome
2928
from _pytask.outcomes import count_outcomes
3029
from _pytask.path import find_case_sensitive_path
30+
from _pytask.path import import_path
3131
from _pytask.report import CollectionReport
3232
from _pytask.session import Session
3333
from _pytask.shared import find_duplicates
@@ -121,13 +121,7 @@ def pytask_collect_file(
121121
) -> list[CollectionReport] | None:
122122
"""Collect a file."""
123123
if any(path.match(pattern) for pattern in session.config["task_files"]):
124-
spec = importlib_util.spec_from_file_location(path.stem, str(path))
125-
126-
if spec is None:
127-
raise ImportError(f"Can't find module {path.stem!r} at location {path}.")
128-
129-
mod = importlib_util.module_from_spec(spec)
130-
spec.loader.exec_module(mod)
124+
mod = import_path(path, session.config["root"])
131125

132126
collected_reports = []
133127
for name, obj in inspect.getmembers(mod):

src/_pytask/path.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from __future__ import annotations
33

44
import functools
5+
import importlib.util
56
import os
7+
import sys
68
from pathlib import Path
9+
from types import ModuleType
710
from typing import Sequence
811

912

@@ -120,3 +123,76 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path:
120123
"""
121124
out = path.resolve() if platform == "win32" else path
122125
return out
126+
127+
128+
def import_path(path: Path, root: Path) -> ModuleType:
129+
"""Import and return a module from the given path.
130+
131+
The function is taken from pytest when the import mode is set to ``importlib``. It
132+
pytest's recommended import mode for new projects although the default is set to
133+
``prepend``. More discussion and information can be found in :gh:`373`.
134+
135+
"""
136+
module_name = _module_name_from_path(path, root)
137+
138+
spec = importlib.util.spec_from_file_location(module_name, str(path))
139+
140+
if spec is None:
141+
raise ImportError(f"Can't find module {module_name!r} at location {path}.")
142+
143+
mod = importlib.util.module_from_spec(spec)
144+
sys.modules[module_name] = mod
145+
spec.loader.exec_module(mod)
146+
_insert_missing_modules(sys.modules, module_name)
147+
return mod
148+
149+
150+
def _module_name_from_path(path: Path, root: Path) -> str:
151+
"""Return a dotted module name based on the given path, anchored on root.
152+
153+
For example: path="projects/src/project/task_foo.py" and root="/projects", the
154+
resulting module name will be "src.project.task_foo".
155+
156+
"""
157+
path = path.with_suffix("")
158+
try:
159+
relative_path = path.relative_to(root)
160+
except ValueError:
161+
# If we can't get a relative path to root, use the full path, except for the
162+
# first part ("d:\\" or "/" depending on the platform, for example).
163+
path_parts = path.parts[1:]
164+
else:
165+
# Use the parts for the relative path to the root path.
166+
path_parts = relative_path.parts
167+
168+
return ".".join(path_parts)
169+
170+
171+
def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
172+
"""Insert missing modules when importing modules with :func:`import_path`.
173+
174+
When we want to import a module as ``src.project.task_foo`` for example, we need to
175+
create empty modules ``src`` and ``src.project`` after inserting
176+
``src.project.task_foo``, otherwise ``src.project.task_foo`` is not importable by
177+
``__import__``.
178+
179+
"""
180+
module_parts = module_name.split(".")
181+
while module_name:
182+
if module_name not in modules:
183+
try:
184+
# If sys.meta_path is empty, calling import_module will issue a warning
185+
# and raise ModuleNotFoundError. To avoid the warning, we check
186+
# sys.meta_path explicitly and raise the error ourselves to fall back to
187+
# creating a dummy module.
188+
if not sys.meta_path:
189+
raise ModuleNotFoundError
190+
importlib.import_module(module_name)
191+
except ModuleNotFoundError:
192+
module = ModuleType(
193+
module_name,
194+
doc="Empty module created by pytask.",
195+
)
196+
modules[module_name] = module
197+
module_parts.pop(-1)
198+
module_name = ".".join(module_parts)

tests/test_collect.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,46 @@ def task_write_text(depends_on, produces):
3737
assert tmp_path.joinpath("out.txt").read_text() == "Relative paths work."
3838

3939

40+
@pytest.mark.end_to_end()
41+
def test_collect_tasks_from_modules_with_the_same_name(tmp_path):
42+
"""We need to check that task modules can have the same name. See #373 and #374."""
43+
tmp_path.joinpath("a").mkdir()
44+
tmp_path.joinpath("b").mkdir()
45+
tmp_path.joinpath("a", "task_module.py").write_text("def task_a(): pass")
46+
tmp_path.joinpath("b", "task_module.py").write_text("def task_a(): pass")
47+
session = main({"paths": tmp_path})
48+
assert len(session.collection_reports) == 2
49+
assert all(
50+
report.outcome == CollectionOutcome.SUCCESS
51+
for report in session.collection_reports
52+
)
53+
assert {
54+
report.node.function.__module__ for report in session.collection_reports
55+
} == {"a.task_module", "b.task_module"}
56+
57+
58+
@pytest.mark.end_to_end()
59+
def test_collect_module_name(tmp_path):
60+
"""We need to add a task module to the sys.modules. See #373 and #374."""
61+
source = """
62+
# without this import, everything works fine
63+
from __future__ import annotations
64+
65+
import dataclasses
66+
67+
@dataclasses.dataclass
68+
class Data:
69+
x: int
70+
71+
def task_my_task():
72+
pass
73+
"""
74+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
75+
session = main({"paths": tmp_path})
76+
outcome = session.collection_reports[0].outcome
77+
assert outcome == CollectionOutcome.SUCCESS
78+
79+
4080
@pytest.mark.end_to_end()
4181
def test_collect_filepathnode_with_unknown_type(tmp_path):
4282
"""If a node cannot be parsed because unknown type, raise an error."""

tests/test_debugging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def task_1():
365365
x = 3
366366
print("hello18")
367367
assert count_continue == 2, "unexpected_failure: %d != 2" % count_continue
368+
raise Exception("expected_failure")
368369
"""
369370
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
370371

@@ -399,7 +400,8 @@ def task_1():
399400
assert "1" in rest
400401
assert "failed" in rest
401402
assert "Failed" in rest
402-
assert "AssertionError: unexpected_failure" in rest
403+
assert "AssertionError: unexpected_failure" not in rest
404+
assert "expected_failure" in rest
403405
_flush(child)
404406

405407

tests/test_path.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from __future__ import annotations
22

33
import sys
4+
import textwrap
45
from contextlib import ExitStack as does_not_raise # noqa: N813
56
from pathlib import Path
67
from pathlib import PurePosixPath
78
from pathlib import PureWindowsPath
9+
from types import ModuleType
10+
from typing import Any
811

912
import pytest
13+
from _pytask.path import _insert_missing_modules
14+
from _pytask.path import _module_name_from_path
1015
from _pytask.path import find_case_sensitive_path
1116
from _pytask.path import find_closest_ancestor
1217
from _pytask.path import find_common_ancestor
18+
from _pytask.path import import_path
1319
from _pytask.path import relative_to
1420

1521

@@ -117,3 +123,182 @@ def test_find_case_sensitive_path(tmp_path, path, existing_paths, expected):
117123

118124
result = find_case_sensitive_path(tmp_path / path, sys.platform)
119125
assert result == tmp_path / expected
126+
127+
128+
@pytest.fixture()
129+
def simple_module(tmp_path: Path) -> Path:
130+
fn = tmp_path / "_src/project/mymod.py"
131+
fn.parent.mkdir(parents=True)
132+
fn.write_text("def foo(x): return 40 + x")
133+
return fn
134+
135+
136+
def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
137+
"""`importlib` mode does not change sys.path."""
138+
module = import_path(simple_module, root=tmp_path)
139+
assert module.foo(2) == 42 # type: ignore[attr-defined]
140+
assert str(simple_module.parent) not in sys.path
141+
assert module.__name__ in sys.modules
142+
assert module.__name__ == "_src.project.mymod"
143+
assert "_src" in sys.modules
144+
assert "_src.project" in sys.modules
145+
146+
147+
def test_importmode_twice_is_different_module(
148+
simple_module: Path, tmp_path: Path
149+
) -> None:
150+
"""`importlib` mode always returns a new module."""
151+
module1 = import_path(simple_module, root=tmp_path)
152+
module2 = import_path(simple_module, root=tmp_path)
153+
assert module1 is not module2
154+
155+
156+
def test_no_meta_path_found(
157+
simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
158+
) -> None:
159+
"""Even without any meta_path should still import module."""
160+
monkeypatch.setattr(sys, "meta_path", [])
161+
module = import_path(simple_module, root=tmp_path)
162+
assert module.foo(2) == 42 # type: ignore[attr-defined]
163+
164+
# mode='importlib' fails if no spec is found to load the module
165+
import importlib.util
166+
167+
monkeypatch.setattr(
168+
importlib.util, "spec_from_file_location", lambda *args: None # noqa: ARG005
169+
)
170+
with pytest.raises(ImportError):
171+
import_path(simple_module, root=tmp_path)
172+
173+
174+
def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
175+
"""
176+
Ensure that importlib mode works with a module containing dataclasses (#373,
177+
pytest#7856).
178+
"""
179+
fn = tmp_path.joinpath("_src/project/task_dataclass.py")
180+
fn.parent.mkdir(parents=True)
181+
fn.write_text(
182+
textwrap.dedent(
183+
"""
184+
from dataclasses import dataclass
185+
186+
@dataclass
187+
class Data:
188+
value: str
189+
"""
190+
)
191+
)
192+
193+
module = import_path(fn, root=tmp_path)
194+
Data: Any = module.Data # noqa: N806
195+
data = Data(value="foo")
196+
assert data.value == "foo"
197+
assert data.__module__ == "_src.project.task_dataclass"
198+
199+
200+
def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
201+
"""Ensure that importlib mode works with pickle (#373, pytest#7859)."""
202+
fn = tmp_path.joinpath("_src/project/task_pickle.py")
203+
fn.parent.mkdir(parents=True)
204+
fn.write_text(
205+
textwrap.dedent(
206+
"""
207+
import pickle
208+
209+
def _action():
210+
return 42
211+
212+
def round_trip():
213+
s = pickle.dumps(_action)
214+
return pickle.loads(s)
215+
"""
216+
)
217+
)
218+
219+
module = import_path(fn, root=tmp_path)
220+
round_trip = module.round_trip
221+
action = round_trip()
222+
assert action() == 42
223+
224+
225+
def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
226+
"""
227+
Ensure that importlib mode works can load pickles that look similar but are
228+
defined in separate modules.
229+
"""
230+
fn1 = tmp_path.joinpath("_src/m1/project/task.py")
231+
fn1.parent.mkdir(parents=True)
232+
fn1.write_text(
233+
textwrap.dedent(
234+
"""
235+
import dataclasses
236+
import pickle
237+
238+
@dataclasses.dataclass
239+
class Data:
240+
x: int = 42
241+
"""
242+
)
243+
)
244+
245+
fn2 = tmp_path.joinpath("_src/m2/project/task.py")
246+
fn2.parent.mkdir(parents=True)
247+
fn2.write_text(
248+
textwrap.dedent(
249+
"""
250+
import dataclasses
251+
import pickle
252+
253+
@dataclasses.dataclass
254+
class Data:
255+
x: str = ""
256+
"""
257+
)
258+
)
259+
260+
import pickle
261+
262+
def round_trip(obj):
263+
s = pickle.dumps(obj)
264+
return pickle.loads(s) # noqa: S301
265+
266+
module = import_path(fn1, root=tmp_path)
267+
Data1 = module.Data # noqa: N806
268+
269+
module = import_path(fn2, root=tmp_path)
270+
Data2 = module.Data # noqa: N806
271+
272+
assert round_trip(Data1(20)) == Data1(20)
273+
assert round_trip(Data2("hello")) == Data2("hello")
274+
assert Data1.__module__ == "_src.m1.project.task"
275+
assert Data2.__module__ == "_src.m2.project.task"
276+
277+
278+
def test_module_name_from_path(tmp_path: Path) -> None:
279+
result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path)
280+
assert result == "src.project.task_foo"
281+
282+
# Path is not relative to root dir: use the full path to obtain the module name.
283+
result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar"))
284+
assert result == "home.foo.task_foo"
285+
286+
287+
def test_insert_missing_modules(
288+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
289+
) -> None:
290+
monkeypatch.chdir(tmp_path)
291+
# Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and
292+
# don't end up being imported.
293+
modules = {"xxx.project.foo": ModuleType("xxx.project.foo")}
294+
_insert_missing_modules(modules, "xxx.project.foo")
295+
assert sorted(modules) == ["xxx", "xxx.project", "xxx.project.foo"]
296+
297+
mod = ModuleType("mod", doc="My Module")
298+
modules = {"xxy": mod}
299+
_insert_missing_modules(modules, "xxy")
300+
assert modules == {"xxy": mod}
301+
302+
modules = {}
303+
_insert_missing_modules(modules, "")
304+
assert not modules

0 commit comments

Comments
 (0)