diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b9c85aaf50b5..0f5e5616814f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1178,3 +1178,98 @@ def extra_repr(self) -> str: s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s + + +class QKVCrossParallelLinear(torch.nn.Module): + + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + # Empty placeholders for loading as a single module. + self.weight = torch.nn.Parameter() + set_weight_attrs(self.weight, { + "weight_loader": self.weight_loader_weight, + }) + # Use a dictionary to avoid submodules parameters auto-registration: + # drop-in replacement for a `QKVParallelLinear` module. + self.proj = dict() + self.proj["q_proj_decoder"] = ColumnParallelLinear( + input_size=hidden_size, + output_size=total_num_heads * head_size, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.q_proj_decoder") + + self.proj["kv_proj_encoder"] = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=0, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.kv_proj_encoder") + + # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size + + if bias: + self.bias = torch.nn.Parameter() + set_weight_attrs(self.bias, { + "weight_loader": self.weight_loader_bias, + }) + + @property + def q_proj_decoder(self): + return self.proj["q_proj_decoder"] + + @property + def kv_proj_encoder(self): + return self.proj["kv_proj_encoder"] + + def forward(self, decoder_hidden_states, encoder_hidden_states): + q, _ = self.q_proj_decoder(decoder_hidden_states) + if encoder_hidden_states is None: + # Encoder KV already cached. + k = None + v = None + else: + # Prefill phase, encoder KV cached here. + kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) + # Split kv in half + k, v = kv_enc.split(self.kv_size, dim=-1) + return q, k, v + + def weight_loader_weight(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param. + param = self.q_proj_decoder.weight if loaded_shard_id == "q" \ + else self.kv_proj_encoder.weight + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) + + def weight_loader_bias(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ + else self.kv_proj_encoder.bias + param.weight_loader( + param, + loaded_weight) if loaded_shard_id == "q" else param.weight_loader( + param, loaded_weight, loaded_shard_id) \ No newline at end of file diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5847c50862e5..109b65d92cf9 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -31,6 +31,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -169,7 +170,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.num_kv_heads = self.num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -248,7 +249,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.num_kv_heads = self.num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -299,14 +300,14 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) + # TP sharding sizes is accounted for within "*Parallel" layers. + self.qkv_proj = QKVCrossParallelLinear(self.d_model, + self.d_model // + self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias, + quant_config=quant_config) self.out_proj = RowParallelLinear( embed_dim, @@ -327,10 +328,7 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - + self.num_kv_heads = self.num_heads # No GQA in bart self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -347,18 +345,7 @@ def forward( ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - # (afeldman-nm 2024/07/22) TODO: - # Need a more efficient solution for q/k/v - qkv_dec, _ = self.qkv_proj(decoder_hidden_states) - q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - if encoder_hidden_states is None: - k = None - v = None - else: - qkv_enc, _ = self.qkv_proj(encoder_hidden_states) - _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index f74fa7a46629..a9de63245d97 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -43,6 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -798,21 +799,22 @@ def __init__( self.config = config self.pipeline_parallel_rank = get_pp_group().rank_in_group self.tensor_parallel_size = get_tp_group().world_size - self.num_heads = self.config.num_attention_heads + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_local_heads = self.num_heads // self.tensor_parallel_size - self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = \ self.num_key_value_heads // self.tensor_parallel_size - self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim - # TODO: change to Q/KV separate linear after #7448 is merged - self.qkv_proj = QKVParallelLinear( + self.qkv_proj = QKVCrossParallelLinear( self.hidden_size, self.head_dim, self.num_heads, @@ -821,6 +823,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, @@ -851,21 +854,12 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], ) -> torch.Tensor: - qkv_dec, _ = self.qkv_proj(hidden_states) - q, _, _ = qkv_dec.split( - [self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) - if cross_attention_states is None: - k = None - v = None - else: - qkv_enc, _ = self.qkv_proj(cross_attention_states) - _, k, v = qkv_enc.split( - [self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) + q, k, v = self.qkv_proj(hidden_states, cross_attention_states) + if cross_attention_states is not None: k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) @@ -889,6 +883,7 @@ def _attention_with_mask( kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Skip writing kv-cache for the initial profiling run. + # TODO (NickLucche) replace with custom attn bias and use standard attn if len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f9aa5da39a5f..a705aeffef35 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -650,4 +650,4 @@ def cast_overflow_tensors( if tensors.isinf().any() or tensors.isnan().any(): clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) - return tensors + return tensors \ No newline at end of file