From ee326afdf39352247bc7b57200a39b61b3175349 Mon Sep 17 00:00:00 2001 From: sixgod Date: Sun, 5 Jan 2025 14:44:20 +0800 Subject: [PATCH 1/2] [model] Add cogagent model support vLLM --- vllm/model_executor/models/chatglm.py | 273 +++++++++++++++++++++++--- 1 file changed, 243 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ffd6891b2596..caac433a4007 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -1,6 +1,6 @@ # Adapted from -# https://github.com/THUDM/GLM-4 -"""Inference-only ChatGLM model compatible with THUDM weights.""" +# https://github.com/THUDM/CogAgent +"""Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace from array import array from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, @@ -24,7 +24,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -33,7 +32,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, +from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -54,7 +53,7 @@ def calculate_image_placeholder(vision_config): def mm_input_mapper_for_glmv( ctx: InputContext, - data: ModalityData[object], + data: MultiModalData[object], ) -> Dict: model_config = ctx.model_config tokenizer = cached_get_tokenizer( @@ -201,7 +200,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): new_input_ids = [] final_processed_position = 0 - final_processed_position = 0 for boi_position, eoi_position in zip(boi_positions, eoi_positions): assert boi_position < eoi_position @@ -218,11 +216,131 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): prompt = tokenizer.decode(new_input_ids) return token_inputs( - prompt_token_ids=new_input_ids, + prompt_token_ids=raw_batch_data['input_ids'][0], + # prompt_token_ids=new_input_ids, prompt=prompt, multi_modal_data=multi_modal_data, ) +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [b, np, sq, hn] + sq = x.size(1) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:, :sq] + xshaped = x.chunk(2, -1) + cos, sin = rope_cache[...,0].unsqueeze(2), rope_cache[...,1].unsqueeze(2) + # print(f'\033[92m--cos\033[0m',cos.shape) + # print(f'\033[92m--xshaped\033[0m',xshaped[0].shape) + x_out2 = torch.concat( + [ + xshaped[0] * cos - xshaped[1] * sin, + xshaped[1] * cos + xshaped[0] * sin, + ], + -1, + ) + return torch.cat((x_out2, x_pass), dim=-1) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / ( + 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim) + ) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + self.rope_ratio = rope_ratio + + def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype): + base = 10000 * self.rope_ratio + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32) + freqs = torch.outer(seq, inv_freq) + + emb = torch.stack((freqs.cos(), freqs.sin()), dim=-1).to(dtype=dtype) + return emb + + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + base = base * self.rope_ratio + theta = 1.0 / ( + base + ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem) + ) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + if self.original_impl: + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + else: + return self.impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + class GLMAttention(nn.Module): @@ -236,10 +354,27 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() + self.projection_size = config.kv_channels * config.num_attention_heads + self.qkv_hidden_size = 3 * self.projection_size + self.original_rope = config.original_rope + + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = ( + self.projection_size // config.num_attention_heads + ) + # num_attention_heads_per_partition self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) @@ -272,16 +407,6 @@ def __init__( quant_config=quant_config, ) - # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 - rope_ratio = getattr(config, "rope_ratio", 1.0) - max_positions = getattr(config, "seq_length", 8192) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim // 2, - max_position=max_positions, - base=10000 * rope_ratio, - is_neox_style=False, - ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -293,20 +418,78 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, + rotary_pos_emb: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(position_ids, q, k) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = qkv.split( + [ + self.total_num_heads + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.total_num_heads, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + else: + new_tensor_shape = qkv.size()[:-1] + ( + self.total_num_heads, + 3 * self.hidden_size_per_attention_head, + ) + qkv = qkv.view(*new_tensor_shape) + + # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim( + qkv, 3 + ) + # q, k = self.rotary_emb(position_ids, q, k) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer, key_layer, value_layer = [ + k.transpose(1, 2) for k in [query_layer, key_layer, value_layer] + ] + + query_layer = query_layer.reshape(query_layer.size(2), -1) + key_layer = key_layer.reshape(key_layer.size(2), -1) + value_layer = value_layer.reshape(value_layer.size(2), -1) + context_layer = self.attn( - q, - k, - v, + query_layer, + key_layer, + value_layer, kv_cache, attn_metadata, ) + attn_output, _ = self.dense(context_layer) return attn_output @@ -397,7 +580,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, + rotary_pos_emb: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: @@ -407,7 +590,7 @@ def forward( # Self attention. attention_output = self.self_attention( hidden_states=layernorm_output, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache, attn_metadata=attn_metadata, ) @@ -471,15 +654,16 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, + rotary_pos_emb: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + # print(f'第{i}层\n') hidden_states = layer( hidden_states=hidden_states, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, kv_cache=kv_caches[i - self.start_layer], attn_metadata=attn_metadata, ) @@ -508,6 +692,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads + if config.kv_channels is None + else config.kv_channels + ) + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + rope_ratio=config.rope_ratio, + original_impl=config.original_rope, + dtype=config.torch_dtype, + ) self.encoder = GLMTransformer(config, cache_config, quant_config, @@ -583,18 +779,33 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. - if intermediate_tensors is None and inputs_embeds is None: + seq_length = len(input_ids) + + if intermediate_tensors is not None: + + # if intermediate_tensors is None and inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None else: - inputs_embeds = intermediate_tensors["hidden_states"] + inputs_embeds = self.embedding(input_ids) + + # inputs_embeds = intermediate_tensors["hidden_states"] + + positions = positions.unsqueeze(0) + inputs_embeds = inputs_embeds.unsqueeze(0) + + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if positions is not None: + rotary_pos_emb = rotary_pos_emb[positions] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, - position_ids=positions, + rotary_pos_emb=rotary_pos_emb, kv_caches=kv_caches, attn_metadata=attn_metadata, ) @@ -639,6 +850,8 @@ def forward(self, hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, **kwargs) + hidden_states = hidden_states.squeeze(0) + return hidden_states def compute_logits( From 998524481a3fa2eaf92025f3da688a490806d565 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 10 Jan 2025 22:45:11 +0800 Subject: [PATCH 2/2] revert to use get_rope Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/chatglm.py | 273 +++----------------------- 1 file changed, 31 insertions(+), 242 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index caac433a4007..7e37ce3086e6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -24,6 +24,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -32,7 +33,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, +from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -53,7 +54,7 @@ def calculate_image_placeholder(vision_config): def mm_input_mapper_for_glmv( ctx: InputContext, - data: MultiModalData[object], + data: ModalityData[object], ) -> Dict: model_config = ctx.model_config tokenizer = cached_get_tokenizer( @@ -216,131 +217,11 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): prompt = tokenizer.decode(new_input_ids) return token_inputs( - prompt_token_ids=raw_batch_data['input_ids'][0], - # prompt_token_ids=new_input_ids, + prompt_token_ids=new_input_ids, prompt=prompt, multi_modal_data=multi_modal_data, ) -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [b, np, sq, hn] - sq = x.size(1) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:, :sq] - xshaped = x.chunk(2, -1) - cos, sin = rope_cache[...,0].unsqueeze(2), rope_cache[...,1].unsqueeze(2) - # print(f'\033[92m--cos\033[0m',cos.shape) - # print(f'\033[92m--xshaped\033[0m',xshaped[0].shape) - x_out2 = torch.concat( - [ - xshaped[0] * cos - xshaped[1] * sin, - xshaped[1] * cos + xshaped[0] * sin, - ], - -1, - ) - return torch.cat((x_out2, x_pass), dim=-1) - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / ( - 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim) - ) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - self.rope_ratio = rope_ratio - - def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype): - base = 10000 * self.rope_ratio - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32) - freqs = torch.outer(seq, inv_freq) - - emb = torch.stack((freqs.cos(), freqs.sin()), dim=-1).to(dtype=dtype) - return emb - - def forward_impl( - self, - seq_len: int, - n_elem: int, - dtype: torch.dtype, - device: torch.device, - base: int = 10000, - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - base = base * self.rope_ratio - theta = 1.0 / ( - base - ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem) - ) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - if self.original_impl: - return self.forward_impl( - max_seq_len, - self.dim, - dtype=self.inv_freq.dtype, - device=self.inv_freq.device, - ) - else: - return self.impl( - max_seq_len, - self.dim, - dtype=self.inv_freq.dtype, - device=self.inv_freq.device, - ) - class GLMAttention(nn.Module): @@ -354,27 +235,10 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() - self.projection_size = config.kv_channels * config.num_attention_heads - self.qkv_hidden_size = 3 * self.projection_size - self.original_rope = config.original_rope - - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) - # num_attention_heads_per_partition self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size - + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) @@ -407,6 +271,19 @@ def __init__( quant_config=quant_config, ) + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False, + # which is equivalent to is_neox_style=True + is_neox_style = not config.original_rope + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=is_neox_style, + ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -418,78 +295,20 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - rotary_pos_emb: torch.Tensor, + position_ids: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = qkv.split( - [ - self.total_num_heads - * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition - * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition - * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] - + ( - self.total_num_heads, - self.hidden_size_per_attention_head, - ) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - else: - new_tensor_shape = qkv.size()[:-1] + ( - self.total_num_heads, - 3 * self.hidden_size_per_attention_head, - ) - qkv = qkv.view(*new_tensor_shape) - - # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim( - qkv, 3 - ) - # q, k = self.rotary_emb(position_ids, q, k) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer, key_layer, value_layer = [ - k.transpose(1, 2) for k in [query_layer, key_layer, value_layer] - ] - - query_layer = query_layer.reshape(query_layer.size(2), -1) - key_layer = key_layer.reshape(key_layer.size(2), -1) - value_layer = value_layer.reshape(value_layer.size(2), -1) - + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn( - query_layer, - key_layer, - value_layer, + q, + k, + v, kv_cache, attn_metadata, ) - attn_output, _ = self.dense(context_layer) return attn_output @@ -580,7 +399,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - rotary_pos_emb: torch.Tensor, + position_ids: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: @@ -590,7 +409,7 @@ def forward( # Self attention. attention_output = self.self_attention( hidden_states=layernorm_output, - rotary_pos_emb=rotary_pos_emb, + position_ids=position_ids, kv_cache=kv_cache, attn_metadata=attn_metadata, ) @@ -654,16 +473,15 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - rotary_pos_emb: torch.Tensor, + position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - # print(f'第{i}层\n') hidden_states = layer( hidden_states=hidden_states, - rotary_pos_emb=rotary_pos_emb, + position_ids=position_ids, kv_cache=kv_caches[i - self.start_layer], attn_metadata=attn_metadata, ) @@ -692,18 +510,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads - if config.kv_channels is None - else config.kv_channels - ) - self.rotary_pos_emb = RotaryEmbedding( - rotary_dim // 2, - rope_ratio=config.rope_ratio, - original_impl=config.original_rope, - dtype=config.torch_dtype, - ) self.encoder = GLMTransformer(config, cache_config, quant_config, @@ -779,33 +585,18 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. - seq_length = len(input_ids) - - if intermediate_tensors is not None: - - # if intermediate_tensors is None and inputs_embeds is None: + if intermediate_tensors is None and inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None else: - inputs_embeds = self.embedding(input_ids) - - # inputs_embeds = intermediate_tensors["hidden_states"] - - positions = positions.unsqueeze(0) - inputs_embeds = inputs_embeds.unsqueeze(0) - - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if positions is not None: - rotary_pos_emb = rotary_pos_emb[positions] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] + inputs_embeds = intermediate_tensors["hidden_states"] # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, - rotary_pos_emb=rotary_pos_emb, + position_ids=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, ) @@ -850,8 +641,6 @@ def forward(self, hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, **kwargs) - hidden_states = hidden_states.squeeze(0) - return hidden_states def compute_logits( @@ -992,4 +781,4 @@ def __new__( return ChatGLMV(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: - return ChatGLM(vllm_config=vllm_config, prefix=prefix) + return ChatGLM(vllm_config=vllm_config, prefix=prefix) \ No newline at end of file