Skip to content

Commit 2fe04f2

Browse files
sfeng33Mu Huai
authored andcommitted
[Model] Use autoweightloader for mamba (vllm-project#16950)
Signed-off-by: sfeng33 <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 12ab34c commit 2fe04f2

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

vllm/model_executor/models/mamba.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.sequence import IntermediateTensors
2828
from vllm.utils import LayerBlockType
2929

30-
from .utils import (is_pp_missing_parameter,
30+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3131
make_empty_intermediate_tensors_factory, make_layers,
3232
maybe_prefix)
3333

@@ -154,6 +154,26 @@ def forward(
154154

155155
return hidden_states
156156

157+
def load_weights(self, weights: Iterable[Tuple[str,
158+
torch.Tensor]]) -> Set[str]:
159+
params_dict = dict(self.named_parameters())
160+
loaded_params: Set[str] = set()
161+
for name, loaded_weight in weights:
162+
if "A_log" in name:
163+
name = name.replace("A_log", "A")
164+
# Skip loading extra bias for GPTQ models.
165+
if name.endswith(".bias") and name not in params_dict:
166+
continue
167+
if is_pp_missing_parameter(name, self):
168+
continue
169+
170+
param = params_dict[name]
171+
weight_loader = getattr(param, "weight_loader",
172+
default_weight_loader)
173+
weight_loader(param, loaded_weight)
174+
loaded_params.add(name)
175+
return loaded_params
176+
157177

158178
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
159179
SupportsV0Only):
@@ -257,20 +277,5 @@ def sample(
257277

258278
def load_weights(self, weights: Iterable[Tuple[str,
259279
torch.Tensor]]) -> Set[str]:
260-
params_dict = dict(self.named_parameters())
261-
loaded_params: Set[str] = set()
262-
for name, loaded_weight in weights:
263-
if "A_log" in name:
264-
name = name.replace("A_log", "A")
265-
# Skip loading extra bias for GPTQ models.
266-
if name.endswith(".bias") and name not in params_dict:
267-
continue
268-
if is_pp_missing_parameter(name, self):
269-
continue
270-
271-
param = params_dict[name]
272-
weight_loader = getattr(param, "weight_loader",
273-
default_weight_loader)
274-
weight_loader(param, loaded_weight)
275-
loaded_params.add(name)
276-
return loaded_params
280+
loader = AutoWeightsLoader(self)
281+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)