|
| 1 | +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py |
| 2 | +from typing import List, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers.models.llama.modeling_llama import ( |
| 6 | + LlamaAttention, |
| 7 | + LlamaDecoderLayer, |
| 8 | + LlamaForCausalLM, |
| 9 | + LlamaMLP, |
| 10 | + LlamaModel, |
| 11 | +) |
| 12 | + |
| 13 | +from colossalai.inference.flash_decoding_utils import FDIntermTensors |
| 14 | +from colossalai.inference.struct import BatchInfo |
| 15 | +from colossalai.kernel.triton import ( |
| 16 | + context_attention_unpadded, |
| 17 | + copy_kv_to_blocked_cache, |
| 18 | + flash_decoding_attention, |
| 19 | + get_xine_cache, |
| 20 | + rotary_embedding, |
| 21 | +) |
| 22 | +from colossalai.logging import get_dist_logger |
| 23 | + |
| 24 | +from flash_attn.bert_padding import index_first_axis, pad_input # noqa |
| 25 | + |
| 26 | +logger = get_dist_logger(__name__) |
| 27 | + |
| 28 | +try: |
| 29 | + HAS_TRITON = True |
| 30 | +except ImportError: |
| 31 | + HAS_TRITON = False |
| 32 | + logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") |
| 33 | + |
| 34 | + |
| 35 | +@torch.no_grad() |
| 36 | +def llama_causal_lm_forward( |
| 37 | + self: LlamaForCausalLM, |
| 38 | + batch: BatchInfo = None, |
| 39 | + k_caches: List[torch.Tensor] = None, |
| 40 | + v_caches: List[torch.Tensor] = None, |
| 41 | +): |
| 42 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 43 | + hidden_states = llama_model_forward( |
| 44 | + self.model, |
| 45 | + batch=batch, |
| 46 | + k_caches=k_caches, |
| 47 | + v_caches=v_caches, |
| 48 | + ) |
| 49 | + logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1)) |
| 50 | + return logits |
| 51 | + |
| 52 | + |
| 53 | +@torch.no_grad() |
| 54 | +def llama_model_forward( |
| 55 | + self: LlamaModel, |
| 56 | + batch: BatchInfo = None, |
| 57 | + k_caches: List[torch.Tensor] = None, |
| 58 | + v_caches: List[torch.Tensor] = None, |
| 59 | +): |
| 60 | + input_ids = batch.get_1D_inputs() |
| 61 | + block_tables = batch.get_block_table_tensor() |
| 62 | + |
| 63 | + sequence_lengths = batch.get_sequence_lengths() |
| 64 | + batch_size = len(sequence_lengths) |
| 65 | + kv_seq_len = sequence_lengths.max().item() |
| 66 | + |
| 67 | + hidden_states = self.embed_tokens(input_ids) |
| 68 | + |
| 69 | + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) |
| 70 | + |
| 71 | + if batch.is_prompts: |
| 72 | + output_tensor = torch.zeros( |
| 73 | + (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device |
| 74 | + ) |
| 75 | + else: |
| 76 | + output_tensor = torch.zeros( |
| 77 | + (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device |
| 78 | + ) |
| 79 | + sm_scale = 1.0 / (batch.head_dim**0.5) |
| 80 | + |
| 81 | + for layer_id, decoder_layer in enumerate(self.layers): |
| 82 | + hidden_states = decoder_layer( |
| 83 | + hidden_states, |
| 84 | + block_tables=block_tables, |
| 85 | + k_cache=k_caches[layer_id], |
| 86 | + v_cache=v_caches[layer_id], |
| 87 | + is_prompts=batch.is_prompts, |
| 88 | + sequence_lengths=sequence_lengths, |
| 89 | + kv_seq_len=kv_seq_len, |
| 90 | + cos_sin=cos_sin, |
| 91 | + fd_inter_tensor=batch.fd_inter_tensor, |
| 92 | + output_tensor=output_tensor, |
| 93 | + sm_scale=sm_scale, |
| 94 | + ) |
| 95 | + |
| 96 | + if batch.is_prompts: |
| 97 | + last_token_indexs = sequence_lengths.cumsum(dim=-1) |
| 98 | + hidden_states = hidden_states[last_token_indexs - 1].contiguous() |
| 99 | + hidden_states = self.norm(hidden_states) |
| 100 | + |
| 101 | + return hidden_states |
| 102 | + |
| 103 | + |
| 104 | +@torch.no_grad() |
| 105 | +def llama_decoder_layer_forward( |
| 106 | + self: LlamaDecoderLayer, |
| 107 | + hidden_states: torch.Tensor, |
| 108 | + block_tables: torch.Tensor = None, |
| 109 | + k_cache: torch.Tensor = None, |
| 110 | + v_cache: torch.Tensor = None, |
| 111 | + is_prompts: bool = True, |
| 112 | + sequence_lengths: torch.Tensor = None, |
| 113 | + kv_seq_len: int = 0, |
| 114 | + cos_sin: Tuple[torch.Tensor] = None, |
| 115 | + fd_inter_tensor: FDIntermTensors = None, |
| 116 | + output_tensor: torch.Tensor = None, |
| 117 | + sm_scale: int = None, |
| 118 | +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| 119 | + residual = hidden_states |
| 120 | + |
| 121 | + hidden_states = self.input_layernorm(hidden_states) |
| 122 | + # Self Attention |
| 123 | + hidden_states = self.self_attn( |
| 124 | + hidden_states=hidden_states, |
| 125 | + block_tables=block_tables, |
| 126 | + k_cache=k_cache, |
| 127 | + v_cache=v_cache, |
| 128 | + is_prompts=is_prompts, |
| 129 | + sequence_lengths=sequence_lengths, |
| 130 | + kv_seq_len=kv_seq_len, |
| 131 | + cos_sin=cos_sin, |
| 132 | + fd_inter_tensor=fd_inter_tensor, |
| 133 | + output_tensor=output_tensor, |
| 134 | + sm_scale=sm_scale, |
| 135 | + ) |
| 136 | + |
| 137 | + hidden_states = residual + hidden_states |
| 138 | + |
| 139 | + # Fully Connected |
| 140 | + residual = hidden_states |
| 141 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 142 | + hidden_states = self.mlp(hidden_states) |
| 143 | + hidden_states = residual + hidden_states |
| 144 | + |
| 145 | + return hidden_states |
| 146 | + |
| 147 | + |
| 148 | +# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward |
| 149 | +@torch.no_grad() |
| 150 | +def llama_attn_forward( |
| 151 | + self: LlamaAttention, |
| 152 | + hidden_states: torch.Tensor, |
| 153 | + block_tables: torch.Tensor = None, |
| 154 | + k_cache: torch.Tensor = None, |
| 155 | + v_cache: torch.Tensor = None, |
| 156 | + is_prompts: bool = True, |
| 157 | + sequence_lengths: torch.Tensor = None, |
| 158 | + kv_seq_len: int = 0, |
| 159 | + cos_sin: Tuple[torch.Tensor] = None, |
| 160 | + fd_inter_tensor: FDIntermTensors = None, |
| 161 | + output_tensor: torch.Tensor = None, |
| 162 | + sm_scale: int = None, |
| 163 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 164 | + query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim) |
| 165 | + key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view( |
| 166 | + -1, self.num_key_value_heads, self.head_dim |
| 167 | + ) |
| 168 | + value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view( |
| 169 | + -1, self.num_key_value_heads, self.head_dim |
| 170 | + ) |
| 171 | + |
| 172 | + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) |
| 173 | + |
| 174 | + _, _, _, block_size = k_cache.shape |
| 175 | + |
| 176 | + if is_prompts: |
| 177 | + attn_output = context_attention_unpadded( |
| 178 | + q=query_states, |
| 179 | + k=key_states, |
| 180 | + v=value_states, |
| 181 | + k_cache=k_cache, |
| 182 | + v_cache=v_cache, |
| 183 | + context_lengths=sequence_lengths, |
| 184 | + block_tables=block_tables, |
| 185 | + block_size=block_size, |
| 186 | + output=output_tensor, |
| 187 | + max_seq_len=kv_seq_len, |
| 188 | + sm_scale=sm_scale, |
| 189 | + ) |
| 190 | + else: |
| 191 | + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) |
| 192 | + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) |
| 193 | + attn_output = flash_decoding_attention( |
| 194 | + q=query_states, |
| 195 | + k_cache=k_cache, |
| 196 | + v_cache=v_cache, |
| 197 | + kv_seq_len=sequence_lengths, |
| 198 | + block_tables=block_tables, |
| 199 | + block_size=block_size, |
| 200 | + max_seq_len_in_batch=kv_seq_len, |
| 201 | + output=output_tensor, |
| 202 | + mid_output=fd_inter_tensor.mid_output, |
| 203 | + mid_output_lse=fd_inter_tensor.mid_output_lse, |
| 204 | + sm_scale=sm_scale, |
| 205 | + ) |
| 206 | + attn_output = attn_output.squeeze(1) |
| 207 | + |
| 208 | + attn_output = attn_output.view(-1, self.num_heads, self.head_dim) |
| 209 | + attn_output = attn_output.reshape(-1, self.hidden_size) |
| 210 | + attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1)) |
| 211 | + |
| 212 | + return attn_output |
| 213 | + |
| 214 | + |
| 215 | +@torch.no_grad() |
| 216 | +def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor): |
| 217 | + gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1)) |
| 218 | + act_out = torch.nn.functional.silu(gate_proj_out, inplace=True) |
| 219 | + up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1)) |
| 220 | + tmp_out = act_out * up_proj_out |
| 221 | + return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1)) |
0 commit comments