Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 69 additions & 27 deletions pytorch_forecasting/utils/_dependencies/_safe_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Import a module/class, return a Mock object if import fails.

Copied from sktime/skbase.

Should be refactored and moved to a common location in skbase.
"""

import importlib
Expand All @@ -12,20 +14,21 @@
def _safe_import(import_path, pkg_name=None):
"""Import a module/class, return a Mock object if import fails.

Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``.
The function supports importing both top-level modules and nested attributes:

- Top-level module: "torch" -> imports torch
- Nested module: "torch.nn" -> imports torch.nn
- Class/function: "torch.nn.Linear" -> imports Linear class from torch.nn
- Top-level module: ``"torch"`` -> same as ``import torch``
- Nested module: ``"torch.nn"`` -> same as``from torch import nn``
- Class/function: ``"torch.nn.Linear"`` -> same as ``from torch.nn import Linear``

Parameters
----------
import_path : str
The path to the module/class to import. Can be:

- Single module: "torch"
- Nested module: "torch.nn"
- Class/attribute: "torch.nn.ReLU"
- Single module: ``"torch"``
- Nested module: ``"torch.nn"``
- Class/attribute: ``"torch.nn.ReLU"``

Note: The dots in the path determine the import behavior:

Expand All @@ -37,39 +40,44 @@
The name of the package to check for installation. This is useful when
the import name differs from the package name, for example:

- import: "sklearn" -> pkg_name="scikit-learn"
- import: "cv2" -> pkg_name="opencv-python"
- import: ``"sklearn"`` -> ``pkg_name="scikit-learn"``
- import: ``"cv2"`` -> ``pkg_name="opencv-python"``

If None, uses the first part of import_path before the dot.
If ``None``, uses the first part of ``import_path`` before the dot.

Returns
-------
object
One of the following:
If the import path and ``pkg_name`` is present, one of the following:

- The imported module if ``import_path`` has no dots
- The imported submodule if ``import_path`` has one dot
- The imported class/function if ``import_path`` has multiple dots

- The imported module if import_path has no dots
- The imported submodule if import_path has one dot
- The imported class/function if import_path has multiple dots
- A MagicMock object that returns an installation message if the
package is not found
If the package or import path are not found:
a unique ``MagicMock`` object per unique import path.

Examples
--------
>>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import

>>> # Import a top-level module
>>> torch = safe_import("torch")
>>> torch = _safe_import("torch")

>>> # Import a submodule
>>> nn = safe_import("torch.nn")
>>> nn = _safe_import("torch.nn")

>>> # Import a specific class
>>> Linear = safe_import("torch.nn.Linear")
>>> Linear = _safe_import("torch.nn.Linear")

>>> # Import with different package name
>>> cv2 = safe_import("cv2", pkg_name="opencv-python")
>>> cv2 = _safe_import("cv2", pkg_name="opencv-python")
"""
path_list = import_path.split(".")

if pkg_name is None:
path_list = import_path.split(".")
pkg_name = path_list[0]
obj_name = path_list[-1]

if pkg_name in _get_installed_packages():
try:
Expand All @@ -79,10 +87,44 @@
module = importlib.import_module(module_name)
return getattr(module, attr_name)
except (ImportError, AttributeError):
return importlib.import_module(import_path)
else:
mock_obj = MagicMock()
mock_obj.__call__ = MagicMock(
return_value=f"Please install {pkg_name} to use this functionality."
)
return mock_obj
pass

Check warning on line 90 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L90

Added line #L90 was not covered by tests

mock_obj = _create_mock_class(obj_name)
return mock_obj


class CommonMagicMeta(type):
def __getattr__(cls, name):
return MagicMock()

Check warning on line 98 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L98

Added line #L98 was not covered by tests

def __setattr__(cls, name, value):
pass # Ignore attribute writes

Check warning on line 101 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L101

Added line #L101 was not covered by tests


class MagicAttribute(metaclass=CommonMagicMeta):
def __getattr__(self, name):
return MagicMock()

Check warning on line 106 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L106

Added line #L106 was not covered by tests

def __setattr__(self, name, value):
pass # Ignore attribute writes

Check warning on line 109 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L109

Added line #L109 was not covered by tests

def __call__(self, *args, **kwargs):
return self # Ensures instantiation returns the same object

Check warning on line 112 in pytorch_forecasting/utils/_dependencies/_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/_safe_import.py#L112

Added line #L112 was not covered by tests


def _create_mock_class(name: str, bases=()):
"""Create new dynamic mock class similar to MagicMock.

Parameters
----------
name : str
The name of the new class.
bases : tuple, default=()
The base classes of the new class.

Returns
-------
a new class that behaves like MagicMock, with name ``name``.
Forwards all attribute access to a MagicMock object stored in the instance.
"""
return type(name, (MagicAttribute,), {"__metaclass__": CommonMagicMeta})
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for dependency utilities."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
__author__ = ["jgyasu", "fkiraly"]

Check warning on line 1 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L1

Added line #L1 was not covered by tests

from pytorch_forecasting.utils._dependencies import (

Check warning on line 3 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L3

Added line #L3 was not covered by tests
_get_installed_packages,
_safe_import,
)


def test_import_present_module():

Check warning on line 9 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L9

Added line #L9 was not covered by tests
"""Test importing a dependency that is installed."""
result = _safe_import("pandas")
assert result is not None
assert "pandas" in _get_installed_packages()

Check warning on line 13 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L11-L13

Added lines #L11 - L13 were not covered by tests


def test_import_missing_module():

Check warning on line 16 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L16

Added line #L16 was not covered by tests
"""Test importing a dependency that is not installed."""
result = _safe_import("nonexistent_module")
assert hasattr(result, "__name__")
assert result.__name__ == "nonexistent_module"

Check warning on line 20 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L18-L20

Added lines #L18 - L20 were not covered by tests


def test_import_without_pkg_name():

Check warning on line 23 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L23

Added line #L23 was not covered by tests
"""Test importing a dependency with the same name as package name."""
result = _safe_import("torch", pkg_name="torch")
assert result is not None

Check warning on line 26 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L25-L26

Added lines #L25 - L26 were not covered by tests


def test_import_with_different_pkg_name_1():

Check warning on line 29 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L29

Added line #L29 was not covered by tests
"""Test importing a dependency with a different package name."""
result = _safe_import("skbase", pkg_name="scikit-base")
assert result is not None

Check warning on line 32 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L31-L32

Added lines #L31 - L32 were not covered by tests


def test_import_with_different_pkg_name_2():

Check warning on line 35 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L35

Added line #L35 was not covered by tests
"""Test importing another dependency with a different package name."""
result = _safe_import("cv2", pkg_name="opencv-python")
assert result is not None

Check warning on line 38 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L37-L38

Added lines #L37 - L38 were not covered by tests


def test_import_submodule():

Check warning on line 41 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L41

Added line #L41 was not covered by tests
"""Test importing a submodule."""
result = _safe_import("torch.nn")
assert result is not None

Check warning on line 44 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L43-L44

Added lines #L43 - L44 were not covered by tests


def test_import_class():

Check warning on line 47 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L47

Added line #L47 was not covered by tests
"""Test importing a class."""
result = _safe_import("torch.nn.Linear")
assert result is not None

Check warning on line 50 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L49-L50

Added lines #L49 - L50 were not covered by tests


def test_import_existing_object():

Check warning on line 53 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L53

Added line #L53 was not covered by tests
"""Test importing an existing object."""
result = _safe_import("pandas.DataFrame")
assert result is not None
assert result.__name__ == "DataFrame"
from pandas import DataFrame

Check warning on line 58 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L55-L58

Added lines #L55 - L58 were not covered by tests

assert result is DataFrame

Check warning on line 60 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L60

Added line #L60 was not covered by tests


def test_multiple_inheritance_from_mock():

Check warning on line 63 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L63

Added line #L63 was not covered by tests
"""Test multiple inheritance from dynamic MagicMock."""
Class1 = _safe_import("foobar.foo.FooBar")
Class2 = _safe_import("barfoobar.BarFooBar")

Check warning on line 66 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L65-L66

Added lines #L65 - L66 were not covered by tests

class NewClass(Class1, Class2):

Check warning on line 68 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L68

Added line #L68 was not covered by tests
"""This should not trigger an error.

The class definition would trigger an error if multiple inheritance
from Class1 and Class2 does not work, e.g., if it is simply
identical to MagicMock.
"""

pass

Check warning on line 76 in pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py#L76

Added line #L76 was not covered by tests
70 changes: 70 additions & 0 deletions tests/test_utils/test_safe_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from pytorch_forecasting.utils._dependencies import _safe_import


def test_present_module():
"""Test importing a dependency that is installed."""
module = _safe_import("torch")
assert module is not None


def test_import_missing_module():
"""Test importing a dependency that is not installed."""
result = _safe_import("nonexistent_module")
assert hasattr(result, "__name__")
assert result.__name__ == "nonexistent_module"


def test_import_without_pkg_name():
"""Test importing a dependency with the same name as package name."""
result = _safe_import("torch", pkg_name="torch")
assert result is not None


def test_import_with_different_pkg_name_1():
"""Test importing a dependency with a different package name."""
result = _safe_import("skbase", pkg_name="scikit-base")
assert result is not None


def test_import_with_different_pkg_name_2():
"""Test importing another dependency with a different package name."""
result = _safe_import("cv2", pkg_name="opencv-python")
assert result is not None


def test_import_submodule():
"""Test importing a submodule."""
result = _safe_import("torch.nn")
assert result is not None


def test_import_class():
"""Test importing a class."""
result = _safe_import("torch.nn.Linear")
assert result is not None


def test_import_existing_object():
"""Test importing an existing object."""
result = _safe_import("pandas.DataFrame")
assert result is not None
assert result.__name__ == "DataFrame"
from pandas import DataFrame

assert result is DataFrame


def test_multiple_inheritance_from_mock():
"""Test multiple inheritance from dynamic MagicMock."""
Class1 = _safe_import("foobar.foo.FooBar")
Class2 = _safe_import("barfoobar.BarFooBar")

class NewClass(Class1, Class2):
"""This should not trigger an error.

The class definition would trigger an error if multiple inheritance
from Class1 and Class2 does not work, e.g., if it is simply
identical to MagicMock.
"""

pass
Loading