|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import gc |
| 8 | +import json |
| 9 | +from collections import defaultdict |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | +import torch.nn as nn |
| 13 | +from huggingface_hub import repo_exists, snapshot_download |
| 14 | +from safetensors import safe_open |
| 15 | +from torch.distributed.tensor import distribute_tensor, DTensor |
| 16 | +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME |
| 17 | + |
| 18 | +from torchtitan.tools.logging import logger |
| 19 | + |
| 20 | +INDEX_NAME_MAPPING = { |
| 21 | + "safetensors": SAFE_WEIGHTS_INDEX_NAME, |
| 22 | +} |
| 23 | + |
| 24 | +PATTERNS_TO_REMOVE = [ |
| 25 | + "._orig_mod", # Some optimizers add suffixes |
| 26 | + "._fsdp_wrapped_module", # FSDP wrapper |
| 27 | + "._checkpoint_wrapped_module", # checkpoint wrapper |
| 28 | + ".module", # DataParallel/DistributedDataParallel |
| 29 | + "_module.", # Some wrappers add prefix |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +def normalize_state_dict_key( |
| 34 | + key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE |
| 35 | +) -> str: |
| 36 | + """ |
| 37 | + Normalize the state dict key, remove the prefix or suffix added by various wrappers. |
| 38 | + Args: |
| 39 | + key: The original state dict key |
| 40 | + Returns: |
| 41 | + The normalized key |
| 42 | + """ |
| 43 | + normalized_key = key |
| 44 | + for pattern in patterns_to_remove: |
| 45 | + normalized_key = normalized_key.replace(pattern, "") |
| 46 | + |
| 47 | + return normalized_key |
| 48 | + |
| 49 | + |
| 50 | +def get_weight_map(pretrained_model_path: Path) -> dict[str, str]: |
| 51 | + """ |
| 52 | + Get the weight map from the pretrained model. |
| 53 | + Args: |
| 54 | + pretrained_model_path: The path to the pretrained model. |
| 55 | + Returns: |
| 56 | + weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys. |
| 57 | + """ |
| 58 | + index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"] |
| 59 | + if not index_file.exists(): |
| 60 | + return None |
| 61 | + with open(index_file, "r") as f: |
| 62 | + metadata = json.load(f) |
| 63 | + return metadata["weight_map"] |
| 64 | + |
| 65 | + |
| 66 | +def group_state_dict_keys_and_st_partition_paths( |
| 67 | + pretrained_model_path: Path, |
| 68 | + state_dict_keys, |
| 69 | + weight_map, |
| 70 | + state_dict_map: dict[str, str] = None, |
| 71 | +): |
| 72 | + """ |
| 73 | + Group state dict keys and save them to a file. |
| 74 | + Args: |
| 75 | + pretrained_model_path: The path to the pretrained model. |
| 76 | + state_dict_keys: The state dict keys to group. |
| 77 | + weight_map: The weight map. |
| 78 | + state_dict_map: A dictionary mapping from the state dict key to the weight path. |
| 79 | + Returns: |
| 80 | + st_partition_map: A dictionary mapping from the weight path to the list of state dict keys. |
| 81 | + """ |
| 82 | + st_partition_map = defaultdict(list) |
| 83 | + for state_dict_key in state_dict_keys: |
| 84 | + ckpt_state_dict_key = ( |
| 85 | + state_dict_map[state_dict_key] |
| 86 | + if state_dict_map is not None |
| 87 | + else state_dict_key |
| 88 | + ) |
| 89 | + if weight_map is None: |
| 90 | + partition_path = pretrained_model_path / "model.safetensors" |
| 91 | + else: |
| 92 | + partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key] |
| 93 | + st_partition_map[partition_path].append(state_dict_key) |
| 94 | + return st_partition_map |
| 95 | + |
| 96 | + |
| 97 | +def load_sharded_state_dict_for_model_from_path( |
| 98 | + pretrained_model_path: Path, |
| 99 | + model: nn.Module, |
| 100 | + mapping_dict: dict[str, str] = None, |
| 101 | + **kwargs, |
| 102 | +): |
| 103 | + """ |
| 104 | + Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. |
| 105 | + Args: |
| 106 | + pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path. |
| 107 | + model: The model to load the state dict into. |
| 108 | + **kwargs: other arguments for torch.nn.Module.load_state_dict |
| 109 | + """ |
| 110 | + # check exceptions |
| 111 | + if not pretrained_model_path.exists(): |
| 112 | + raise ValueError( |
| 113 | + f"The pretrained model path {pretrained_model_path} does not exist." |
| 114 | + ) |
| 115 | + if not pretrained_model_path.is_dir(): |
| 116 | + raise ValueError( |
| 117 | + f"The pretrained model path {pretrained_model_path} is not a directory." |
| 118 | + ) |
| 119 | + # get the weight map |
| 120 | + weight_map = get_weight_map(pretrained_model_path) |
| 121 | + model_state_dict = model.state_dict() |
| 122 | + model_state_dict_keys = list(model_state_dict.keys()) |
| 123 | + |
| 124 | + # create a mapping_dict between the original state_dict_key and the weight_map_key if not provided |
| 125 | + mapping_dict = ( |
| 126 | + mapping_dict |
| 127 | + if mapping_dict is not None |
| 128 | + else {key: normalize_state_dict_key(key) for key in model_state_dict_keys} |
| 129 | + ) |
| 130 | + st_partition_map = group_state_dict_keys_and_st_partition_paths( |
| 131 | + pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict |
| 132 | + ) |
| 133 | + |
| 134 | + # get the sharded state dict |
| 135 | + state_dict = {} |
| 136 | + for safetensor_partition_path, state_dict_keys in st_partition_map.items(): |
| 137 | + with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f: |
| 138 | + for state_dict_key in state_dict_keys: |
| 139 | + model_tensor = model_state_dict[state_dict_key] |
| 140 | + ckpt_state_dict_key = mapping_dict[state_dict_key] |
| 141 | + if isinstance(model_tensor, DTensor): |
| 142 | + local_tensor = f.get_tensor(ckpt_state_dict_key) |
| 143 | + state_dict[state_dict_key] = distribute_tensor( |
| 144 | + local_tensor, |
| 145 | + model_tensor.device_mesh, |
| 146 | + model_tensor.placements, |
| 147 | + ) |
| 148 | + else: |
| 149 | + state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key) |
| 150 | + model.load_state_dict(state_dict, **kwargs) |
| 151 | + del state_dict |
| 152 | + gc.collect() |
| 153 | + |
| 154 | + |
| 155 | +def load_sharded_state_dict_for_model_from_hf( |
| 156 | + pretrained_model_id_or_path: str, |
| 157 | + model: nn.Module, |
| 158 | + **kwargs, |
| 159 | +): |
| 160 | + """ |
| 161 | + Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. |
| 162 | + Args: |
| 163 | + pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface, |
| 164 | + or a local path |
| 165 | + model: The model to load the state dict into. |
| 166 | + **kwargs: other arguments for torch.nn.Module.load_state_dict |
| 167 | + """ |
| 168 | + logger.info(f"Loading the state dict from {pretrained_model_id_or_path}") |
| 169 | + pretrained_model_id_or_path = Path(pretrained_model_id_or_path) |
| 170 | + if not pretrained_model_id_or_path.exists(): |
| 171 | + if not repo_exists(str(pretrained_model_id_or_path)): |
| 172 | + raise ValueError( |
| 173 | + f"The pretrained model {pretrained_model_id_or_path} does not exist" |
| 174 | + ) |
| 175 | + logger.info( |
| 176 | + f"Try to download the model from huggingface: {pretrained_model_id_or_path}" |
| 177 | + ) |
| 178 | + pretrained_model_path = Path( |
| 179 | + snapshot_download(str(pretrained_model_id_or_path)) |
| 180 | + ) |
| 181 | + elif not pretrained_model_id_or_path.is_dir(): |
| 182 | + raise ValueError( |
| 183 | + f"The pretrained model path {pretrained_model_id_or_path} is not a directory." |
| 184 | + ) |
| 185 | + else: |
| 186 | + pretrained_model_path = pretrained_model_id_or_path |
| 187 | + |
| 188 | + load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs) |
0 commit comments