|
28 | 28 | get_tensor_model_parallel_world_size)
|
29 | 29 | from vllm.envs import VLLM_USE_MODELSCOPE
|
30 | 30 | from vllm.logger import init_logger
|
| 31 | +from vllm.model_executor.layers.linear import ReplicatedLinear |
31 | 32 | from vllm.model_executor.layers.quantization.base_config import (
|
32 | 33 | QuantizationConfig)
|
33 | 34 | from vllm.model_executor.model_loader.tensorizer import (
|
@@ -786,6 +787,7 @@ def __init__(self, load_config: LoadConfig):
|
786 | 787 | with open(config_file_path, "r") as f:
|
787 | 788 | config = json.load(f)
|
788 | 789 | self.target_modules = config["target_modules"]
|
| 790 | + self.wo_sharded_weights_modules: List[str] = [] |
789 | 791 |
|
790 | 792 | def _get_config_file(self, qlora_adapter: str) -> str:
|
791 | 793 | is_local = os.path.isdir(qlora_adapter)
|
@@ -1005,16 +1007,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
1005 | 1007 | if any(target_module in weight_name for target_module in
|
1006 | 1008 | self.target_modules) and weight_name.endswith(".weight"):
|
1007 | 1009 | weight_name = weight_name.replace(".weight", ".qweight")
|
1008 |
| - |
1009 |
| - if any(module in weight_name |
1010 |
| - for module in self.column_parallel_weights_modules): |
| 1010 | + # Without sharding |
| 1011 | + if any( |
| 1012 | + weight_name.startswith(module) |
| 1013 | + for module in self.wo_sharded_weights_modules): |
| 1014 | + weight_sub_tensor = weight_tensor |
| 1015 | + # Shard by column |
| 1016 | + elif any(module in weight_name |
| 1017 | + for module in self.column_parallel_weights_modules): |
1011 | 1018 |
|
1012 | 1019 | total_size = weight_tensor.size(-1)
|
1013 | 1020 | start_index = total_size // tp_size * tp_rank
|
1014 | 1021 | end_index = total_size // tp_size * (tp_rank + 1)
|
1015 | 1022 | weight_sub_tensor = weight_tensor[...,
|
1016 | 1023 | start_index:end_index]
|
1017 |
| - |
| 1024 | + # Shard by row |
1018 | 1025 | else:
|
1019 | 1026 | total_size = weight_tensor.size(0)
|
1020 | 1027 | start_index = total_size // tp_size * tp_rank
|
@@ -1068,7 +1075,15 @@ def _load_weights(self, model_config: ModelConfig,
|
1068 | 1075 | model.column_parallel_weights_modules
|
1069 | 1076 | else:
|
1070 | 1077 | self.column_parallel_weights_modules = []
|
1071 |
| - |
| 1078 | + # Some modules like `ReplicatedLinear` should not have their weights |
| 1079 | + # sharded. The reason for implementing it this way is to avoid new |
| 1080 | + # static variable in the model implementation. |
| 1081 | + # TODO: Can we reduce the static variables needed for BNB based on |
| 1082 | + # model information? |
| 1083 | + self.wo_sharded_weights_modules = [ |
| 1084 | + name for name, module in model.named_modules() |
| 1085 | + if isinstance(module, (ReplicatedLinear, )) |
| 1086 | + ] |
1072 | 1087 | self.model_type = type(model).__name__
|
1073 | 1088 |
|
1074 | 1089 | logger.info("Loading weights with BitsAndBytes quantization. "
|
|
0 commit comments