Skip to content

Commit 7c77693

Browse files
fnhirwaPranavBhatP
authored andcommitted
[ENH] Allow multiple instances from multiple mock classes in _safe_import (sktime#1818)
The issue was spotted in sktime via sktime/sktime#8061 Updated the implementation and copied the unit tests to ensure that it works with pytorch-forecasting. closes sktime#1815
1 parent fa1dbde commit 7c77693

File tree

4 files changed

+216
-27
lines changed

4 files changed

+216
-27
lines changed
Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Import a module/class, return a Mock object if import fails.
22
33
Copied from sktime/skbase.
4+
5+
Should be refactored and moved to a common location in skbase.
46
"""
57

68
import importlib
@@ -12,20 +14,21 @@
1214
def _safe_import(import_path, pkg_name=None):
1315
"""Import a module/class, return a Mock object if import fails.
1416
17+
Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``.
1518
The function supports importing both top-level modules and nested attributes:
1619
17-
- Top-level module: "torch" -> imports torch
18-
- Nested module: "torch.nn" -> imports torch.nn
19-
- Class/function: "torch.nn.Linear" -> imports Linear class from torch.nn
20+
- Top-level module: ``"torch"`` -> same as ``import torch``
21+
- Nested module: ``"torch.nn"`` -> same as``from torch import nn``
22+
- Class/function: ``"torch.nn.Linear"`` -> same as ``from torch.nn import Linear``
2023
2124
Parameters
2225
----------
2326
import_path : str
2427
The path to the module/class to import. Can be:
2528
26-
- Single module: "torch"
27-
- Nested module: "torch.nn"
28-
- Class/attribute: "torch.nn.ReLU"
29+
- Single module: ``"torch"``
30+
- Nested module: ``"torch.nn"``
31+
- Class/attribute: ``"torch.nn.ReLU"``
2932
3033
Note: The dots in the path determine the import behavior:
3134
@@ -37,39 +40,44 @@ def _safe_import(import_path, pkg_name=None):
3740
The name of the package to check for installation. This is useful when
3841
the import name differs from the package name, for example:
3942
40-
- import: "sklearn" -> pkg_name="scikit-learn"
41-
- import: "cv2" -> pkg_name="opencv-python"
43+
- import: ``"sklearn"`` -> ``pkg_name="scikit-learn"``
44+
- import: ``"cv2"`` -> ``pkg_name="opencv-python"``
4245
43-
If None, uses the first part of import_path before the dot.
46+
If ``None``, uses the first part of ``import_path`` before the dot.
4447
4548
Returns
4649
-------
4750
object
48-
One of the following:
51+
If the import path and ``pkg_name`` is present, one of the following:
52+
53+
- The imported module if ``import_path`` has no dots
54+
- The imported submodule if ``import_path`` has one dot
55+
- The imported class/function if ``import_path`` has multiple dots
4956
50-
- The imported module if import_path has no dots
51-
- The imported submodule if import_path has one dot
52-
- The imported class/function if import_path has multiple dots
53-
- A MagicMock object that returns an installation message if the
54-
package is not found
57+
If the package or import path are not found:
58+
a unique ``MagicMock`` object per unique import path.
5559
5660
Examples
5761
--------
62+
>>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import
63+
5864
>>> # Import a top-level module
59-
>>> torch = safe_import("torch")
65+
>>> torch = _safe_import("torch")
6066
6167
>>> # Import a submodule
62-
>>> nn = safe_import("torch.nn")
68+
>>> nn = _safe_import("torch.nn")
6369
6470
>>> # Import a specific class
65-
>>> Linear = safe_import("torch.nn.Linear")
71+
>>> Linear = _safe_import("torch.nn.Linear")
6672
6773
>>> # Import with different package name
68-
>>> cv2 = safe_import("cv2", pkg_name="opencv-python")
74+
>>> cv2 = _safe_import("cv2", pkg_name="opencv-python")
6975
"""
76+
path_list = import_path.split(".")
77+
7078
if pkg_name is None:
71-
path_list = import_path.split(".")
7279
pkg_name = path_list[0]
80+
obj_name = path_list[-1]
7381

7482
if pkg_name in _get_installed_packages():
7583
try:
@@ -79,10 +87,44 @@ def _safe_import(import_path, pkg_name=None):
7987
module = importlib.import_module(module_name)
8088
return getattr(module, attr_name)
8189
except (ImportError, AttributeError):
82-
return importlib.import_module(import_path)
83-
else:
84-
mock_obj = MagicMock()
85-
mock_obj.__call__ = MagicMock(
86-
return_value=f"Please install {pkg_name} to use this functionality."
87-
)
88-
return mock_obj
90+
pass
91+
92+
mock_obj = _create_mock_class(obj_name)
93+
return mock_obj
94+
95+
96+
class CommonMagicMeta(type):
97+
def __getattr__(cls, name):
98+
return MagicMock()
99+
100+
def __setattr__(cls, name, value):
101+
pass # Ignore attribute writes
102+
103+
104+
class MagicAttribute(metaclass=CommonMagicMeta):
105+
def __getattr__(self, name):
106+
return MagicMock()
107+
108+
def __setattr__(self, name, value):
109+
pass # Ignore attribute writes
110+
111+
def __call__(self, *args, **kwargs):
112+
return self # Ensures instantiation returns the same object
113+
114+
115+
def _create_mock_class(name: str, bases=()):
116+
"""Create new dynamic mock class similar to MagicMock.
117+
118+
Parameters
119+
----------
120+
name : str
121+
The name of the new class.
122+
bases : tuple, default=()
123+
The base classes of the new class.
124+
125+
Returns
126+
-------
127+
a new class that behaves like MagicMock, with name ``name``.
128+
Forwards all attribute access to a MagicMock object stored in the instance.
129+
"""
130+
return type(name, (MagicAttribute,), {"__metaclass__": CommonMagicMeta})
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for dependency utilities."""
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
__author__ = ["jgyasu", "fkiraly"]
2+
3+
from pytorch_forecasting.utils._dependencies import (
4+
_get_installed_packages,
5+
_safe_import,
6+
)
7+
8+
9+
def test_import_present_module():
10+
"""Test importing a dependency that is installed."""
11+
result = _safe_import("pandas")
12+
assert result is not None
13+
assert "pandas" in _get_installed_packages()
14+
15+
16+
def test_import_missing_module():
17+
"""Test importing a dependency that is not installed."""
18+
result = _safe_import("nonexistent_module")
19+
assert hasattr(result, "__name__")
20+
assert result.__name__ == "nonexistent_module"
21+
22+
23+
def test_import_without_pkg_name():
24+
"""Test importing a dependency with the same name as package name."""
25+
result = _safe_import("torch", pkg_name="torch")
26+
assert result is not None
27+
28+
29+
def test_import_with_different_pkg_name_1():
30+
"""Test importing a dependency with a different package name."""
31+
result = _safe_import("skbase", pkg_name="scikit-base")
32+
assert result is not None
33+
34+
35+
def test_import_with_different_pkg_name_2():
36+
"""Test importing another dependency with a different package name."""
37+
result = _safe_import("cv2", pkg_name="opencv-python")
38+
assert result is not None
39+
40+
41+
def test_import_submodule():
42+
"""Test importing a submodule."""
43+
result = _safe_import("torch.nn")
44+
assert result is not None
45+
46+
47+
def test_import_class():
48+
"""Test importing a class."""
49+
result = _safe_import("torch.nn.Linear")
50+
assert result is not None
51+
52+
53+
def test_import_existing_object():
54+
"""Test importing an existing object."""
55+
result = _safe_import("pandas.DataFrame")
56+
assert result is not None
57+
assert result.__name__ == "DataFrame"
58+
from pandas import DataFrame
59+
60+
assert result is DataFrame
61+
62+
63+
def test_multiple_inheritance_from_mock():
64+
"""Test multiple inheritance from dynamic MagicMock."""
65+
Class1 = _safe_import("foobar.foo.FooBar")
66+
Class2 = _safe_import("barfoobar.BarFooBar")
67+
68+
class NewClass(Class1, Class2):
69+
"""This should not trigger an error.
70+
71+
The class definition would trigger an error if multiple inheritance
72+
from Class1 and Class2 does not work, e.g., if it is simply
73+
identical to MagicMock.
74+
"""
75+
76+
pass
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from pytorch_forecasting.utils._dependencies import _safe_import
2+
3+
4+
def test_present_module():
5+
"""Test importing a dependency that is installed."""
6+
module = _safe_import("torch")
7+
assert module is not None
8+
9+
10+
def test_import_missing_module():
11+
"""Test importing a dependency that is not installed."""
12+
result = _safe_import("nonexistent_module")
13+
assert hasattr(result, "__name__")
14+
assert result.__name__ == "nonexistent_module"
15+
16+
17+
def test_import_without_pkg_name():
18+
"""Test importing a dependency with the same name as package name."""
19+
result = _safe_import("torch", pkg_name="torch")
20+
assert result is not None
21+
22+
23+
def test_import_with_different_pkg_name_1():
24+
"""Test importing a dependency with a different package name."""
25+
result = _safe_import("skbase", pkg_name="scikit-base")
26+
assert result is not None
27+
28+
29+
def test_import_with_different_pkg_name_2():
30+
"""Test importing another dependency with a different package name."""
31+
result = _safe_import("cv2", pkg_name="opencv-python")
32+
assert result is not None
33+
34+
35+
def test_import_submodule():
36+
"""Test importing a submodule."""
37+
result = _safe_import("torch.nn")
38+
assert result is not None
39+
40+
41+
def test_import_class():
42+
"""Test importing a class."""
43+
result = _safe_import("torch.nn.Linear")
44+
assert result is not None
45+
46+
47+
def test_import_existing_object():
48+
"""Test importing an existing object."""
49+
result = _safe_import("pandas.DataFrame")
50+
assert result is not None
51+
assert result.__name__ == "DataFrame"
52+
from pandas import DataFrame
53+
54+
assert result is DataFrame
55+
56+
57+
def test_multiple_inheritance_from_mock():
58+
"""Test multiple inheritance from dynamic MagicMock."""
59+
Class1 = _safe_import("foobar.foo.FooBar")
60+
Class2 = _safe_import("barfoobar.BarFooBar")
61+
62+
class NewClass(Class1, Class2):
63+
"""This should not trigger an error.
64+
65+
The class definition would trigger an error if multiple inheritance
66+
from Class1 and Class2 does not work, e.g., if it is simply
67+
identical to MagicMock.
68+
"""
69+
70+
pass

0 commit comments

Comments
 (0)