From 6a93e00a4518e662c4648af083b53228197f4cef Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 21 Jan 2025 19:32:05 +0000 Subject: [PATCH 01/10] first draft Signed-off-by: NickLucche --- vllm/model_executor/models/bart.py | 48 ++++++++++-------- vllm/model_executor/models/utils.py | 75 +++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5847c50862e5..a37101d45904 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -44,7 +44,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsV0Only -from .utils import maybe_prefix +from .utils import QKVCrossParallelLinear, maybe_prefix logger = logging.get_logger(__name__) @@ -299,14 +299,22 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) + # self.q_proj = ColumnParallelLinear( + # input_size=self.embed_dim, + # output_size=self.embed_dim, + # bias=bias, + # quant_config=quant_config, + # ) + + # self.kv_proj = QKVParallelLinear( + # hidden_size=self.d_model, + # head_size=self.d_model // self.total_num_heads, + # total_num_heads=0, + # total_num_kv_heads=self.total_num_kv_heads, + # bias=bias, + # quant_config=quant_config, + # ) + self.qkv_proj = QKVCrossParallelLinear(self.d_model, self.d_model // self.total_num_heads, self.total_num_heads, self.total_num_kv_heads, bias, quant_config=quant_config) self.out_proj = RowParallelLinear( embed_dim, @@ -347,18 +355,16 @@ def forward( ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - # (afeldman-nm 2024/07/22) TODO: - # Need a more efficient solution for q/k/v - qkv_dec, _ = self.qkv_proj(decoder_hidden_states) - q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - if encoder_hidden_states is None: - k = None - v = None - else: - qkv_enc, _ = self.qkv_proj(encoder_hidden_states) - _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + # q, _ = self.q_proj(decoder_hidden_states) + + # if encoder_hidden_states is None: + # k = None + # v = None + # else: + # # Prefill, cache encoder KV. + # kv_enc, _ = self.kv_proj(encoder_hidden_states) + # k, v = kv_enc.split(self.kv_size, dim=-1) + q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f9aa5da39a5f..c5a540abf67c 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -16,6 +16,9 @@ from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available +from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig logger = init_logger(__name__) @@ -651,3 +654,75 @@ def cast_overflow_tensors( clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) return tensors + +class QKVCrossParallelLinear(torch.nn.Module): + def __init__(self, hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.weight = torch.nn.Parameter() # placeholder for loading + self.bias = torch.nn.Parameter() # placeholder for loading + + self.q_proj_decoder = ColumnParallelLinear( + input_size=hidden_size, + output_size=total_num_heads*head_size, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype + ) + self.kv_size = total_num_kv_heads*head_size + self.kv_proj_encoder = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=0, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype + ) + + set_weight_attrs(self.weight, { + "output_dim": 0, + "weight_loader": self.weight_loader_weight, + }) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader_bias, + }) + # Do not show placeholders after loading the model. + # delattr(self, "weight") + # delattr(self, "bias") + + def forward(self, decoder_hidden_states, encoder_hidden_states): + q, _ = self.q_proj_decoder(decoder_hidden_states) + if encoder_hidden_states is None: + # Encoder KV already cached. + k = None + v = None + else: + # Prefill phase, encoder KV cached here. + kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) + # Split kv in half + k, v = kv_enc.split(self.kv_size, dim=-1) + return q, k, v + + def weight_loader_weight(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder params. + param = self.q_proj_decoder.weight if loaded_shard_id == "q" else self.kv_proj_encoder.weight + param.weight_loader(param, loaded_weight) if loaded_shard_id == "q" else param.weight_loader(param, loaded_weight, loaded_shard_id) + + def weight_loader_bias(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param = self.q_proj_decoder.bias if loaded_shard_id == "q" else self.kv_proj_encoder.bias + param.weight_loader(param, loaded_weight) if loaded_shard_id == "q" else param.weight_loader(param, loaded_weight, loaded_shard_id) From c313996744f8d63e331b87f58ea7a27ad29ab667 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 22 Jan 2025 11:10:08 +0000 Subject: [PATCH 02/10] cleanup Signed-off-by: NickLucche --- vllm/model_executor/models/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c5a540abf67c..1e048b92a40c 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -666,8 +666,9 @@ def __init__(self, hidden_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() - self.weight = torch.nn.Parameter() # placeholder for loading - self.bias = torch.nn.Parameter() # placeholder for loading + # Empty placeholders for loading as a single module. + self.weight = torch.nn.Parameter() + self.bias = torch.nn.Parameter() self.q_proj_decoder = ColumnParallelLinear( input_size=hidden_size, @@ -690,16 +691,11 @@ def __init__(self, hidden_size: int, ) set_weight_attrs(self.weight, { - "output_dim": 0, - "weight_loader": self.weight_loader_weight, + "weight_loader": self.weight_loader_weight, }) set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader_bias, + "weight_loader": self.weight_loader_bias, }) - # Do not show placeholders after loading the model. - # delattr(self, "weight") - # delattr(self, "bias") def forward(self, decoder_hidden_states, encoder_hidden_states): q, _ = self.q_proj_decoder(decoder_hidden_states) From f42545579f3f739d560814e5eddf21219d950dc4 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 22 Jan 2025 17:34:12 +0000 Subject: [PATCH 03/10] submodules in dict to avoid param registration Signed-off-by: NickLucche --- vllm/model_executor/models/utils.py | 91 ++++++++++++++++++----------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1e048b92a40c..16fa6facbe43 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,13 +12,15 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available -from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear -from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig logger = init_logger(__name__) @@ -656,7 +658,9 @@ def cast_overflow_tensors( return tensors class QKVCrossParallelLinear(torch.nn.Module): - def __init__(self, hidden_size: int, + + def __init__(self, + hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: Optional[int] = None, @@ -667,19 +671,23 @@ def __init__(self, hidden_size: int, prefix: str = ""): super().__init__() # Empty placeholders for loading as a single module. - self.weight = torch.nn.Parameter() - self.bias = torch.nn.Parameter() - - self.q_proj_decoder = ColumnParallelLinear( + self.weight = torch.nn.Parameter() + set_weight_attrs(self.weight, { + "weight_loader": self.weight_loader_weight, + }) + # Use a dictionary to avoid submodules parameters auto-registration: + # drop-in replacement for a `QKVParallelLinear` module. + self.proj = dict() + self.proj["q_proj_decoder"] = ColumnParallelLinear( input_size=hidden_size, - output_size=total_num_heads*head_size, + output_size=total_num_heads * head_size, bias=bias, quant_config=quant_config, skip_bias_add=skip_bias_add, - params_dtype=params_dtype - ) - self.kv_size = total_num_kv_heads*head_size - self.kv_proj_encoder = QKVParallelLinear( + params_dtype=params_dtype, + prefix=f"{prefix}.q_proj_decoder") + self.kv_size = total_num_kv_heads * head_size + self.proj["kv_proj_encoder"] = QKVParallelLinear( hidden_size=hidden_size, head_size=head_size, total_num_heads=0, @@ -687,15 +695,22 @@ def __init__(self, hidden_size: int, bias=bias, quant_config=quant_config, skip_bias_add=skip_bias_add, - params_dtype=params_dtype - ) + params_dtype=params_dtype, + prefix=f"{prefix}.kv_proj_encoder") - set_weight_attrs(self.weight, { - "weight_loader": self.weight_loader_weight, - }) - set_weight_attrs(self.bias, { - "weight_loader": self.weight_loader_bias, - }) + if bias: + self.bias = torch.nn.Parameter() + set_weight_attrs(self.bias, { + "weight_loader": self.weight_loader_bias, + }) + + @property + def q_proj_decoder(self): + return self.proj["q_proj_decoder"] + + @property + def kv_proj_encoder(self): + return self.proj["kv_proj_encoder"] def forward(self, decoder_hidden_states, encoder_hidden_states): q, _ = self.q_proj_decoder(decoder_hidden_states) @@ -710,15 +725,25 @@ def forward(self, decoder_hidden_states, encoder_hidden_states): k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v - def weight_loader_weight(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder params. - param = self.q_proj_decoder.weight if loaded_shard_id == "q" else self.kv_proj_encoder.weight - param.weight_loader(param, loaded_weight) if loaded_shard_id == "q" else param.weight_loader(param, loaded_weight, loaded_shard_id) - - def weight_loader_bias(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - param = self.q_proj_decoder.bias if loaded_shard_id == "q" else self.kv_proj_encoder.bias - param.weight_loader(param, loaded_weight) if loaded_shard_id == "q" else param.weight_loader(param, loaded_weight, loaded_shard_id) + def weight_loader_weight(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param. + param = self.q_proj_decoder.weight if loaded_shard_id == "q" \ + else self.kv_proj_encoder.weight + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) + + def weight_loader_bias(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ + else self.kv_proj_encoder.bias + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) From 80ef42dc9c35f0b971a73436ae9c8d38d00da7ed Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 22 Jan 2025 17:35:23 +0000 Subject: [PATCH 04/10] mllama test Signed-off-by: NickLucche --- vllm/model_executor/models/mllama.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index f74fa7a46629..c383a2289141 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -66,7 +66,7 @@ from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only from .llama import LlamaDecoderLayer, LlamaMLP -from .utils import maybe_prefix +from .utils import QKVCrossParallelLinear, maybe_prefix logger = init_logger(__name__) @@ -811,8 +811,7 @@ def __init__( self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim - # TODO: change to Q/KV separate linear after #7448 is merged - self.qkv_proj = QKVParallelLinear( + self.qkv_proj = QKVCrossParallelLinear( self.hidden_size, self.head_dim, self.num_heads, @@ -851,21 +850,12 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], ) -> torch.Tensor: - qkv_dec, _ = self.qkv_proj(hidden_states) - q, _, _ = qkv_dec.split( - [self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) - if cross_attention_states is None: - k = None - v = None - else: - qkv_enc, _ = self.qkv_proj(cross_attention_states) - _, k, v = qkv_enc.split( - [self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) + q, k, v = self.qkv_proj(hidden_states, cross_attention_states) + if cross_attention_states is not None: k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) @@ -889,6 +879,7 @@ def _attention_with_mask( kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Skip writing kv-cache for the initial profiling run. + # TODO (NickLucche) replace with custom attn bias and use standard attn if len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, From 7d293d1e64794f7680e2dd7b39ef3071bf5e2a76 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 23 Jan 2025 18:07:08 +0000 Subject: [PATCH 05/10] format Signed-off-by: NickLucche --- vllm/model_executor/models/bart.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a37101d45904..510914c7048b 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -314,7 +314,13 @@ def __init__( # bias=bias, # quant_config=quant_config, # ) - self.qkv_proj = QKVCrossParallelLinear(self.d_model, self.d_model // self.total_num_heads, self.total_num_heads, self.total_num_kv_heads, bias, quant_config=quant_config) + self.qkv_proj = QKVCrossParallelLinear(self.d_model, + self.d_model // + self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias, + quant_config=quant_config) self.out_proj = RowParallelLinear( embed_dim, From 5061ab7fb0da14ef2454db2f672c835b3c6562c6 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 24 Jan 2025 14:44:37 +0000 Subject: [PATCH 06/10] clean up comments Signed-off-by: NickLucche --- vllm/model_executor/models/bart.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 510914c7048b..a1f88a5a535e 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -299,21 +299,6 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - # self.q_proj = ColumnParallelLinear( - # input_size=self.embed_dim, - # output_size=self.embed_dim, - # bias=bias, - # quant_config=quant_config, - # ) - - # self.kv_proj = QKVParallelLinear( - # hidden_size=self.d_model, - # head_size=self.d_model // self.total_num_heads, - # total_num_heads=0, - # total_num_kv_heads=self.total_num_kv_heads, - # bias=bias, - # quant_config=quant_config, - # ) self.qkv_proj = QKVCrossParallelLinear(self.d_model, self.d_model // self.total_num_heads, @@ -361,15 +346,6 @@ def forward( ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - # q, _ = self.q_proj(decoder_hidden_states) - - # if encoder_hidden_states is None: - # k = None - # v = None - # else: - # # Prefill, cache encoder KV. - # kv_enc, _ = self.kv_proj(encoder_hidden_states) - # k, v = kv_enc.split(self.kv_size, dim=-1) q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) attn_output = self.attn(q, k, v) From 00ae74d4f774e9d45efe26cfb572662c88e44427 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 18 Feb 2025 17:51:10 +0000 Subject: [PATCH 07/10] fix distributed use Signed-off-by: NickLucche --- vllm/commit_id.py | 1 + vllm/model_executor/models/bart.py | 6 ++---- vllm/model_executor/models/mllama.py | 8 +++++--- vllm/model_executor/models/utils.py | 5 ++++- 4 files changed, 12 insertions(+), 8 deletions(-) create mode 100644 vllm/commit_id.py diff --git a/vllm/commit_id.py b/vllm/commit_id.py new file mode 100644 index 000000000000..f22dbb8199db --- /dev/null +++ b/vllm/commit_id.py @@ -0,0 +1 @@ +__commit__ = "933dc175653650d405b1e344822a57dad241c075" diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a1f88a5a535e..a54f20fd294c 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -299,6 +299,7 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 + # TP sharding sizes is accounted for within "*Parallel" layers. self.qkv_proj = QKVCrossParallelLinear(self.d_model, self.d_model // self.total_num_heads, @@ -326,10 +327,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - + self.num_kv_heads = self.num_heads # No GQA in bart self.attn = Attention(self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index c383a2289141..be2414d3ede9 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -798,14 +798,15 @@ def __init__( self.config = config self.pipeline_parallel_rank = get_pp_group().rank_in_group self.tensor_parallel_size = get_tp_group().world_size - self.num_heads = self.config.num_attention_heads + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_local_heads = self.num_heads // self.tensor_parallel_size - self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = \ self.num_key_value_heads // self.tensor_parallel_size - self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_local_size = self.num_local_heads * self.head_dim @@ -820,6 +821,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 16fa6facbe43..f8a9530e321d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -686,7 +686,7 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, prefix=f"{prefix}.q_proj_decoder") - self.kv_size = total_num_kv_heads * head_size + self.proj["kv_proj_encoder"] = QKVParallelLinear( hidden_size=hidden_size, head_size=head_size, @@ -698,6 +698,9 @@ def __init__(self, params_dtype=params_dtype, prefix=f"{prefix}.kv_proj_encoder") + # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size + if bias: self.bias = torch.nn.Parameter() set_weight_attrs(self.bias, { From 58e7308253aaba636e8068c3b976fbc4c1326e4f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 18 Feb 2025 17:56:14 +0000 Subject: [PATCH 08/10] format Signed-off-by: NickLucche --- vllm/commit_id.py | 1 - vllm/model_executor/models/bart.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) delete mode 100644 vllm/commit_id.py diff --git a/vllm/commit_id.py b/vllm/commit_id.py deleted file mode 100644 index f22dbb8199db..000000000000 --- a/vllm/commit_id.py +++ /dev/null @@ -1 +0,0 @@ -__commit__ = "933dc175653650d405b1e344822a57dad241c075" diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a54f20fd294c..d9cd386da057 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -299,7 +299,7 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - # TP sharding sizes is accounted for within "*Parallel" layers. + # TP sharding sizes is accounted for within "*Parallel" layers. self.qkv_proj = QKVCrossParallelLinear(self.d_model, self.d_model // self.total_num_heads, @@ -327,7 +327,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads # No GQA in bart + self.num_kv_heads = self.num_heads # No GQA in bart self.attn = Attention(self.num_heads, self.head_dim, self.scaling, From 75b5425658d6fde438f4c538e0f131c9b0cba07f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 19 Feb 2025 10:03:26 +0000 Subject: [PATCH 09/10] address review Signed-off-by: NickLucche --- vllm/model_executor/layers/linear.py | 95 +++++++++++++++++++++++++ vllm/model_executor/models/bart.py | 7 +- vllm/model_executor/models/mllama.py | 4 +- vllm/model_executor/models/utils.py | 101 +-------------------------- 4 files changed, 103 insertions(+), 104 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b9c85aaf50b5..0f5e5616814f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1178,3 +1178,98 @@ def extra_repr(self) -> str: s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s + + +class QKVCrossParallelLinear(torch.nn.Module): + + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + # Empty placeholders for loading as a single module. + self.weight = torch.nn.Parameter() + set_weight_attrs(self.weight, { + "weight_loader": self.weight_loader_weight, + }) + # Use a dictionary to avoid submodules parameters auto-registration: + # drop-in replacement for a `QKVParallelLinear` module. + self.proj = dict() + self.proj["q_proj_decoder"] = ColumnParallelLinear( + input_size=hidden_size, + output_size=total_num_heads * head_size, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.q_proj_decoder") + + self.proj["kv_proj_encoder"] = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=0, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.kv_proj_encoder") + + # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size + + if bias: + self.bias = torch.nn.Parameter() + set_weight_attrs(self.bias, { + "weight_loader": self.weight_loader_bias, + }) + + @property + def q_proj_decoder(self): + return self.proj["q_proj_decoder"] + + @property + def kv_proj_encoder(self): + return self.proj["kv_proj_encoder"] + + def forward(self, decoder_hidden_states, encoder_hidden_states): + q, _ = self.q_proj_decoder(decoder_hidden_states) + if encoder_hidden_states is None: + # Encoder KV already cached. + k = None + v = None + else: + # Prefill phase, encoder KV cached here. + kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) + # Split kv in half + k, v = kv_enc.split(self.kv_size, dim=-1) + return q, k, v + + def weight_loader_weight(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param. + param = self.q_proj_decoder.weight if loaded_shard_id == "q" \ + else self.kv_proj_encoder.weight + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) + + def weight_loader_bias(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ + else self.kv_proj_encoder.bias + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) \ No newline at end of file diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index d9cd386da057..109b65d92cf9 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -31,6 +31,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -44,7 +45,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsV0Only -from .utils import QKVCrossParallelLinear, maybe_prefix +from .utils import maybe_prefix logger = logging.get_logger(__name__) @@ -169,7 +170,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.num_kv_heads = self.num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -248,7 +249,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.num_kv_heads = self.num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index be2414d3ede9..9f5137fdd1fc 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + QKVCrossParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -66,7 +67,7 @@ from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only from .llama import LlamaDecoderLayer, LlamaMLP -from .utils import QKVCrossParallelLinear, maybe_prefix +from .utils import maybe_prefix logger = init_logger(__name__) @@ -806,6 +807,7 @@ def __init__( self.num_key_value_heads // self.tensor_parallel_size self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f8a9530e321d..a705aeffef35 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,12 +12,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -655,98 +650,4 @@ def cast_overflow_tensors( if tensors.isinf().any() or tensors.isnan().any(): clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) - return tensors - -class QKVCrossParallelLinear(torch.nn.Module): - - def __init__(self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__() - # Empty placeholders for loading as a single module. - self.weight = torch.nn.Parameter() - set_weight_attrs(self.weight, { - "weight_loader": self.weight_loader_weight, - }) - # Use a dictionary to avoid submodules parameters auto-registration: - # drop-in replacement for a `QKVParallelLinear` module. - self.proj = dict() - self.proj["q_proj_decoder"] = ColumnParallelLinear( - input_size=hidden_size, - output_size=total_num_heads * head_size, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.q_proj_decoder") - - self.proj["kv_proj_encoder"] = QKVParallelLinear( - hidden_size=hidden_size, - head_size=head_size, - total_num_heads=0, - total_num_kv_heads=total_num_kv_heads, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.kv_proj_encoder") - - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size - - if bias: - self.bias = torch.nn.Parameter() - set_weight_attrs(self.bias, { - "weight_loader": self.weight_loader_bias, - }) - - @property - def q_proj_decoder(self): - return self.proj["q_proj_decoder"] - - @property - def kv_proj_encoder(self): - return self.proj["kv_proj_encoder"] - - def forward(self, decoder_hidden_states, encoder_hidden_states): - q, _ = self.q_proj_decoder(decoder_hidden_states) - if encoder_hidden_states is None: - # Encoder KV already cached. - k = None - v = None - else: - # Prefill phase, encoder KV cached here. - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) - # Split kv in half - k, v = kv_enc.split(self.kv_size, dim=-1) - return q, k, v - - def weight_loader_weight(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param. - param = self.q_proj_decoder.weight if loaded_shard_id == "q" \ - else self.kv_proj_encoder.weight - param.weight_loader( - param, - loaded_weight) if loaded_shard_id == "q" else param.weight_loader( - param, loaded_weight, loaded_shard_id) - - def weight_loader_bias(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ - else self.kv_proj_encoder.bias - param.weight_loader( - param, - loaded_weight) if loaded_shard_id == "q" else param.weight_loader( - param, loaded_weight, loaded_shard_id) + return tensors \ No newline at end of file From 7c476771f2a8efd4ced3e25a120cd83e5cfa4ffe Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 5 Mar 2025 14:05:54 +0000 Subject: [PATCH 10/10] isort Signed-off-by: NickLucche --- vllm/model_executor/models/mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 9f5137fdd1fc..a9de63245d97 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -43,8 +43,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, QKVCrossParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig