Skip to content

Commit d509e79

Browse files
authored
Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
This reverts commit fbf3c09.
1 parent fbf3c09 commit d509e79

File tree

4 files changed

+109
-107
lines changed

4 files changed

+109
-107
lines changed

colossalai/inference/dynamic_batching/io_struct.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class Req:
7-
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
7+
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
88
self.request_id = request_id
99
self.prompt_ids = prompt_ids
1010
self.input_len = len(prompt_ids)
@@ -14,7 +14,6 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompt
1414
self.output_metadata_list = []
1515
self.has_generate_finished = False
1616
self.aborted = False
17-
self.prompts = prompts
1817

1918
def to_rpc_obj(self):
2019
return {
@@ -37,11 +36,7 @@ def stop_sequences_matched(self):
3736
if self.sample_params.stop_sequences is not None:
3837
for stop_token_ids in self.sample_params.stop_sequences:
3938
stop_len = len(stop_token_ids)
40-
if (
41-
stop_len > 0
42-
and len(self.output_ids) >= stop_len
43-
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
44-
):
39+
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
4540
return True
4641
return False
4742

@@ -107,7 +102,7 @@ def mark_finished_req(self, eos_id):
107102
has_new_finish = True
108103
return has_new_finish
109104

110-
def filter_finished(self) -> List[Req]:
105+
def filter_finished(self)->List[Req]:
111106
"""
112107
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
113108
"""
@@ -116,9 +111,9 @@ def filter_finished(self) -> List[Req]:
116111
finished_req = []
117112
for req in self.reqs:
118113
if not req.has_generate_finished:
119-
unfinished_req.append(req)
114+
unfinished_req.append(req)
120115
else:
121-
finished_req.append(req)
116+
finished_req.append(req)
122117
self.reqs = unfinished_req
123118
self.id_to_reqs = {req.request_id: req for req in self.reqs}
124119
return finished_req

colossalai/inference/manager.py

Lines changed: 63 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import asyncio
1+
import time
22
from typing import List
3-
4-
from transformers import AutoTokenizer
3+
import asyncio
54

65
from .dynamic_batching.infer_batch import InferBatch
76
from .dynamic_batching.io_struct import Batch, Req
@@ -10,17 +9,16 @@
109
from .dynamic_batching.stats import Stats
1110
from .tensor_parallel import TPInferEngine
1211

12+
from transformers import AutoTokenizer
1313
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1414

15-
1615
class DynamicBatchManager:
1716
def __init__(
1817
self,
1918
tp_engine: TPInferEngine,
2019
max_total_token_num,
2120
batch_max_tokens,
2221
eos_id,
23-
model,
2422
log_stats=True,
2523
log_stats_interval=10,
2624
running_batch: Batch = None,
@@ -32,7 +30,6 @@ def __init__(
3230
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
3331
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
3432
eos_id : The end token of a seq
35-
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
3633
log_stats : whether to log stats
3734
log_stats_interval : log stats interval
3835
running_batch : running batch
@@ -48,32 +45,32 @@ def __init__(
4845
self.eos_id = eos_id
4946
self.has_wait_tokens = 0
5047
self.max_wait_tokens = 10
51-
self.model = model
52-
48+
5349
self.stats_tool = Stats(log_stats, log_stats_interval)
5450
self.mem_usage_interval = log_stats_interval * 2
55-
self._set_tokenizer(tokenizer_name=self.model)
5651

57-
async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
52+
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
5853
"""
5954
Add new request to req queue, during initialization all requests are held in waiting list.
6055
"""
61-
req = Req(request_id, prompt_ids, sampling_params, prompts)
56+
req = Req(request_id, prompt_ids, sampling_params)
6257
self.req_queue.append(req)
6358
return
6459

65-
async def add_input(self, request_id, sampling_params, prompts):
60+
def add_input(self, request_id, sampling_params, input_ids):
6661
"""
6762
Encode and Add new input to req queue. support one sequence input for now.
6863
"""
69-
prompt_ids = self.tokenizer.encode(prompts)
64+
prompt_ids = self.tokenizer.encode(input_ids)
7065
prompt_len = len(prompt_ids)
7166
if prompt_len > self.engine.max_input_len:
72-
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
67+
raise ValueError(
68+
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
69+
)
7370
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
74-
self.add_req(request_id, prompt_ids, sampling_params, prompts)
71+
self.add_req(prompt_ids, sampling_params, request_id)
7572
return
76-
73+
7774
def abort(self, request_id):
7875
if self.running_batch is not None:
7976
for req in self.running_batch.reqs:
@@ -91,15 +88,10 @@ async def loop_for_fwd(self):
9188
The main loop for a dynamic batching process.
9289
"""
9390
counter_count = 0
94-
# self.running_batch is not None or self.req_queue.waiting_req_list
91+
#self.running_batch is not None or self.req_queue.waiting_req_list
9592
while True:
96-
if self.running_batch is not None or self.req_queue.waiting_req_list:
97-
async for result in self._step():
98-
yield result
99-
else:
100-
# need to wait for new requests
101-
await asyncio.sleep(0.1)
102-
continue
93+
async for item in self._step():
94+
yield item
10395
counter_count += 1
10496
if self.running_batch is not None:
10597
if counter_count % self.mem_usage_interval == 0:
@@ -111,33 +103,30 @@ async def loop_for_fwd(self):
111103
)
112104
self.stats_tool.print_stats()
113105

114-
def _set_tokenizer(
115-
self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
116-
):
106+
if self.running_batch is None:
107+
time.sleep(0.1) # 10ms
108+
109+
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
117110
if tokenizer is not None:
118-
self.tokenizer = tokenizer
111+
self.tokenizer = tokenizer
119112
else:
120113
if "llama" in tokenizer_name.lower() and use_fast == True:
121114
print(
122-
"For some LLaMA-based models, initializing the fast tokenizer may "
123-
"take a long time. To eliminate the initialization time, consider "
124-
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
125-
"tokenizer. This is done automatically in Colossalai."
126-
)
127-
128-
tokenizer_name = _FAST_LLAMA_TOKENIZER
129-
130-
try:
131-
self.tokenizer = AutoTokenizer.from_pretrained(
132-
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
133-
)
134-
except TypeError:
115+
"For some LLaMA-based models, initializing the fast tokenizer may "
116+
"take a long time. To eliminate the initialization time, consider "
117+
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
118+
"tokenizer. This is done automatically in Colossalai.")
119+
120+
tokenizer_name = _FAST_LLAMA_TOKENIZER
121+
122+
try:
123+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
124+
except TypeError as e:
135125
use_fast = False
136-
self.tokenizer = AutoTokenizer.from_pretrained(
137-
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
138-
)
126+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
127+
139128

140-
async def _step(self):
129+
def _step(self):
141130
"""
142131
Logic for handling requests
143132
"""
@@ -147,36 +136,33 @@ async def _step(self):
147136
if new_batch is not None:
148137
self.stats_tool.count_prompt_tokens(new_batch)
149138
self.running_batch = new_batch
150-
async for item in self._prefill_batch(self.running_batch):
151-
yield item
139+
yield from self._prefill_batch(self.running_batch)
152140
self._filter_runing_batch()
153141
self.has_wait_tokens = 0
154142
return
155143

156144
if self.has_wait_tokens < self.max_wait_tokens:
157145
self.stats_tool.count_output_tokens(self.running_batch)
158-
self._decode_batch(self.running_batch)
146+
yield from self._decode_batch(self.running_batch)
159147
self._filter_runing_batch()
160148
self.has_wait_tokens += 1
161149
return
162150
else:
163151
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
164152
if new_mini_batch is not None:
165153
self.stats_tool.count_prompt_tokens(new_mini_batch)
166-
async for item in self._prefill_batch(new_mini_batch):
167-
yield item
154+
yield from self._prefill_batch(new_mini_batch)
168155
if not new_mini_batch.is_clear():
169156
self._merge_batch(self.running_batch, new_mini_batch)
170157
self.running_batch.merge(new_mini_batch)
171158
self.has_wait_tokens = 0
172-
159+
173160
else:
174161
self.stats_tool.count_output_tokens(self.running_batch)
175-
async for item in self._decode_batch(self.running_batch):
176-
yield item
162+
yield from self._decode_batch(self.running_batch)
177163
self._filter_runing_batch()
178164
self.has_wait_tokens += 1
179-
165+
180166
return
181167

182168
def _init_batch(self, batch: Batch, dtype="fp16"):
@@ -201,7 +187,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
201187
)
202188
self.engine.cache[batch_id] = batch_data
203189

204-
async def _prefill_batch(self, batch):
190+
def _prefill_batch(self, batch):
205191
"""
206192
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
207193
"""
@@ -212,20 +198,19 @@ async def _prefill_batch(self, batch):
212198
req_to_out_token_id = ans
213199
self._add_token_id_to_req(batch, req_to_out_token_id)
214200
has_new_finished_req = batch.mark_finished_req(self.eos_id)
215-
async for item in self._handle_finish_req(batch, has_new_finished_req):
216-
yield item
201+
yield from self._handle_finish_req(batch, has_new_finished_req)
202+
217203
# delete finished reqs
218204

219-
async def _decode_batch(self, batch: Batch):
205+
def _decode_batch(self, batch: Batch):
220206
"""
221207
Decoding process
222208
"""
223209
ans = self.engine._decode_batch(batch.batch_id)
224210
req_to_out_token_id = ans
225211
self._add_token_id_to_req(batch, req_to_out_token_id)
226212
has_new_finished_req = batch.mark_finished_req(self.eos_id)
227-
async for item in self._handle_finish_req(batch, has_new_finished_req):
228-
yield item
213+
yield from self._handle_finish_req(batch, has_new_finished_req)
229214

230215
def _filter_batch(self, batch: Batch):
231216
batch_id = batch.batch_id
@@ -255,15 +240,15 @@ def _remove_batch(self, batch):
255240
batch.free_self()
256241
del batch
257242

258-
async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
243+
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
259244
if has_new_finished_req:
260-
finished_reqs = batch.filter_finished()
245+
finished_reqs=batch.filter_finished()
261246
if batch.is_clear():
262247
self._remove_batch(batch)
263248
else:
264249
self._filter_batch(batch)
265-
async for item in self._output_process(finished_reqs):
266-
yield item
250+
yield from self._output_process(finished_reqs)
251+
267252

268253
def _filter_runing_batch(self):
269254
if self.running_batch is not None and self.running_batch.is_clear():
@@ -282,24 +267,18 @@ async def _output_process(self, finished_reqs: List[Req]):
282267
"""
283268
for req in finished_reqs:
284269
output = self.tokenizer.decode(req.output_ids)
285-
yield req.prompts + output
270+
yield output, req.request_id, req.output_metadata_list
286271

287272
def clean_up(self):
288273
# this logic should be implemented in the future.
289274
pass
290275

291-
async def generate(self, request_id, prompt_id, sampling_params):
276+
async def generate(self,request_id,prompt_id,sampling_params):
292277
"""
293278
Generate the output of a request.
294279
"""
295-
296-
await self.add_input(request_id, prompt_id, sampling_params)
297-
298-
299-
async def process_data(dbm):
300-
async for data in dbm.loop_for_fwd():
301-
print(data)
302-
280+
self.add_input(request_id,prompt_id,sampling_params)
281+
303282

304283
def start_dynamic_batching(args, tp_engine, waiting_req_list):
305284
try:
@@ -308,13 +287,21 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
308287
max_total_token_num=args.max_total_token_num,
309288
batch_max_tokens=args.batch_max_tokens,
310289
eos_id=args.eos_id,
311-
model=args.model,
312290
log_stats=not args.disable_log_stats,
313291
log_stats_interval=args.log_stats_interval,
314292
waiting_req_list=waiting_req_list,
315293
)
316294

317295
except Exception:
318-
raise RuntimeError("Failed to start dynamic batching")
296+
batch_manager.clean_up()
297+
raise
298+
299+
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
300+
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
301+
302+
asyncio.run(prod_task)
303+
304+
for item in batch_manager.loop_for_fwd():
305+
print(item)
319306

320307
return batch_manager

colossalai/inference/test_async.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import asyncio
2+
3+
shared_list = []
4+
5+
async def producer():
6+
for i in range(5):
7+
await asyncio.sleep(1) # 模拟异步获取数据的操作
8+
shared_list.append(i)
9+
print(f"Produced {i}")
10+
11+
async def consumer():
12+
last_index = 0
13+
while True:
14+
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
15+
if last_index < len(shared_list):
16+
item = shared_list[last_index]
17+
print(f"Consumed {item}")
18+
yield item
19+
last_index += 1
20+
21+
async def main():
22+
# 创建生产者和消费者任务
23+
prod_task = asyncio.create_task(producer())
24+
25+
# 等待生产者任务完成
26+
await prod_task
27+
28+
async for data in consumer():
29+
print(data)
30+
# 为了示例的目的,我们只等待一段时间,然后停止消费者
31+
await asyncio.sleep(5)
32+
33+
asyncio.run(main())

0 commit comments

Comments
 (0)