|  | 
| 27 | 27 | from vllm.sequence import IntermediateTensors | 
| 28 | 28 | from vllm.utils import LayerBlockType | 
| 29 | 29 | 
 | 
| 30 |  | -from .utils import (is_pp_missing_parameter, | 
|  | 30 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, | 
| 31 | 31 |                     make_empty_intermediate_tensors_factory, make_layers, | 
| 32 | 32 |                     maybe_prefix) | 
| 33 | 33 | 
 | 
| @@ -154,6 +154,26 @@ def forward( | 
| 154 | 154 | 
 | 
| 155 | 155 |         return hidden_states | 
| 156 | 156 | 
 | 
|  | 157 | +    def load_weights(self, weights: Iterable[Tuple[str, | 
|  | 158 | +                                                   torch.Tensor]]) -> Set[str]: | 
|  | 159 | +        params_dict = dict(self.named_parameters()) | 
|  | 160 | +        loaded_params: Set[str] = set() | 
|  | 161 | +        for name, loaded_weight in weights: | 
|  | 162 | +            if "A_log" in name: | 
|  | 163 | +                name = name.replace("A_log", "A") | 
|  | 164 | +            # Skip loading extra bias for GPTQ models. | 
|  | 165 | +            if name.endswith(".bias") and name not in params_dict: | 
|  | 166 | +                continue | 
|  | 167 | +            if is_pp_missing_parameter(name, self): | 
|  | 168 | +                continue | 
|  | 169 | + | 
|  | 170 | +            param = params_dict[name] | 
|  | 171 | +            weight_loader = getattr(param, "weight_loader", | 
|  | 172 | +                                    default_weight_loader) | 
|  | 173 | +            weight_loader(param, loaded_weight) | 
|  | 174 | +            loaded_params.add(name) | 
|  | 175 | +        return loaded_params | 
|  | 176 | + | 
| 157 | 177 | 
 | 
| 158 | 178 | class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, | 
| 159 | 179 |                        SupportsV0Only): | 
| @@ -257,20 +277,5 @@ def sample( | 
| 257 | 277 | 
 | 
| 258 | 278 |     def load_weights(self, weights: Iterable[Tuple[str, | 
| 259 | 279 |                                                    torch.Tensor]]) -> Set[str]: | 
| 260 |  | -        params_dict = dict(self.named_parameters()) | 
| 261 |  | -        loaded_params: Set[str] = set() | 
| 262 |  | -        for name, loaded_weight in weights: | 
| 263 |  | -            if "A_log" in name: | 
| 264 |  | -                name = name.replace("A_log", "A") | 
| 265 |  | -            # Skip loading extra bias for GPTQ models. | 
| 266 |  | -            if name.endswith(".bias") and name not in params_dict: | 
| 267 |  | -                continue | 
| 268 |  | -            if is_pp_missing_parameter(name, self): | 
| 269 |  | -                continue | 
| 270 |  | - | 
| 271 |  | -            param = params_dict[name] | 
| 272 |  | -            weight_loader = getattr(param, "weight_loader", | 
| 273 |  | -                                    default_weight_loader) | 
| 274 |  | -            weight_loader(param, loaded_weight) | 
| 275 |  | -            loaded_params.add(name) | 
| 276 |  | -        return loaded_params | 
|  | 280 | +        loader = AutoWeightsLoader(self) | 
|  | 281 | +        return loader.load_weights(weights) | 
0 commit comments