Skip to content

Commit 6402240

Browse files
tiandiao123CjhHa1
authored andcommitted
Revert "[inference] Async dynamic batching (hpcaitech#4894)" (hpcaitech#4909)
This reverts commit fced140.
1 parent eacdc4f commit 6402240

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

colossalai/inference/dynamic_batching/io_struct.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,17 @@ def mark_finished_req(self, eos_id):
103103
has_new_finish = True
104104
return has_new_finish
105105

106-
def filter_finished(self)->List[Req]:
106+
def filter_finished(self):
107107
"""
108108
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
109109
"""
110110
# TODO: the logic of return should be defined here.
111111
unfinished_req = []
112-
finished_req = []
113112
for req in self.reqs:
114113
if not req.has_generate_finished:
115-
unfinished_req.append(req)
116-
else:
117-
finished_req.append(req)
114+
unfinished_req.append(req)
118115
self.reqs = unfinished_req
119116
self.id_to_reqs = {req.request_id: req for req in self.reqs}
120-
return finished_req
121117

122118
def is_clear(self):
123119
return len(self.reqs) == 0

colossalai/inference/manager.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from .dynamic_batching.stats import Stats
99
from .tensor_parallel import TPInferEngine
1010

11-
from transformers import AutoTokenizer
12-
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1311

1412
class DynamicBatchManager:
1513
def __init__(
@@ -61,6 +59,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
6159
print("len(self.req_queue): ", len(self.req_queue))
6260
return
6361

62+
<<<<<<< HEAD
6463
def add_input(self, request_id, sampling_params, prompts):
6564
"""
6665
Encode and Add new input to req queue. support one sequence input for now.
@@ -75,6 +74,8 @@ def add_input(self, request_id, sampling_params, prompts):
7574
self.add_req(prompt_ids, sampling_params, request_id, prompts)
7675
return
7776

77+
=======
78+
>>>>>>> 78cd937f... Revert "[inference] Async dynamic batching (#4894)" (#4909)
7879
def abort(self, request_id):
7980
if self.running_batch is not None:
8081
for req in self.running_batch.reqs:
@@ -114,26 +115,6 @@ def loop_for_fwd(self):
114115
if self.running_batch is None:
115116
time.sleep(0.1) # 10ms
116117

117-
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
118-
if tokenizer is not None:
119-
self.tokenizer = tokenizer
120-
else:
121-
if "llama" in tokenizer_name.lower() and use_fast == True:
122-
print(
123-
"For some LLaMA-based models, initializing the fast tokenizer may "
124-
"take a long time. To eliminate the initialization time, consider "
125-
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
126-
"tokenizer. This is done automatically in Colossalai.")
127-
128-
tokenizer_name = _FAST_LLAMA_TOKENIZER
129-
130-
try:
131-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
132-
except TypeError as e:
133-
use_fast = False
134-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
135-
136-
137118
def _step(self):
138119
"""
139120
Logic for handling requests
@@ -144,33 +125,32 @@ def _step(self):
144125
if new_batch is not None:
145126
self.stats_tool.count_prompt_tokens(new_batch)
146127
self.running_batch = new_batch
147-
yield from self._prefill_batch(self.running_batch)
128+
self._prefill_batch(self.running_batch)
148129
self._filter_runing_batch()
149130
self.has_wait_tokens = 0
150131
return
151132

152133
if self.has_wait_tokens < self.max_wait_tokens:
153134
self.stats_tool.count_output_tokens(self.running_batch)
154-
yield from self._decode_batch(self.running_batch)
135+
self._decode_batch(self.running_batch)
155136
self._filter_runing_batch()
156137
self.has_wait_tokens += 1
157138
return
158139
else:
159140
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
160141
if new_mini_batch is not None:
161142
self.stats_tool.count_prompt_tokens(new_mini_batch)
162-
yield from self._prefill_batch(new_mini_batch)
143+
self._prefill_batch(new_mini_batch)
163144
if not new_mini_batch.is_clear():
164145
self._merge_batch(self.running_batch, new_mini_batch)
165146
self.running_batch.merge(new_mini_batch)
166147
self.has_wait_tokens = 0
167-
168148
else:
169149
self.stats_tool.count_output_tokens(self.running_batch)
170-
yield from self._decode_batch(self.running_batch)
150+
self._decode_batch(self.running_batch)
171151
self._filter_runing_batch()
172152
self.has_wait_tokens += 1
173-
153+
174154
return
175155

176156
def _init_batch(self, batch: Batch, dtype="fp16"):
@@ -206,8 +186,7 @@ def _prefill_batch(self, batch):
206186
req_to_out_token_id = ans
207187
self._add_token_id_to_req(batch, req_to_out_token_id)
208188
has_new_finished_req = batch.mark_finished_req(self.eos_id)
209-
yield from self._handle_finish_req(batch, has_new_finished_req)
210-
189+
self._handle_finish_req(batch, has_new_finished_req)
211190
# delete finished reqs
212191

213192
def _decode_batch(self, batch: Batch):
@@ -218,7 +197,7 @@ def _decode_batch(self, batch: Batch):
218197
req_to_out_token_id = ans
219198
self._add_token_id_to_req(batch, req_to_out_token_id)
220199
has_new_finished_req = batch.mark_finished_req(self.eos_id)
221-
yield from self._handle_finish_req(batch, has_new_finished_req)
200+
self._handle_finish_req(batch, has_new_finished_req)
222201

223202
def _filter_batch(self, batch: Batch):
224203
batch_id = batch.batch_id
@@ -250,13 +229,11 @@ def _remove_batch(self, batch):
250229

251230
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
252231
if has_new_finished_req:
253-
finished_reqs=batch.filter_finished()
232+
batch.filter_finished()
254233
if batch.is_clear():
255234
self._remove_batch(batch)
256235
else:
257236
self._filter_batch(batch)
258-
yield from self._output_process(finished_reqs)
259-
260237

261238
def _filter_runing_batch(self):
262239
if self.running_batch is not None and self.running_batch.is_clear():
@@ -269,6 +246,7 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
269246
req.output_metadata_list.append(new_gen_metadata)
270247
return
271248

249+
<<<<<<< HEAD
272250
def _output_process(self, finished_reqs: List[Req]):
273251
"""
274252
Process the output of a batch.
@@ -277,10 +255,13 @@ def _output_process(self, finished_reqs: List[Req]):
277255
output = self.tokenizer.decode(req.output_ids)
278256
yield req.prompts + output
279257

258+
=======
259+
>>>>>>> 78cd937f... Revert "[inference] Async dynamic batching (#4894)" (#4909)
280260
def clean_up(self):
281261
# this logic should be implemented in the future.
282262
pass
283263

264+
<<<<<<< HEAD
284265
def generate(self,prompts,sampling_params,request_id):
285266
"""
286267
Generate the output of a request.
@@ -306,3 +287,24 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
306287
raise
307288

308289
return batch_manager
290+
=======
291+
292+
def start_dynamic_batching(args, tp_engine, waiting_req_list):
293+
# try:
294+
batch_manager = DynamicBatchManager(
295+
tp_engine=tp_engine,
296+
max_total_token_num=args.max_total_token_num,
297+
batch_max_tokens=args.batch_max_tokens,
298+
eos_id=args.eos_id,
299+
log_stats=not args.disable_log_stats,
300+
log_stats_interval=args.log_stats_interval,
301+
waiting_req_list=waiting_req_list,
302+
)
303+
304+
# except Exception:
305+
# batch_manager.clean_up()
306+
# raise
307+
308+
batch_manager.loop_for_fwd()
309+
return
310+
>>>>>>> 78cd937f... Revert "[inference] Async dynamic batching (#4894)" (#4909)

0 commit comments

Comments
 (0)