11import time
22from typing import List
3+ import asyncio
34
45from .dynamic_batching .infer_batch import InferBatch
56from .dynamic_batching .io_struct import Batch , Req
89from .dynamic_batching .stats import Stats
910from .tensor_parallel import TPInferEngine
1011
12+ from transformers import AutoTokenizer
13+ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1114
1215class DynamicBatchManager :
1316 def __init__ (
@@ -54,6 +57,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
5457 self .req_queue .append (req )
5558 return
5659
60+ def add_input (self , request_id , sampling_params , input_ids ):
61+ """
62+ Encode and Add new input to req queue. support one sequence input for now.
63+ """
64+ prompt_ids = self .tokenizer .encode (input_ids )
65+ prompt_len = len (prompt_ids )
66+ if prompt_len > self .engine .max_input_len :
67+ raise ValueError (
68+ f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } "
69+ )
70+ sampling_params .stop_sentences_to_token_ids (self .tokenizer )
71+ self .add_req (prompt_ids , sampling_params , request_id )
72+ return
73+
5774 def abort (self , request_id ):
5875 if self .running_batch is not None :
5976 for req in self .running_batch .reqs :
@@ -66,13 +83,15 @@ def abort(self, request_id):
6683 req .aborted = True
6784 return
6885
69- def loop_for_fwd (self ):
86+ async def loop_for_fwd (self ):
7087 """
7188 The main loop for a dynamic batching process.
7289 """
7390 counter_count = 0
74- while self .running_batch is not None or self .req_queue .waiting_req_list :
75- self ._step ()
91+ #self.running_batch is not None or self.req_queue.waiting_req_list
92+ while True :
93+ async for item in self ._step ():
94+ yield item
7695 counter_count += 1
7796 if self .running_batch is not None :
7897 if counter_count % self .mem_usage_interval == 0 :
@@ -87,6 +106,26 @@ def loop_for_fwd(self):
87106 if self .running_batch is None :
88107 time .sleep (0.1 ) # 10ms
89108
109+ def _set_tokenizer (self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast :bool = True ,):
110+ if tokenizer is not None :
111+ self .tokenizer = tokenizer
112+ else :
113+ if "llama" in tokenizer_name .lower () and use_fast == True :
114+ print (
115+ "For some LLaMA-based models, initializing the fast tokenizer may "
116+ "take a long time. To eliminate the initialization time, consider "
117+ f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
118+ "tokenizer. This is done automatically in Colossalai." )
119+
120+ tokenizer_name = _FAST_LLAMA_TOKENIZER
121+
122+ try :
123+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
124+ except TypeError as e :
125+ use_fast = False
126+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
127+
128+
90129 def _step (self ):
91130 """
92131 Logic for handling requests
@@ -97,32 +136,33 @@ def _step(self):
97136 if new_batch is not None :
98137 self .stats_tool .count_prompt_tokens (new_batch )
99138 self .running_batch = new_batch
100- self ._prefill_batch (self .running_batch )
139+ yield from self ._prefill_batch (self .running_batch )
101140 self ._filter_runing_batch ()
102141 self .has_wait_tokens = 0
103142 return
104143
105144 if self .has_wait_tokens < self .max_wait_tokens :
106145 self .stats_tool .count_output_tokens (self .running_batch )
107- self ._decode_batch (self .running_batch )
146+ yield from self ._decode_batch (self .running_batch )
108147 self ._filter_runing_batch ()
109148 self .has_wait_tokens += 1
110149 return
111150 else :
112151 new_mini_batch = self .req_queue .generate_new_batch (self .running_batch )
113152 if new_mini_batch is not None :
114153 self .stats_tool .count_prompt_tokens (new_mini_batch )
115- self ._prefill_batch (new_mini_batch )
154+ yield from self ._prefill_batch (new_mini_batch )
116155 if not new_mini_batch .is_clear ():
117156 self ._merge_batch (self .running_batch , new_mini_batch )
118157 self .running_batch .merge (new_mini_batch )
119158 self .has_wait_tokens = 0
159+
120160 else :
121161 self .stats_tool .count_output_tokens (self .running_batch )
122- self ._decode_batch (self .running_batch )
162+ yield from self ._decode_batch (self .running_batch )
123163 self ._filter_runing_batch ()
124164 self .has_wait_tokens += 1
125-
165+
126166 return
127167
128168 def _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -158,7 +198,8 @@ def _prefill_batch(self, batch):
158198 req_to_out_token_id = ans
159199 self ._add_token_id_to_req (batch , req_to_out_token_id )
160200 has_new_finished_req = batch .mark_finished_req (self .eos_id )
161- self ._handle_finish_req (batch , has_new_finished_req )
201+ yield from self ._handle_finish_req (batch , has_new_finished_req )
202+
162203 # delete finished reqs
163204
164205 def _decode_batch (self , batch : Batch ):
@@ -169,7 +210,7 @@ def _decode_batch(self, batch: Batch):
169210 req_to_out_token_id = ans
170211 self ._add_token_id_to_req (batch , req_to_out_token_id )
171212 has_new_finished_req = batch .mark_finished_req (self .eos_id )
172- self ._handle_finish_req (batch , has_new_finished_req )
213+ yield from self ._handle_finish_req (batch , has_new_finished_req )
173214
174215 def _filter_batch (self , batch : Batch ):
175216 batch_id = batch .batch_id
@@ -201,11 +242,13 @@ def _remove_batch(self, batch):
201242
202243 def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
203244 if has_new_finished_req :
204- batch .filter_finished ()
245+ finished_reqs = batch .filter_finished ()
205246 if batch .is_clear ():
206247 self ._remove_batch (batch )
207248 else :
208249 self ._filter_batch (batch )
250+ yield from self ._output_process (finished_reqs )
251+
209252
210253 def _filter_runing_batch (self ):
211254 if self .running_batch is not None and self .running_batch .is_clear ():
@@ -218,26 +261,47 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
218261 req .output_metadata_list .append (new_gen_metadata )
219262 return
220263
264+ async def _output_process (self , finished_reqs : List [Req ]):
265+ """
266+ Process the output of a batch.
267+ """
268+ for req in finished_reqs :
269+ output = self .tokenizer .decode (req .output_ids )
270+ yield output , req .request_id , req .output_metadata_list
271+
221272 def clean_up (self ):
222273 # this logic should be implemented in the future.
223274 pass
224275
276+ async def generate (self ,request_id ,prompt_id ,sampling_params ):
277+ """
278+ Generate the output of a request.
279+ """
280+ self .add_input (request_id ,prompt_id ,sampling_params )
281+
225282
226283def start_dynamic_batching (args , tp_engine , waiting_req_list ):
227- # try:
228- batch_manager = DynamicBatchManager (
229- tp_engine = tp_engine ,
230- max_total_token_num = args .max_total_token_num ,
231- batch_max_tokens = args .batch_max_tokens ,
232- eos_id = args .eos_id ,
233- log_stats = not args .disable_log_stats ,
234- log_stats_interval = args .log_stats_interval ,
235- waiting_req_list = waiting_req_list ,
236- )
237-
238- # except Exception:
239- # batch_manager.clean_up()
240- # raise
241-
242- batch_manager .loop_for_fwd ()
243- return
284+ try :
285+ batch_manager = DynamicBatchManager (
286+ tp_engine = tp_engine ,
287+ max_total_token_num = args .max_total_token_num ,
288+ batch_max_tokens = args .batch_max_tokens ,
289+ eos_id = args .eos_id ,
290+ log_stats = not args .disable_log_stats ,
291+ log_stats_interval = args .log_stats_interval ,
292+ waiting_req_list = waiting_req_list ,
293+ )
294+
295+ except Exception :
296+ batch_manager .clean_up ()
297+ raise
298+
299+ batch_manager ._set_tokenizer (tokenizer_name = tp_engine .model .__class__ .__name__ )
300+ prod_task = asyncio .create_task (batch_manager .add_input (4 ,sampling_params = SamplingParams (),input_ids = "hello world" ))
301+
302+ asyncio .run (prod_task )
303+
304+ for item in batch_manager .loop_for_fwd ():
305+ print (item )
306+
307+ return batch_manager
0 commit comments