1- import asyncio
1+ import time
22from typing import List
3-
4- from transformers import AutoTokenizer
3+ import asyncio
54
65from .dynamic_batching .infer_batch import InferBatch
76from .dynamic_batching .io_struct import Batch , Req
109from .dynamic_batching .stats import Stats
1110from .tensor_parallel import TPInferEngine
1211
12+ from transformers import AutoTokenizer
1313_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1414
15-
1615class DynamicBatchManager :
1716 def __init__ (
1817 self ,
1918 tp_engine : TPInferEngine ,
2019 max_total_token_num ,
2120 batch_max_tokens ,
2221 eos_id ,
23- model ,
2422 log_stats = True ,
2523 log_stats_interval = 10 ,
2624 running_batch : Batch = None ,
@@ -32,7 +30,6 @@ def __init__(
3230 batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
3331 running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
3432 eos_id : The end token of a seq
35- model: the model weight dir path, the app will load config, weights and tokenizer from this dir
3633 log_stats : whether to log stats
3734 log_stats_interval : log stats interval
3835 running_batch : running batch
@@ -48,32 +45,32 @@ def __init__(
4845 self .eos_id = eos_id
4946 self .has_wait_tokens = 0
5047 self .max_wait_tokens = 10
51- self .model = model
52-
48+
5349 self .stats_tool = Stats (log_stats , log_stats_interval )
5450 self .mem_usage_interval = log_stats_interval * 2
55- self ._set_tokenizer (tokenizer_name = self .model )
5651
57- async def add_req (self , request_id , prompt_ids : List [int ], sampling_params : SamplingParams , prompts : str = "" ):
52+ def add_req (self , prompt_ids : List [int ], sampling_params : SamplingParams , request_id : str ):
5853 """
5954 Add new request to req queue, during initialization all requests are held in waiting list.
6055 """
61- req = Req (request_id , prompt_ids , sampling_params , prompts )
56+ req = Req (request_id , prompt_ids , sampling_params )
6257 self .req_queue .append (req )
6358 return
6459
65- async def add_input (self , request_id , sampling_params , prompts ):
60+ def add_input (self , request_id , sampling_params , input_ids ):
6661 """
6762 Encode and Add new input to req queue. support one sequence input for now.
6863 """
69- prompt_ids = self .tokenizer .encode (prompts )
64+ prompt_ids = self .tokenizer .encode (input_ids )
7065 prompt_len = len (prompt_ids )
7166 if prompt_len > self .engine .max_input_len :
72- raise ValueError (f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } " )
67+ raise ValueError (
68+ f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } "
69+ )
7370 sampling_params .stop_sentences_to_token_ids (self .tokenizer )
74- self .add_req (request_id , prompt_ids , sampling_params , prompts )
71+ self .add_req (prompt_ids , sampling_params , request_id )
7572 return
76-
73+
7774 def abort (self , request_id ):
7875 if self .running_batch is not None :
7976 for req in self .running_batch .reqs :
@@ -91,15 +88,10 @@ async def loop_for_fwd(self):
9188 The main loop for a dynamic batching process.
9289 """
9390 counter_count = 0
94- # self.running_batch is not None or self.req_queue.waiting_req_list
91+ #self.running_batch is not None or self.req_queue.waiting_req_list
9592 while True :
96- if self .running_batch is not None or self .req_queue .waiting_req_list :
97- async for result in self ._step ():
98- yield result
99- else :
100- # need to wait for new requests
101- await asyncio .sleep (0.1 )
102- continue
93+ async for item in self ._step ():
94+ yield item
10395 counter_count += 1
10496 if self .running_batch is not None :
10597 if counter_count % self .mem_usage_interval == 0 :
@@ -111,33 +103,30 @@ async def loop_for_fwd(self):
111103 )
112104 self .stats_tool .print_stats ()
113105
114- def _set_tokenizer (
115- self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast : bool = True
116- ):
106+ if self .running_batch is None :
107+ time .sleep (0.1 ) # 10ms
108+
109+ def _set_tokenizer (self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast :bool = True ,):
117110 if tokenizer is not None :
118- self .tokenizer = tokenizer
111+ self .tokenizer = tokenizer
119112 else :
120113 if "llama" in tokenizer_name .lower () and use_fast == True :
121114 print (
122- "For some LLaMA-based models, initializing the fast tokenizer may "
123- "take a long time. To eliminate the initialization time, consider "
124- f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
125- "tokenizer. This is done automatically in Colossalai."
126- )
127-
128- tokenizer_name = _FAST_LLAMA_TOKENIZER
129-
130- try :
131- self .tokenizer = AutoTokenizer .from_pretrained (
132- tokenizer_name , use_fast = use_fast , trust_remote_code = trust_remote_code
133- )
134- except TypeError :
115+ "For some LLaMA-based models, initializing the fast tokenizer may "
116+ "take a long time. To eliminate the initialization time, consider "
117+ f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
118+ "tokenizer. This is done automatically in Colossalai." )
119+
120+ tokenizer_name = _FAST_LLAMA_TOKENIZER
121+
122+ try :
123+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
124+ except TypeError as e :
135125 use_fast = False
136- self .tokenizer = AutoTokenizer .from_pretrained (
137- tokenizer_name , use_fast = use_fast , trust_remote_code = trust_remote_code
138- )
126+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
127+
139128
140- async def _step (self ):
129+ def _step (self ):
141130 """
142131 Logic for handling requests
143132 """
@@ -147,36 +136,33 @@ async def _step(self):
147136 if new_batch is not None :
148137 self .stats_tool .count_prompt_tokens (new_batch )
149138 self .running_batch = new_batch
150- async for item in self ._prefill_batch (self .running_batch ):
151- yield item
139+ yield from self ._prefill_batch (self .running_batch )
152140 self ._filter_runing_batch ()
153141 self .has_wait_tokens = 0
154142 return
155143
156144 if self .has_wait_tokens < self .max_wait_tokens :
157145 self .stats_tool .count_output_tokens (self .running_batch )
158- self ._decode_batch (self .running_batch )
146+ yield from self ._decode_batch (self .running_batch )
159147 self ._filter_runing_batch ()
160148 self .has_wait_tokens += 1
161149 return
162150 else :
163151 new_mini_batch = self .req_queue .generate_new_batch (self .running_batch )
164152 if new_mini_batch is not None :
165153 self .stats_tool .count_prompt_tokens (new_mini_batch )
166- async for item in self ._prefill_batch (new_mini_batch ):
167- yield item
154+ yield from self ._prefill_batch (new_mini_batch )
168155 if not new_mini_batch .is_clear ():
169156 self ._merge_batch (self .running_batch , new_mini_batch )
170157 self .running_batch .merge (new_mini_batch )
171158 self .has_wait_tokens = 0
172-
159+
173160 else :
174161 self .stats_tool .count_output_tokens (self .running_batch )
175- async for item in self ._decode_batch (self .running_batch ):
176- yield item
162+ yield from self ._decode_batch (self .running_batch )
177163 self ._filter_runing_batch ()
178164 self .has_wait_tokens += 1
179-
165+
180166 return
181167
182168 def _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -201,7 +187,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
201187 )
202188 self .engine .cache [batch_id ] = batch_data
203189
204- async def _prefill_batch (self , batch ):
190+ def _prefill_batch (self , batch ):
205191 """
206192 For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
207193 """
@@ -212,20 +198,19 @@ async def _prefill_batch(self, batch):
212198 req_to_out_token_id = ans
213199 self ._add_token_id_to_req (batch , req_to_out_token_id )
214200 has_new_finished_req = batch .mark_finished_req (self .eos_id )
215- async for item in self ._handle_finish_req (batch , has_new_finished_req ):
216- yield item
201+ yield from self ._handle_finish_req (batch , has_new_finished_req )
202+
217203 # delete finished reqs
218204
219- async def _decode_batch (self , batch : Batch ):
205+ def _decode_batch (self , batch : Batch ):
220206 """
221207 Decoding process
222208 """
223209 ans = self .engine ._decode_batch (batch .batch_id )
224210 req_to_out_token_id = ans
225211 self ._add_token_id_to_req (batch , req_to_out_token_id )
226212 has_new_finished_req = batch .mark_finished_req (self .eos_id )
227- async for item in self ._handle_finish_req (batch , has_new_finished_req ):
228- yield item
213+ yield from self ._handle_finish_req (batch , has_new_finished_req )
229214
230215 def _filter_batch (self , batch : Batch ):
231216 batch_id = batch .batch_id
@@ -255,15 +240,15 @@ def _remove_batch(self, batch):
255240 batch .free_self ()
256241 del batch
257242
258- async def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
243+ def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
259244 if has_new_finished_req :
260- finished_reqs = batch .filter_finished ()
245+ finished_reqs = batch .filter_finished ()
261246 if batch .is_clear ():
262247 self ._remove_batch (batch )
263248 else :
264249 self ._filter_batch (batch )
265- async for item in self ._output_process (finished_reqs ):
266- yield item
250+ yield from self ._output_process (finished_reqs )
251+
267252
268253 def _filter_runing_batch (self ):
269254 if self .running_batch is not None and self .running_batch .is_clear ():
@@ -282,24 +267,18 @@ async def _output_process(self, finished_reqs: List[Req]):
282267 """
283268 for req in finished_reqs :
284269 output = self .tokenizer .decode (req .output_ids )
285- yield req .prompts + output
270+ yield output , req .request_id , req . output_metadata_list
286271
287272 def clean_up (self ):
288273 # this logic should be implemented in the future.
289274 pass
290275
291- async def generate (self , request_id , prompt_id , sampling_params ):
276+ async def generate (self ,request_id ,prompt_id ,sampling_params ):
292277 """
293278 Generate the output of a request.
294279 """
295-
296- await self .add_input (request_id , prompt_id , sampling_params )
297-
298-
299- async def process_data (dbm ):
300- async for data in dbm .loop_for_fwd ():
301- print (data )
302-
280+ self .add_input (request_id ,prompt_id ,sampling_params )
281+
303282
304283def start_dynamic_batching (args , tp_engine , waiting_req_list ):
305284 try :
@@ -308,13 +287,21 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
308287 max_total_token_num = args .max_total_token_num ,
309288 batch_max_tokens = args .batch_max_tokens ,
310289 eos_id = args .eos_id ,
311- model = args .model ,
312290 log_stats = not args .disable_log_stats ,
313291 log_stats_interval = args .log_stats_interval ,
314292 waiting_req_list = waiting_req_list ,
315293 )
316294
317295 except Exception :
318- raise RuntimeError ("Failed to start dynamic batching" )
296+ batch_manager .clean_up ()
297+ raise
298+
299+ batch_manager ._set_tokenizer (tokenizer_name = tp_engine .model .__class__ .__name__ )
300+ prod_task = asyncio .create_task (batch_manager .add_input (4 ,sampling_params = SamplingParams (),input_ids = "hello world" ))
301+
302+ asyncio .run (prod_task )
303+
304+ for item in batch_manager .loop_for_fwd ():
305+ print (item )
319306
320307 return batch_manager
0 commit comments