|
4 | 4 | # Copyright (c) 2023 OpenGVLab
|
5 | 5 | # Licensed under The MIT License [see LICENSE for details]
|
6 | 6 | # --------------------------------------------------------
|
| 7 | +import itertools |
7 | 8 | from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
8 | 9 |
|
9 | 10 | import torch
|
@@ -414,58 +415,31 @@ def sample(
|
414 | 415 | ) -> Optional[SamplerOutput]:
|
415 | 416 | return self.language_model.sample(logits, sampling_metadata)
|
416 | 417 |
|
417 |
| - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
418 |
| - stacked_params_mapping = [ |
419 |
| - # (param_name, shard_name, shard_id) |
420 |
| - (".qkv_proj", ".q_proj", "q"), |
421 |
| - (".qkv_proj", ".k_proj", "k"), |
422 |
| - (".qkv_proj", ".v_proj", "v"), |
423 |
| - (".gate_up_proj", ".gate_proj", 0), |
424 |
| - (".gate_up_proj", ".up_proj", 1), |
425 |
| - (".gate_up_proj", ".w1", 0), |
426 |
| - (".gate_up_proj", ".w3", 1), |
427 |
| - ] |
428 |
| - params_dict = dict(self.named_parameters()) |
| 418 | + def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], |
| 419 | + prefix: str): |
429 | 420 | for name, loaded_weight in weights:
|
430 |
| - if "rotary_emb.inv_freq" in name: |
431 |
| - continue |
432 |
| - if self.config.text_config.tie_word_embeddings \ |
433 |
| - and "lm_head.weight" in name: |
434 |
| - continue |
435 |
| - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
436 |
| - # We only do sharding for language model |
437 |
| - # and not vision model for now. |
438 |
| - if "vision_embed_tokens" in name and self.vision_embed_tokens: |
439 |
| - continue |
440 |
| - if weight_name not in name: |
441 |
| - continue |
442 |
| - param = params_dict[name.replace(weight_name, param_name)] |
443 |
| - weight_loader = param.weight_loader |
444 |
| - weight_loader(param, loaded_weight, shard_id) |
445 |
| - break |
446 |
| - else: |
447 |
| - # Skip loading extra bias for GPTQ models. |
448 |
| - if name.endswith(".bias") and name not in params_dict: |
449 |
| - continue |
450 |
| - param = params_dict[name] |
451 |
| - if "wqkv" in name: |
452 |
| - config = self.config.text_config |
453 |
| - kv_groups = (config.num_attention_heads // |
454 |
| - config.num_key_value_heads) |
455 |
| - head_dim = config.hidden_size // config.num_attention_heads |
456 |
| - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, |
457 |
| - head_dim, |
458 |
| - loaded_weight.shape[-1]) |
459 |
| - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], |
460 |
| - dim=1) |
461 |
| - wq = wq.reshape(-1, wq.shape[-1]) |
462 |
| - wk = wk.reshape(-1, wk.shape[-1]) |
463 |
| - wv = wv.reshape(-1, wv.shape[-1]) |
464 |
| - weight_loader = param.weight_loader |
465 |
| - weight_loader(param, wq, 'q') |
466 |
| - weight_loader(param, wk, 'k') |
467 |
| - weight_loader(param, wv, 'v') |
468 |
| - continue |
469 |
| - weight_loader = getattr(param, "weight_loader", |
470 |
| - default_weight_loader) |
471 |
| - weight_loader(param, loaded_weight) |
| 421 | + name = name.split(".") |
| 422 | + if prefix == name.pop(0): |
| 423 | + name = ".".join(name) |
| 424 | + yield name, loaded_weight |
| 425 | + |
| 426 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 427 | + # prepare weight iterators for components |
| 428 | + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) |
| 429 | + |
| 430 | + # load vision encoder |
| 431 | + vit_weights = self._filter_weights(vit_weights, "vision_model") |
| 432 | + self.vision_model.load_weights(vit_weights) |
| 433 | + |
| 434 | + # load mlp projector |
| 435 | + mlp_weights = self._filter_weights(mlp_weights, "mlp1") |
| 436 | + mlp_params_dict = dict(self.mlp1.named_parameters()) |
| 437 | + for name, loaded_weight in mlp_weights: |
| 438 | + param = mlp_params_dict[name] |
| 439 | + weight_loader = getattr(param, "weight_loader", |
| 440 | + default_weight_loader) |
| 441 | + weight_loader(param, loaded_weight) |
| 442 | + |
| 443 | + # load llm backbone |
| 444 | + llm_weights = self._filter_weights(llm_weights, "language_model") |
| 445 | + self.language_model.load_weights(llm_weights) |
0 commit comments