Skip to content

Commit 90dfd01

Browse files
Isotr0pyAlvant
authored andcommitted
[Model] Refactor and decouple weight loading logic for InternVL2 model (vllm-project#7067)
Signed-off-by: Alvant <[email protected]>
1 parent c155a82 commit 90dfd01

File tree

2 files changed

+38
-55
lines changed

2 files changed

+38
-55
lines changed

vllm/model_executor/models/intern_vit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Copyright (c) 2023 OpenGVLab
55
# Licensed under The MIT License [see LICENSE for details]
66
# --------------------------------------------------------
7-
from typing import Optional
7+
from typing import Iterable, Optional, Tuple
88

99
import torch
1010
import torch.nn as nn
@@ -16,6 +16,7 @@
1616
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1717
RowParallelLinear)
1818
from vllm.model_executor.layers.quantization import QuantizationConfig
19+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1920

2021
NORM2FN = {
2122
'rms_norm': RMSNorm,
@@ -268,3 +269,11 @@ def forward(
268269
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
269270

270271
return encoder_outputs
272+
273+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
274+
params_dict = dict(self.named_parameters())
275+
for name, loaded_weight in weights:
276+
param = params_dict[name]
277+
weight_loader = getattr(param, "weight_loader",
278+
default_weight_loader)
279+
weight_loader(param, loaded_weight)

vllm/model_executor/models/internvl.py

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright (c) 2023 OpenGVLab
55
# Licensed under The MIT License [see LICENSE for details]
66
# --------------------------------------------------------
7+
import itertools
78
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
89

910
import torch
@@ -414,58 +415,31 @@ def sample(
414415
) -> Optional[SamplerOutput]:
415416
return self.language_model.sample(logits, sampling_metadata)
416417

417-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
418-
stacked_params_mapping = [
419-
# (param_name, shard_name, shard_id)
420-
(".qkv_proj", ".q_proj", "q"),
421-
(".qkv_proj", ".k_proj", "k"),
422-
(".qkv_proj", ".v_proj", "v"),
423-
(".gate_up_proj", ".gate_proj", 0),
424-
(".gate_up_proj", ".up_proj", 1),
425-
(".gate_up_proj", ".w1", 0),
426-
(".gate_up_proj", ".w3", 1),
427-
]
428-
params_dict = dict(self.named_parameters())
418+
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
419+
prefix: str):
429420
for name, loaded_weight in weights:
430-
if "rotary_emb.inv_freq" in name:
431-
continue
432-
if self.config.text_config.tie_word_embeddings \
433-
and "lm_head.weight" in name:
434-
continue
435-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
436-
# We only do sharding for language model
437-
# and not vision model for now.
438-
if "vision_embed_tokens" in name and self.vision_embed_tokens:
439-
continue
440-
if weight_name not in name:
441-
continue
442-
param = params_dict[name.replace(weight_name, param_name)]
443-
weight_loader = param.weight_loader
444-
weight_loader(param, loaded_weight, shard_id)
445-
break
446-
else:
447-
# Skip loading extra bias for GPTQ models.
448-
if name.endswith(".bias") and name not in params_dict:
449-
continue
450-
param = params_dict[name]
451-
if "wqkv" in name:
452-
config = self.config.text_config
453-
kv_groups = (config.num_attention_heads //
454-
config.num_key_value_heads)
455-
head_dim = config.hidden_size // config.num_attention_heads
456-
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
457-
head_dim,
458-
loaded_weight.shape[-1])
459-
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
460-
dim=1)
461-
wq = wq.reshape(-1, wq.shape[-1])
462-
wk = wk.reshape(-1, wk.shape[-1])
463-
wv = wv.reshape(-1, wv.shape[-1])
464-
weight_loader = param.weight_loader
465-
weight_loader(param, wq, 'q')
466-
weight_loader(param, wk, 'k')
467-
weight_loader(param, wv, 'v')
468-
continue
469-
weight_loader = getattr(param, "weight_loader",
470-
default_weight_loader)
471-
weight_loader(param, loaded_weight)
421+
name = name.split(".")
422+
if prefix == name.pop(0):
423+
name = ".".join(name)
424+
yield name, loaded_weight
425+
426+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
427+
# prepare weight iterators for components
428+
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
429+
430+
# load vision encoder
431+
vit_weights = self._filter_weights(vit_weights, "vision_model")
432+
self.vision_model.load_weights(vit_weights)
433+
434+
# load mlp projector
435+
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
436+
mlp_params_dict = dict(self.mlp1.named_parameters())
437+
for name, loaded_weight in mlp_weights:
438+
param = mlp_params_dict[name]
439+
weight_loader = getattr(param, "weight_loader",
440+
default_weight_loader)
441+
weight_loader(param, loaded_weight)
442+
443+
# load llm backbone
444+
llm_weights = self._filter_weights(llm_weights, "language_model")
445+
self.language_model.load_weights(llm_weights)

0 commit comments

Comments
 (0)