Skip to content

Commit e8f0642

Browse files
authored
[Inference]Add Nopadding Llama Modeling (#5327)
* add nopadding llama modeling * add nopadding_llama.py * rm unused codes * fix bugs in test_xine_copy.py * fix code style
1 parent c7c104c commit e8f0642

File tree

9 files changed

+386
-49
lines changed

9 files changed

+386
-49
lines changed

colossalai/inference/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class InferenceConfig:
3232
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
3333
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
3434
when the actual value exceeds this ratio.
35+
pad_input: Whether to pad all inputs to the max length.
3536
quant_mode (Optional[str]): Quantization mode.
3637
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
3738
"""
@@ -49,6 +50,7 @@ class InferenceConfig:
4950
beam_width: int = 1
5051
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
5152
prefill_ratio: Optional[float] = 1.2
53+
pad_input: bool = False
5254
quant_mode: Optional[str] = None
5355
revision: Optional[str] = None
5456

colossalai/inference/core/engine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def __init__(
5757
model.to(self.dtype)
5858

5959
if model_policy is None:
60-
model_policy = model_policy_map[self.model_config.model_type]()
60+
if self.inference_config.pad_input:
61+
model_type = "padding_" + self.model_config.model_type
62+
else:
63+
model_type = "nopadding_" + self.model_config.model_type
64+
model_policy = model_policy_map[model_type]()
6165

6266
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
6367

@@ -168,7 +172,9 @@ def add_request(
168172

169173
if prompts_token_ids is None:
170174
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
171-
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
175+
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
176+
"input_ids"
177+
]
172178

173179
if isinstance(prompts_token_ids, list):
174180
pass
@@ -237,7 +243,9 @@ def step(self) -> List[str]:
237243
self.v_cache,
238244
)
239245

240-
logits = logits[:, -1, :]
246+
if self.inference_config.pad_input:
247+
logits = logits[:, -1, :]
248+
241249
self.request_handler.search_tokens(self.generation_config, logits)
242250
finished_sequences = self.request_handler.update()
243251

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
2+
from typing import List, Optional, Tuple
3+
4+
import torch
5+
from transformers.models.llama.modeling_llama import (
6+
LlamaAttention,
7+
LlamaDecoderLayer,
8+
LlamaForCausalLM,
9+
LlamaMLP,
10+
LlamaModel,
11+
)
12+
13+
from colossalai.inference.flash_decoding_utils import FDIntermTensors
14+
from colossalai.inference.struct import BatchInfo
15+
from colossalai.kernel.triton import (
16+
context_attention_unpadded,
17+
copy_kv_to_blocked_cache,
18+
flash_decoding_attention,
19+
get_xine_cache,
20+
rotary_embedding,
21+
)
22+
from colossalai.logging import get_dist_logger
23+
24+
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
25+
26+
logger = get_dist_logger(__name__)
27+
28+
try:
29+
HAS_TRITON = True
30+
except ImportError:
31+
HAS_TRITON = False
32+
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
33+
34+
35+
@torch.no_grad()
36+
def llama_causal_lm_forward(
37+
self: LlamaForCausalLM,
38+
batch: BatchInfo = None,
39+
k_caches: List[torch.Tensor] = None,
40+
v_caches: List[torch.Tensor] = None,
41+
):
42+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
43+
hidden_states = llama_model_forward(
44+
self.model,
45+
batch=batch,
46+
k_caches=k_caches,
47+
v_caches=v_caches,
48+
)
49+
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
50+
return logits
51+
52+
53+
@torch.no_grad()
54+
def llama_model_forward(
55+
self: LlamaModel,
56+
batch: BatchInfo = None,
57+
k_caches: List[torch.Tensor] = None,
58+
v_caches: List[torch.Tensor] = None,
59+
):
60+
input_ids = batch.get_1D_inputs()
61+
block_tables = batch.get_block_table_tensor()
62+
63+
sequence_lengths = batch.get_sequence_lengths()
64+
batch_size = len(sequence_lengths)
65+
kv_seq_len = sequence_lengths.max().item()
66+
67+
hidden_states = self.embed_tokens(input_ids)
68+
69+
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
70+
71+
if batch.is_prompts:
72+
output_tensor = torch.zeros(
73+
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
74+
)
75+
else:
76+
output_tensor = torch.zeros(
77+
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
78+
)
79+
sm_scale = 1.0 / (batch.head_dim**0.5)
80+
81+
for layer_id, decoder_layer in enumerate(self.layers):
82+
hidden_states = decoder_layer(
83+
hidden_states,
84+
block_tables=block_tables,
85+
k_cache=k_caches[layer_id],
86+
v_cache=v_caches[layer_id],
87+
is_prompts=batch.is_prompts,
88+
sequence_lengths=sequence_lengths,
89+
kv_seq_len=kv_seq_len,
90+
cos_sin=cos_sin,
91+
fd_inter_tensor=batch.fd_inter_tensor,
92+
output_tensor=output_tensor,
93+
sm_scale=sm_scale,
94+
)
95+
96+
if batch.is_prompts:
97+
last_token_indexs = sequence_lengths.cumsum(dim=-1)
98+
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
99+
hidden_states = self.norm(hidden_states)
100+
101+
return hidden_states
102+
103+
104+
@torch.no_grad()
105+
def llama_decoder_layer_forward(
106+
self: LlamaDecoderLayer,
107+
hidden_states: torch.Tensor,
108+
block_tables: torch.Tensor = None,
109+
k_cache: torch.Tensor = None,
110+
v_cache: torch.Tensor = None,
111+
is_prompts: bool = True,
112+
sequence_lengths: torch.Tensor = None,
113+
kv_seq_len: int = 0,
114+
cos_sin: Tuple[torch.Tensor] = None,
115+
fd_inter_tensor: FDIntermTensors = None,
116+
output_tensor: torch.Tensor = None,
117+
sm_scale: int = None,
118+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
119+
residual = hidden_states
120+
121+
hidden_states = self.input_layernorm(hidden_states)
122+
# Self Attention
123+
hidden_states = self.self_attn(
124+
hidden_states=hidden_states,
125+
block_tables=block_tables,
126+
k_cache=k_cache,
127+
v_cache=v_cache,
128+
is_prompts=is_prompts,
129+
sequence_lengths=sequence_lengths,
130+
kv_seq_len=kv_seq_len,
131+
cos_sin=cos_sin,
132+
fd_inter_tensor=fd_inter_tensor,
133+
output_tensor=output_tensor,
134+
sm_scale=sm_scale,
135+
)
136+
137+
hidden_states = residual + hidden_states
138+
139+
# Fully Connected
140+
residual = hidden_states
141+
hidden_states = self.post_attention_layernorm(hidden_states)
142+
hidden_states = self.mlp(hidden_states)
143+
hidden_states = residual + hidden_states
144+
145+
return hidden_states
146+
147+
148+
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
149+
@torch.no_grad()
150+
def llama_attn_forward(
151+
self: LlamaAttention,
152+
hidden_states: torch.Tensor,
153+
block_tables: torch.Tensor = None,
154+
k_cache: torch.Tensor = None,
155+
v_cache: torch.Tensor = None,
156+
is_prompts: bool = True,
157+
sequence_lengths: torch.Tensor = None,
158+
kv_seq_len: int = 0,
159+
cos_sin: Tuple[torch.Tensor] = None,
160+
fd_inter_tensor: FDIntermTensors = None,
161+
output_tensor: torch.Tensor = None,
162+
sm_scale: int = None,
163+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
164+
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
165+
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
166+
-1, self.num_key_value_heads, self.head_dim
167+
)
168+
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
169+
-1, self.num_key_value_heads, self.head_dim
170+
)
171+
172+
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
173+
174+
_, _, _, block_size = k_cache.shape
175+
176+
if is_prompts:
177+
attn_output = context_attention_unpadded(
178+
q=query_states,
179+
k=key_states,
180+
v=value_states,
181+
k_cache=k_cache,
182+
v_cache=v_cache,
183+
context_lengths=sequence_lengths,
184+
block_tables=block_tables,
185+
block_size=block_size,
186+
output=output_tensor,
187+
max_seq_len=kv_seq_len,
188+
sm_scale=sm_scale,
189+
)
190+
else:
191+
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
192+
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
193+
attn_output = flash_decoding_attention(
194+
q=query_states,
195+
k_cache=k_cache,
196+
v_cache=v_cache,
197+
kv_seq_len=sequence_lengths,
198+
block_tables=block_tables,
199+
block_size=block_size,
200+
max_seq_len_in_batch=kv_seq_len,
201+
output=output_tensor,
202+
mid_output=fd_inter_tensor.mid_output,
203+
mid_output_lse=fd_inter_tensor.mid_output_lse,
204+
sm_scale=sm_scale,
205+
)
206+
attn_output = attn_output.squeeze(1)
207+
208+
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
209+
attn_output = attn_output.reshape(-1, self.hidden_size)
210+
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
211+
212+
return attn_output
213+
214+
215+
@torch.no_grad()
216+
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
217+
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
218+
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
219+
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
220+
tmp_out = act_out * up_proj_out
221+
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))

colossalai/inference/modeling/models/llama.py renamed to colossalai/inference/modeling/models/padding_llama.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
context_attention_unpadded,
1212
copy_kv_to_blocked_cache,
1313
flash_decoding_attention,
14+
get_xine_cache,
1415
rotary_embedding,
1516
)
1617
from colossalai.logging import get_dist_logger
@@ -101,12 +102,7 @@ def llama_model_forward(
101102

102103
hidden_states = self.embed_tokens(input_ids)
103104

104-
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
105-
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
106-
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
107-
# cos_sin = (cos, sin)
108-
109-
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
105+
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
110106

111107
if batch.is_prompts:
112108
output_tensor = torch.zeros(
@@ -135,7 +131,9 @@ def llama_model_forward(
135131
sm_scale=sm_scale,
136132
)
137133

134+
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
138135
hidden_states = self.norm(hidden_states)
136+
139137
return hidden_states
140138

141139

@@ -327,26 +325,3 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
327325
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
328326
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
329327
return (q, k, v, indices)
330-
331-
332-
@torch.no_grad()
333-
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
334-
"""
335-
Get cos and sin for the cache, and return nopad format.
336-
Args:
337-
lengths: shape(num_seqs,), stores lenghth of each sequence.
338-
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
339-
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
340-
is_prompts: bool, mark if in prefill mode.
341-
dtype: The data type of this inference process.
342-
"""
343-
344-
if is_prompts:
345-
index_arrays = [torch.arange(length) for length in lengths]
346-
else:
347-
index_arrays = [(length - 1).view(-1) for length in lengths]
348-
indices = torch.cat(index_arrays, dim=-1)
349-
cos_output = cos_cache[indices].to(dtype=dtype)
350-
sin_output = sin_cache[indices].to(dtype=dtype)
351-
352-
return (cos_output, sin_output)
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from .llama import LlamaModelInferPolicy
1+
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
2+
from .padding_llama import PaddingLlamaModelInferPolicy
23

34
model_policy_map = {
4-
"llama": LlamaModelInferPolicy,
5+
"padding_llama": PaddingLlamaModelInferPolicy,
6+
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
57
}
68

7-
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]
9+
__all__ = ["PaddingLlamaModelInferPolicy", "NoPaddingLlamaModelInferPolicy", "model_polic_map"]

0 commit comments

Comments
 (0)