|
3 | 3 | # pyre-strict |
4 | 4 |
|
5 | 5 | import os |
6 | | -from typing import Any, Dict, List, Optional |
| 6 | +from contextlib import AbstractContextManager, nullcontext |
| 7 | +from typing import Any, Dict, List, Optional, TYPE_CHECKING |
7 | 8 |
|
| 9 | +import numpy as np |
8 | 10 | import torch |
9 | 11 | from captum.concept._core.concept import Concept |
10 | 12 | from captum.concept._utils.common import concepts_to_str |
@@ -166,7 +168,29 @@ def load( |
166 | 168 | cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer) |
167 | 169 |
|
168 | 170 | if os.path.exists(cavs_path): |
169 | | - save_dict = torch.load(cavs_path) |
| 171 | + # Necessary for Python >=3.7 and <3.9! |
| 172 | + if TYPE_CHECKING: |
| 173 | + ctx: AbstractContextManager[None, None] |
| 174 | + else: |
| 175 | + ctx: AbstractContextManager |
| 176 | + if hasattr(torch.serialization, "safe_globals"): |
| 177 | + safe_globals = [ |
| 178 | + # pyre-ignore[16]: Module `numpy.core.multiarray` has no attribute |
| 179 | + # `_reconstruct` |
| 180 | + np.core.multiarray._reconstruct, # type: ignore[attr-defined] |
| 181 | + np.ndarray, |
| 182 | + np.dtype, |
| 183 | + ] |
| 184 | + if hasattr(np, "dtypes"): |
| 185 | + # pyre-ignore[16]: Module `numpy` has no attribute `dtypes`. |
| 186 | + safe_globals.extend([np.dtypes.UInt32DType, np.dtypes.Int32DType]) |
| 187 | + ctx = torch.serialization.safe_globals(safe_globals) |
| 188 | + else: |
| 189 | + # safe globals not in existence in this version of torch yet. Use a |
| 190 | + # dummy context manager instead |
| 191 | + ctx = nullcontext() |
| 192 | + with ctx: |
| 193 | + save_dict = torch.load(cavs_path) |
170 | 194 |
|
171 | 195 | concept_names = save_dict["concept_names"] |
172 | 196 | concept_ids = save_dict["concept_ids"] |
|
0 commit comments