11"""Import a module/class, return a Mock object if import fails.
22
33Copied from sktime/skbase.
4+
5+ Should be refactored and moved to a common location in skbase.
46"""
57
68import importlib
1214def _safe_import (import_path , pkg_name = None ):
1315 """Import a module/class, return a Mock object if import fails.
1416
17+ Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``.
1518 The function supports importing both top-level modules and nested attributes:
1619
17- - Top-level module: "torch" -> imports torch
18- - Nested module: "torch.nn" -> imports torch.nn
19- - Class/function: "torch.nn.Linear" -> imports Linear class from torch.nn
20+ - Top-level module: `` "torch"`` -> same as ``import torch``
21+ - Nested module: `` "torch.nn"`` -> same as``from torch import nn``
22+ - Class/function: `` "torch.nn.Linear"`` -> same as `` from torch.nn import Linear``
2023
2124 Parameters
2225 ----------
2326 import_path : str
2427 The path to the module/class to import. Can be:
2528
26- - Single module: "torch"
27- - Nested module: "torch.nn"
28- - Class/attribute: "torch.nn.ReLU"
29+ - Single module: `` "torch"``
30+ - Nested module: `` "torch.nn"``
31+ - Class/attribute: `` "torch.nn.ReLU"``
2932
3033 Note: The dots in the path determine the import behavior:
3134
@@ -37,39 +40,44 @@ def _safe_import(import_path, pkg_name=None):
3740 The name of the package to check for installation. This is useful when
3841 the import name differs from the package name, for example:
3942
40- - import: "sklearn" -> pkg_name="scikit-learn"
41- - import: "cv2" -> pkg_name="opencv-python"
43+ - import: `` "sklearn"`` -> `` pkg_name="scikit-learn"``
44+ - import: `` "cv2"`` -> `` pkg_name="opencv-python"``
4245
43- If None, uses the first part of import_path before the dot.
46+ If `` None`` , uses the first part of `` import_path`` before the dot.
4447
4548 Returns
4649 -------
4750 object
48- One of the following:
51+ If the import path and ``pkg_name`` is present, one of the following:
52+
53+ - The imported module if ``import_path`` has no dots
54+ - The imported submodule if ``import_path`` has one dot
55+ - The imported class/function if ``import_path`` has multiple dots
4956
50- - The imported module if import_path has no dots
51- - The imported submodule if import_path has one dot
52- - The imported class/function if import_path has multiple dots
53- - A MagicMock object that returns an installation message if the
54- package is not found
57+ If the package or import path are not found:
58+ a unique ``MagicMock`` object per unique import path.
5559
5660 Examples
5761 --------
62+ >>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import
63+
5864 >>> # Import a top-level module
59- >>> torch = safe_import ("torch")
65+ >>> torch = _safe_import ("torch")
6066
6167 >>> # Import a submodule
62- >>> nn = safe_import ("torch.nn")
68+ >>> nn = _safe_import ("torch.nn")
6369
6470 >>> # Import a specific class
65- >>> Linear = safe_import ("torch.nn.Linear")
71+ >>> Linear = _safe_import ("torch.nn.Linear")
6672
6773 >>> # Import with different package name
68- >>> cv2 = safe_import ("cv2", pkg_name="opencv-python")
74+ >>> cv2 = _safe_import ("cv2", pkg_name="opencv-python")
6975 """
76+ path_list = import_path .split ("." )
77+
7078 if pkg_name is None :
71- path_list = import_path .split ("." )
7279 pkg_name = path_list [0 ]
80+ obj_name = path_list [- 1 ]
7381
7482 if pkg_name in _get_installed_packages ():
7583 try :
@@ -79,10 +87,44 @@ def _safe_import(import_path, pkg_name=None):
7987 module = importlib .import_module (module_name )
8088 return getattr (module , attr_name )
8189 except (ImportError , AttributeError ):
82- return importlib .import_module (import_path )
83- else :
84- mock_obj = MagicMock ()
85- mock_obj .__call__ = MagicMock (
86- return_value = f"Please install { pkg_name } to use this functionality."
87- )
88- return mock_obj
90+ pass
91+
92+ mock_obj = _create_mock_class (obj_name )
93+ return mock_obj
94+
95+
96+ class CommonMagicMeta (type ):
97+ def __getattr__ (cls , name ):
98+ return MagicMock ()
99+
100+ def __setattr__ (cls , name , value ):
101+ pass # Ignore attribute writes
102+
103+
104+ class MagicAttribute (metaclass = CommonMagicMeta ):
105+ def __getattr__ (self , name ):
106+ return MagicMock ()
107+
108+ def __setattr__ (self , name , value ):
109+ pass # Ignore attribute writes
110+
111+ def __call__ (self , * args , ** kwargs ):
112+ return self # Ensures instantiation returns the same object
113+
114+
115+ def _create_mock_class (name : str , bases = ()):
116+ """Create new dynamic mock class similar to MagicMock.
117+
118+ Parameters
119+ ----------
120+ name : str
121+ The name of the new class.
122+ bases : tuple, default=()
123+ The base classes of the new class.
124+
125+ Returns
126+ -------
127+ a new class that behaves like MagicMock, with name ``name``.
128+ Forwards all attribute access to a MagicMock object stored in the instance.
129+ """
130+ return type (name , (MagicAttribute ,), {"__metaclass__" : CommonMagicMeta })
0 commit comments