|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2023 Intel Corporation |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | +"""Load one specify tensor from a bin file.""" |
| 18 | + |
| 19 | +import io |
| 20 | +import os |
| 21 | +import warnings |
| 22 | +from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union |
| 23 | + |
| 24 | +from packaging.version import Version |
| 25 | +from torch.serialization import ( |
| 26 | + StorageType, |
| 27 | + _get_restore_location, |
| 28 | + _is_torchscript_zip, |
| 29 | + _is_zipfile, |
| 30 | + _maybe_decode_ascii, |
| 31 | + _open_file_like, |
| 32 | + _open_zipfile_reader, |
| 33 | +) |
| 34 | + |
| 35 | +from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle |
| 36 | + |
| 37 | +from .utils import torch |
| 38 | + |
| 39 | +torch_version = torch.__version__.split("+")[0] |
| 40 | +version = Version(torch_version) |
| 41 | + |
| 42 | +FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]] |
| 43 | +MAP_LOCATION = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] |
| 44 | + |
| 45 | +if version.release < Version("1.13.0").release: |
| 46 | + UntypedStorage = torch._UntypedStorage |
| 47 | +else: |
| 48 | + UntypedStorage = torch.UntypedStorage |
| 49 | + |
| 50 | + |
| 51 | +def _load(zip_file, tensor_name, prefix, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args): |
| 52 | + restore_location = _get_restore_location(map_location) |
| 53 | + |
| 54 | + loaded_storages = {} |
| 55 | + |
| 56 | + def load_tensor(dtype, numel, key, location): |
| 57 | + name = f"data/{key}" |
| 58 | + |
| 59 | + if version.release < Version("1.13.0").release: |
| 60 | + storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped() |
| 61 | + typed_storage = torch.storage._TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) |
| 62 | + loaded_storages[key] = typed_storage |
| 63 | + elif version.release < Version("2.0.0").release: # pragma: no cover |
| 64 | + storage = zip_file.get_storage_from_record(name, numel, UntypedStorage).storage().untyped() |
| 65 | + typed_storage = torch.storage.TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) |
| 66 | + loaded_storages[key] = typed_storage |
| 67 | + else: |
| 68 | + storage = zip_file.get_storage_from_record(name, numel, UntypedStorage)._typed_storage()._untyped_storage |
| 69 | + typed_storage = torch.storage.TypedStorage( |
| 70 | + wrap_storage=restore_location(storage, location), dtype=dtype, _internal=True |
| 71 | + ) |
| 72 | + |
| 73 | + if typed_storage._data_ptr() != 0: |
| 74 | + loaded_storages[key] = typed_storage |
| 75 | + |
| 76 | + return typed_storage |
| 77 | + |
| 78 | + load_module_mapping: Dict[str, str] = {"torch.tensor": "torch._tensor"} |
| 79 | + |
| 80 | + # Need to subclass Unpickler instead of directly monkey-patching the find_class method |
| 81 | + # because it's marked readonly in pickle. |
| 82 | + # The type: ignore is because mypy can't statically determine the type of this class. |
| 83 | + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] |
| 84 | + def find_class(self, mod_name, name): |
| 85 | + if type(name) is str and "Storage" in name: |
| 86 | + try: |
| 87 | + return StorageType(name) |
| 88 | + except KeyError: # pragma: no cover |
| 89 | + pass |
| 90 | + mod_name = load_module_mapping.get(mod_name, mod_name) |
| 91 | + return super().find_class(mod_name, name) |
| 92 | + |
| 93 | + def persistent_load(self, saved_id): |
| 94 | + assert isinstance(saved_id, tuple) |
| 95 | + typename = _maybe_decode_ascii(saved_id[0]) |
| 96 | + data = saved_id[1:] |
| 97 | + |
| 98 | + assert ( |
| 99 | + typename == "storage" |
| 100 | + ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" |
| 101 | + storage_type, key, location, numel = data |
| 102 | + |
| 103 | + if storage_type is UntypedStorage: # pragma: no cover |
| 104 | + dtype = torch.uint8 |
| 105 | + else: |
| 106 | + dtype = storage_type.dtype |
| 107 | + |
| 108 | + if key in loaded_storages: |
| 109 | + typed_storage = loaded_storages[key] |
| 110 | + else: |
| 111 | + name_list = [self.tensor_name] |
| 112 | + if prefix: |
| 113 | + no_prefix_name = self.tensor_name.split(".") |
| 114 | + if prefix in no_prefix_name: |
| 115 | + no_prefix_name.remove(prefix) |
| 116 | + no_prefix_name = ".".join(no_prefix_name) |
| 117 | + name_list.append(no_prefix_name) |
| 118 | + if self.tensor_name and self.metastack[-1][-2] not in name_list: |
| 119 | + # typed_storage = None |
| 120 | + # loaded_storages[key] = typed_storage |
| 121 | + # nbytes = numel * torch._utils._element_size(dtype) |
| 122 | + # typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) |
| 123 | + typed_storage = None |
| 124 | + else: |
| 125 | + nbytes = numel * torch._utils._element_size(dtype) |
| 126 | + typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) |
| 127 | + |
| 128 | + return typed_storage |
| 129 | + |
| 130 | + # Load the data (which may in turn use `persistent_load` to load tensors) |
| 131 | + data_file = io.BytesIO(zip_file.get_record(pickle_file)) |
| 132 | + |
| 133 | + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) |
| 134 | + # unpickler.persistent_load = persistent_load |
| 135 | + result = unpickler.load(tensor_name) |
| 136 | + |
| 137 | + torch._utils._validate_loaded_sparse_tensors() |
| 138 | + return result |
| 139 | + |
| 140 | + |
| 141 | +def load( |
| 142 | + f: FILE_LIKE, |
| 143 | + tensor_name: str = None, |
| 144 | + prefix: str = None, |
| 145 | + map_location: MAP_LOCATION = None, |
| 146 | + pickle_module: Any = None, |
| 147 | + *, |
| 148 | + weights_only: bool = False, |
| 149 | + **pickle_load_args: Any, |
| 150 | +) -> Any: |
| 151 | + # Reference: https://github.com/pytorch/pytorch/issues/54354 |
| 152 | + # The first line of this docstring overrides the one Sphinx generates for the |
| 153 | + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from |
| 154 | + # the build environment (e.g. `<module 'pickle' from '/leaked/path'). |
| 155 | + |
| 156 | + """Load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args) |
| 157 | +
|
| 158 | + Loads an object saved with :func:`torch.save` from a file. |
| 159 | +
|
| 160 | + :func:`torch.load` uses Python's unpickling facilities but treats storages, |
| 161 | + which underlie tensors, specially. They are first deserialized on the |
| 162 | + CPU and are then moved to the device they were saved from. If this fails |
| 163 | + (e.g. because the run time system doesn't have certain devices), an exception |
| 164 | + is raised. However, storages can be dynamically remapped to an alternative |
| 165 | + set of devices using the :attr:`map_location` argument. |
| 166 | +
|
| 167 | + If :attr:`map_location` is a callable, it will be called once for each serialized |
| 168 | + storage with two arguments: storage and location. The storage argument |
| 169 | + will be the initial deserialization of the storage, residing on the CPU. |
| 170 | + Each serialized storage has a location tag associated with it which |
| 171 | + identifies the device it was saved from, and this tag is the second |
| 172 | + argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` |
| 173 | + for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. |
| 174 | + :attr:`map_location` should return either ``None`` or a storage. If |
| 175 | + :attr:`map_location` returns a storage, it will be used as the final deserialized |
| 176 | + object, already moved to the right device. Otherwise, :func:`torch.load` will |
| 177 | + fall back to the default behavior, as if :attr:`map_location` wasn't specified. |
| 178 | +
|
| 179 | + If :attr:`map_location` is a :class:`torch.device` object or a string containing |
| 180 | + a device tag, it indicates the location where all tensors should be loaded. |
| 181 | +
|
| 182 | + Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags |
| 183 | + appearing in the file (keys), to ones that specify where to put the |
| 184 | + storages (values). |
| 185 | +
|
| 186 | + User extensions can register their own location tags and tagging and |
| 187 | + deserialization methods using :func:`torch.serialization.register_package`. |
| 188 | +
|
| 189 | + Args: |
| 190 | + f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), |
| 191 | + or a string or os.PathLike object containing a file name |
| 192 | + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage |
| 193 | + locations |
| 194 | + pickle_module: module used for unpickling metadata and objects (has to |
| 195 | + match the :attr:`pickle_module` used to serialize file) |
| 196 | + weights_only: Indicates whether unpickler should be restricted to |
| 197 | + loading only tensors, primitive types and dictionaries |
| 198 | + pickle_load_args: (Python 3 only) optional keyword arguments passed over to |
| 199 | + :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., |
| 200 | + :attr:`errors=...`. |
| 201 | +
|
| 202 | + .. warning:: |
| 203 | + :func:`torch.load()` unless `weights_only` parameter is set to `True`, |
| 204 | + uses ``pickle`` module implicitly, which is known to be insecure. |
| 205 | + It is possible to construct malicious pickle data which will execute arbitrary code |
| 206 | + during unpickling. Never load data that could have come from an untrusted |
| 207 | + source in an unsafe mode, or that could have been tampered with. **Only load data you trust**. |
| 208 | +
|
| 209 | + .. note:: |
| 210 | + When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors |
| 211 | + will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` |
| 212 | + and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. |
| 213 | +
|
| 214 | + .. note:: |
| 215 | + By default, we decode byte strings as ``utf-8``. This is to avoid a common error |
| 216 | + case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` |
| 217 | + when loading files saved by Python 2 in Python 3. If this default |
| 218 | + is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how |
| 219 | + these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them |
| 220 | + to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them |
| 221 | + as byte arrays which can be decoded later with ``byte_array.decode(...)``. |
| 222 | +
|
| 223 | + Example: |
| 224 | + >>> # xdoctest: +SKIP("undefined filepaths") |
| 225 | + >>> torch.load('tensors.pt') |
| 226 | + # Load all tensors onto the CPU |
| 227 | + >>> torch.load('tensors.pt', map_location=torch.device('cpu')) |
| 228 | + # Load all tensors onto the CPU, using a function |
| 229 | + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) |
| 230 | + # Load all tensors onto GPU 1 |
| 231 | + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) |
| 232 | + # Map tensors from GPU 1 to GPU 0 |
| 233 | + >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}) |
| 234 | + # Load tensor from io.BytesIO object |
| 235 | + >>> with open('tensor.pt', 'rb') as f: |
| 236 | + ... buffer = io.BytesIO(f.read()) |
| 237 | + >>> torch.load(buffer) |
| 238 | + # Load a module with 'ascii' encoding for unpickling |
| 239 | + >>> torch.load('module.pt', encoding='ascii') |
| 240 | + """ |
| 241 | + torch._C._log_api_usage_once("torch.load") |
| 242 | + # Add ability to force safe only weight loads via environment variable |
| 243 | + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ["1", "y", "yes", "true"]: # pragma: no cover |
| 244 | + weights_only = True |
| 245 | + |
| 246 | + if weights_only: # pragma: no cover |
| 247 | + if pickle_module is not None: |
| 248 | + raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") |
| 249 | + else: |
| 250 | + if pickle_module is None: |
| 251 | + pickle_module = pickle |
| 252 | + |
| 253 | + if "encoding" not in pickle_load_args.keys(): |
| 254 | + pickle_load_args["encoding"] = "utf-8" |
| 255 | + |
| 256 | + with _open_file_like(f, "rb") as opened_file: |
| 257 | + if _is_zipfile(opened_file): |
| 258 | + # The zipfile reader is going to advance the current file position. |
| 259 | + # If we want to actually tail call to torch.jit.load, we need to |
| 260 | + # reset back to the original position. |
| 261 | + orig_position = opened_file.tell() |
| 262 | + with _open_zipfile_reader(opened_file) as opened_zipfile: |
| 263 | + if _is_torchscript_zip(opened_zipfile): # pragma: no cover |
| 264 | + warnings.warn( |
| 265 | + "'torch.load' received a zip file that looks like a TorchScript archive" |
| 266 | + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" |
| 267 | + " silence this warning)", |
| 268 | + UserWarning, |
| 269 | + ) |
| 270 | + opened_file.seek(orig_position) |
| 271 | + return torch.jit.load(opened_file, map_location=map_location) |
| 272 | + return _load(opened_zipfile, tensor_name, prefix, map_location, pickle_module, **pickle_load_args) |
0 commit comments