88from .dynamic_batching .stats import Stats
99from .tensor_parallel import TPInferEngine
1010
11- from transformers import AutoTokenizer
12- _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1311
1412class DynamicBatchManager :
1513 def __init__ (
@@ -61,6 +59,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
6159 print ("len(self.req_queue): " , len (self .req_queue ))
6260 return
6361
62+ < << << << HEAD
6463 def add_input (self , request_id , sampling_params , prompts ):
6564 """
6665 Encode and Add new input to req queue. support one sequence input for now.
@@ -75,6 +74,8 @@ def add_input(self, request_id, sampling_params, prompts):
7574 self .add_req (prompt_ids , sampling_params , request_id , prompts )
7675 return
7776
77+ == == == =
78+ >> >> >> > 78 cd937f ... Revert "[inference] Async dynamic batching (#4894)" (#4909)
7879 def abort (self , request_id ):
7980 if self .running_batch is not None :
8081 for req in self .running_batch .reqs :
@@ -114,26 +115,6 @@ def loop_for_fwd(self):
114115 if self .running_batch is None :
115116 time .sleep (0.1 ) # 10ms
116117
117- def _set_tokenizer (self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast :bool = True ,):
118- if tokenizer is not None :
119- self .tokenizer = tokenizer
120- else :
121- if "llama" in tokenizer_name .lower () and use_fast == True :
122- print (
123- "For some LLaMA-based models, initializing the fast tokenizer may "
124- "take a long time. To eliminate the initialization time, consider "
125- f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
126- "tokenizer. This is done automatically in Colossalai." )
127-
128- tokenizer_name = _FAST_LLAMA_TOKENIZER
129-
130- try :
131- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
132- except TypeError as e :
133- use_fast = False
134- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
135-
136-
137118 def _step (self ):
138119 """
139120 Logic for handling requests
@@ -144,33 +125,32 @@ def _step(self):
144125 if new_batch is not None :
145126 self .stats_tool .count_prompt_tokens (new_batch )
146127 self .running_batch = new_batch
147- yield from self ._prefill_batch (self .running_batch )
128+ self ._prefill_batch (self .running_batch )
148129 self ._filter_runing_batch ()
149130 self .has_wait_tokens = 0
150131 return
151132
152133 if self .has_wait_tokens < self .max_wait_tokens :
153134 self .stats_tool .count_output_tokens (self .running_batch )
154- yield from self ._decode_batch (self .running_batch )
135+ self ._decode_batch (self .running_batch )
155136 self ._filter_runing_batch ()
156137 self .has_wait_tokens + = 1
157138 return
158139 else :
159140 new_mini_batch = self .req_queue .generate_new_batch (self .running_batch )
160141 if new_mini_batch is not None :
161142 self .stats_tool .count_prompt_tokens (new_mini_batch )
162- yield from self ._prefill_batch (new_mini_batch )
143+ self ._prefill_batch (new_mini_batch )
163144 if not new_mini_batch .is_clear ():
164145 self ._merge_batch (self .running_batch , new_mini_batch )
165146 self .running_batch .merge (new_mini_batch )
166147 self .has_wait_tokens = 0
167-
168148 else :
169149 self .stats_tool .count_output_tokens (self .running_batch )
170- yield from self ._decode_batch (self .running_batch )
150+ self ._decode_batch (self .running_batch )
171151 self ._filter_runing_batch ()
172152 self .has_wait_tokens + = 1
173-
153+
174154 return
175155
176156 def _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -206,8 +186,7 @@ def _prefill_batch(self, batch):
206186 req_to_out_token_id = ans
207187 self ._add_token_id_to_req (batch , req_to_out_token_id )
208188 has_new_finished_req = batch .mark_finished_req (self .eos_id )
209- yield from self ._handle_finish_req (batch , has_new_finished_req )
210-
189+ self ._handle_finish_req (batch , has_new_finished_req )
211190 # delete finished reqs
212191
213192 def _decode_batch (self , batch : Batch ):
@@ -218,7 +197,7 @@ def _decode_batch(self, batch: Batch):
218197 req_to_out_token_id = ans
219198 self ._add_token_id_to_req (batch , req_to_out_token_id )
220199 has_new_finished_req = batch .mark_finished_req (self .eos_id )
221- yield from self ._handle_finish_req (batch , has_new_finished_req )
200+ self ._handle_finish_req (batch , has_new_finished_req )
222201
223202 def _filter_batch (self , batch : Batch ):
224203 batch_id = batch .batch_id
@@ -250,13 +229,11 @@ def _remove_batch(self, batch):
250229
251230 def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
252231 if has_new_finished_req :
253- finished_reqs = batch .filter_finished ()
232+ batch .filter_finished ()
254233 if batch .is_clear ():
255234 self ._remove_batch (batch )
256235 else :
257236 self ._filter_batch (batch )
258- yield from self ._output_process (finished_reqs )
259-
260237
261238 def _filter_runing_batch (self ):
262239 if self .running_batch is not None and self .running_batch .is_clear ():
@@ -269,6 +246,7 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
269246 req .output_metadata_list .append (new_gen_metadata )
270247 return
271248
249+ << << << < HEAD
272250 def _output_process (self , finished_reqs : List [Req ]):
273251 """
274252 Process the output of a batch.
@@ -277,10 +255,13 @@ def _output_process(self, finished_reqs: List[Req]):
277255 output = self .tokenizer .decode (req .output_ids )
278256 yield req .prompts + output
279257
258+ == == == =
259+ >> >> >> > 78 cd937f ... Revert "[inference] Async dynamic batching (#4894)" (#4909)
280260 def clean_up (self ):
281261 # this logic should be implemented in the future.
282262 pass
283263
264+ << << << < HEAD
284265 def generate (self ,prompts ,sampling_params ,request_id ):
285266 """
286267 Generate the output of a request.
@@ -306,3 +287,24 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
306287 raise
307288
308289 return batch_manager
290+ == == == =
291+
292+ def start_dynamic_batching (args , tp_engine , waiting_req_list ):
293+ # try:
294+ batch_manager = DynamicBatchManager (
295+ tp_engine = tp_engine ,
296+ max_total_token_num = args .max_total_token_num ,
297+ batch_max_tokens = args .batch_max_tokens ,
298+ eos_id = args .eos_id ,
299+ log_stats = not args .disable_log_stats ,
300+ log_stats_interval = args .log_stats_interval ,
301+ waiting_req_list = waiting_req_list ,
302+ )
303+
304+ # except Exception:
305+ # batch_manager.clean_up()
306+ # raise
307+
308+ batch_manager .loop_for_fwd ()
309+ return
310+ > >> >> >> 78 cd937f ... Revert "[inference] Async dynamic batching (#4894)" (#4909)
0 commit comments