From 689818f7832281fbe8129156f9841ac3f71a8d4c Mon Sep 17 00:00:00 2001 From: hirwa Date: Tue, 15 Apr 2025 17:51:55 +0530 Subject: [PATCH 1/3] align with safe imports from sktime --- .../utils/_dependencies/_safe_import.py | 48 +++++++++++-- tests/test_utils/test_safe_import.py | 70 +++++++++++++++++++ 2 files changed, 111 insertions(+), 7 deletions(-) create mode 100644 tests/test_utils/test_safe_import.py diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f4805f9c1..535e6eb77 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -67,8 +67,8 @@ def _safe_import(import_path, pkg_name=None): >>> # Import with different package name >>> cv2 = safe_import("cv2", pkg_name="opencv-python") """ + path_list = import_path.split(".") if pkg_name is None: - path_list = import_path.split(".") pkg_name = path_list[0] if pkg_name in _get_installed_packages(): @@ -80,9 +80,43 @@ def _safe_import(import_path, pkg_name=None): return getattr(module, attr_name) except (ImportError, AttributeError): return importlib.import_module(import_path) - else: - mock_obj = MagicMock() - mock_obj.__call__ = MagicMock( - return_value=f"Please install {pkg_name} to use this functionality." - ) - return mock_obj + obj_name = path_list[-1] + mock_obj = _create_mock_class(obj_name) + return mock_obj + + +class CommonMagicMeta(type): + def __getattr__(cls, name): + return MagicMock() + + def __setattr__(cls, name, value): + pass # Ignore attribute writes + + +class MagicAttribute(metaclass=CommonMagicMeta): + def __getattr__(self, name): + return MagicMock() + + def __setattr__(self, name, value): + pass # Ignore attribute writes + + def __call__(self, *args, **kwargs): + return self # Ensures instantiation returns the same object + + +def _create_mock_class(name: str, bases=()): + """Create new dynamic mock class similar to MagicMock. + + Parameters + ---------- + name : str + The name of the new class. + bases : tuple, default=() + The base classes of the new class. + + Returns + ------- + a new class that behaves like MagicMock, with name ``name``. + Forwards all attribute access to a MagicMock object stored in the instance. + """ + return type(name, (MagicAttribute,), {"__metaclass__": CommonMagicMeta}) diff --git a/tests/test_utils/test_safe_import.py b/tests/test_utils/test_safe_import.py new file mode 100644 index 000000000..fd4c55e09 --- /dev/null +++ b/tests/test_utils/test_safe_import.py @@ -0,0 +1,70 @@ +from pytorch_forecasting.utils._dependencies import _safe_import + + +def test_present_module(): + """Test importing a dependency that is installed.""" + module = _safe_import("torch") + assert module is not None + + +def test_import_missing_module(): + """Test importing a dependency that is not installed.""" + result = _safe_import("nonexistent_module") + assert hasattr(result, "__name__") + assert result.__name__ == "nonexistent_module" + + +def test_import_without_pkg_name(): + """Test importing a dependency with the same name as package name.""" + result = _safe_import("torch", pkg_name="torch") + assert result is not None + + +def test_import_with_different_pkg_name_1(): + """Test importing a dependency with a different package name.""" + result = _safe_import("skbase", pkg_name="scikit-base") + assert result is not None + + +def test_import_with_different_pkg_name_2(): + """Test importing another dependency with a different package name.""" + result = _safe_import("cv2", pkg_name="opencv-python") + assert result is not None + + +def test_import_submodule(): + """Test importing a submodule.""" + result = _safe_import("torch.nn") + assert result is not None + + +def test_import_class(): + """Test importing a class.""" + result = _safe_import("torch.nn.Linear") + assert result is not None + + +def test_import_existing_object(): + """Test importing an existing object.""" + result = _safe_import("pandas.DataFrame") + assert result is not None + assert result.__name__ == "DataFrame" + from pandas import DataFrame + + assert result is DataFrame + + +def test_multiple_inheritance_from_mock(): + """Test multiple inheritance from dynamic MagicMock.""" + Class1 = _safe_import("foobar.foo.FooBar") + Class2 = _safe_import("barfoobar.BarFooBar") + + class NewClass(Class1, Class2): + """This should not trigger an error. + + The class definition would trigger an error if multiple inheritance + from Class1 and Class2 does not work, e.g., if it is simply + identical to MagicMock. + """ + + pass From 4b1600657ef79c79a2d175991eb4f14c64bc3c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 27 Apr 2025 17:56:34 +0200 Subject: [PATCH 2/3] fixes --- .../utils/_dependencies/_safe_import.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index 535e6eb77..089fec978 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -12,20 +12,21 @@ def _safe_import(import_path, pkg_name=None): """Import a module/class, return a Mock object if import fails. + Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``. The function supports importing both top-level modules and nested attributes: - - Top-level module: "torch" -> imports torch - - Nested module: "torch.nn" -> imports torch.nn - - Class/function: "torch.nn.Linear" -> imports Linear class from torch.nn + - Top-level module: ``"torch"`` -> same as ``import torch`` + - Nested module: ``"torch.nn"`` -> same as``from torch import nn`` + - Class/function: ``"torch.nn.Linear"`` -> same as ``from torch.nn import Linear`` Parameters ---------- import_path : str The path to the module/class to import. Can be: - - Single module: "torch" - - Nested module: "torch.nn" - - Class/attribute: "torch.nn.ReLU" + - Single module: ``"torch"`` + - Nested module: ``"torch.nn"`` + - Class/attribute: ``"torch.nn.ReLU"`` Note: The dots in the path determine the import behavior: @@ -37,39 +38,44 @@ def _safe_import(import_path, pkg_name=None): The name of the package to check for installation. This is useful when the import name differs from the package name, for example: - - import: "sklearn" -> pkg_name="scikit-learn" - - import: "cv2" -> pkg_name="opencv-python" + - import: ``"sklearn"`` -> ``pkg_name="scikit-learn"`` + - import: ``"cv2"`` -> ``pkg_name="opencv-python"`` - If None, uses the first part of import_path before the dot. + If ``None``, uses the first part of ``import_path`` before the dot. Returns ------- object - One of the following: + If the import path and ``pkg_name`` is present, one of the following: - - The imported module if import_path has no dots - - The imported submodule if import_path has one dot - - The imported class/function if import_path has multiple dots - - A MagicMock object that returns an installation message if the - package is not found + - The imported module if ``import_path`` has no dots + - The imported submodule if ``import_path`` has one dot + - The imported class/function if ``import_path`` has multiple dots + + If the package or import path are not found: + a unique ``MagicMock`` object per unique import path. Examples -------- + >>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import + >>> # Import a top-level module - >>> torch = safe_import("torch") + >>> torch = _safe_import("torch") >>> # Import a submodule - >>> nn = safe_import("torch.nn") + >>> nn = _safe_import("torch.nn") >>> # Import a specific class - >>> Linear = safe_import("torch.nn.Linear") + >>> Linear = _safe_import("torch.nn.Linear") >>> # Import with different package name - >>> cv2 = safe_import("cv2", pkg_name="opencv-python") + >>> cv2 = _safe_import("cv2", pkg_name="opencv-python") """ path_list = import_path.split(".") + if pkg_name is None: pkg_name = path_list[0] + obj_name = path_list[-1] if pkg_name in _get_installed_packages(): try: @@ -79,8 +85,8 @@ def _safe_import(import_path, pkg_name=None): module = importlib.import_module(module_name) return getattr(module, attr_name) except (ImportError, AttributeError): - return importlib.import_module(import_path) - obj_name = path_list[-1] + pass + mock_obj = _create_mock_class(obj_name) return mock_obj From 0b2c884182902f5c00b775c3eb0cd4cfbbba2497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 27 Apr 2025 18:20:15 +0200 Subject: [PATCH 3/3] tests --- .../utils/_dependencies/_safe_import.py | 2 + .../utils/_dependencies/tests/__init__.py | 1 + .../_dependencies/tests/test_safe_import.py | 76 +++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 pytorch_forecasting/utils/_dependencies/tests/__init__.py create mode 100644 pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index 089fec978..f11313dd4 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -1,6 +1,8 @@ """Import a module/class, return a Mock object if import fails. Copied from sktime/skbase. + +Should be refactored and moved to a common location in skbase. """ import importlib diff --git a/pytorch_forecasting/utils/_dependencies/tests/__init__.py b/pytorch_forecasting/utils/_dependencies/tests/__init__.py new file mode 100644 index 000000000..4bb7bee24 --- /dev/null +++ b/pytorch_forecasting/utils/_dependencies/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for dependency utilities.""" diff --git a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py new file mode 100644 index 000000000..e0e7e7ecb --- /dev/null +++ b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py @@ -0,0 +1,76 @@ +__author__ = ["jgyasu", "fkiraly"] + +from pytorch_forecasting.utils._dependencies import ( + _get_installed_packages, + _safe_import, +) + + +def test_import_present_module(): + """Test importing a dependency that is installed.""" + result = _safe_import("pandas") + assert result is not None + assert "pandas" in _get_installed_packages() + + +def test_import_missing_module(): + """Test importing a dependency that is not installed.""" + result = _safe_import("nonexistent_module") + assert hasattr(result, "__name__") + assert result.__name__ == "nonexistent_module" + + +def test_import_without_pkg_name(): + """Test importing a dependency with the same name as package name.""" + result = _safe_import("torch", pkg_name="torch") + assert result is not None + + +def test_import_with_different_pkg_name_1(): + """Test importing a dependency with a different package name.""" + result = _safe_import("skbase", pkg_name="scikit-base") + assert result is not None + + +def test_import_with_different_pkg_name_2(): + """Test importing another dependency with a different package name.""" + result = _safe_import("cv2", pkg_name="opencv-python") + assert result is not None + + +def test_import_submodule(): + """Test importing a submodule.""" + result = _safe_import("torch.nn") + assert result is not None + + +def test_import_class(): + """Test importing a class.""" + result = _safe_import("torch.nn.Linear") + assert result is not None + + +def test_import_existing_object(): + """Test importing an existing object.""" + result = _safe_import("pandas.DataFrame") + assert result is not None + assert result.__name__ == "DataFrame" + from pandas import DataFrame + + assert result is DataFrame + + +def test_multiple_inheritance_from_mock(): + """Test multiple inheritance from dynamic MagicMock.""" + Class1 = _safe_import("foobar.foo.FooBar") + Class2 = _safe_import("barfoobar.BarFooBar") + + class NewClass(Class1, Class2): + """This should not trigger an error. + + The class definition would trigger an error if multiple inheritance + from Class1 and Class2 does not work, e.g., if it is simply + identical to MagicMock. + """ + + pass