1
1
import asyncio
2
2
import logging
3
3
from functools import partial
4
- from typing import AsyncIterator , Dict , Iterable , List , Optional , Tuple , Type
4
+ from typing import AsyncIterator , Dict , Iterable , List , Optional , Set , Tuple , Type
5
5
6
6
from colossalai .inference .core .engine import InferenceEngine
7
7
10
10
logger = logging .getLogger ("colossalai-inference" )
11
11
12
12
13
- def _raise_exception_on_finish (task : asyncio .Task , request_tracker : "RequestTracker " ) -> None :
13
+ def _raise_exception_on_finish (task : asyncio .Task , request_tracker : "Tracer " ) -> None :
14
14
msg = "Task finished unexpectedly. This should never happen! "
15
15
try :
16
16
try :
@@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
26
26
27
27
28
28
class RequstStream :
29
- """A stream of Output for a request that can be
30
- iterated over asynchronously."""
29
+ """
30
+ A stream of Output for a request that can be iterated over asynchronously.
31
+ Attributes: 1.request_id: The id of the request.
32
+ 2._future: A future that will be set when the request is finished.
33
+ Methods: set_result and get_result, results will be set when finished, for once, and
34
+ the `self.future` will be set to done.
35
+
36
+ """
31
37
32
38
def __init__ (self , request_id : int ) -> None :
33
39
self .request_id = request_id
@@ -51,6 +57,10 @@ def finished(self) -> bool:
51
57
class Tracer :
52
58
"""
53
59
Recording new requests and finished requests.
60
+ Attributes: 1._request_streams: We create one stream for each request to trace the output.
61
+ 2._finished_requests: A queue to store the finished requests.
62
+ 3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
63
+ 4.new_requests_event: An event to notify the engine that there are new requests.
54
64
"""
55
65
56
66
def __init__ (self ) -> None :
@@ -93,8 +103,8 @@ def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStr
93
103
raise KeyError (f"Request { request_id } already exists." )
94
104
95
105
stream = RequstStream (request_id )
106
+ logger .info (f"Added request { request_id } ." )
96
107
self ._new_requests .put_nowait ((stream , {"request_id" : request_id , ** engine_add_request_kwargs }))
97
-
98
108
self .new_requests_event .set ()
99
109
100
110
return stream
@@ -108,6 +118,7 @@ def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
108
118
109
119
if request_id not in self ._request_streams or self ._request_streams [request_id ].finished :
110
120
# The request has already finished or been aborted.
121
+ # The requests in new_requests will be aborted when try to get them(if marked aborted)
111
122
return
112
123
113
124
self ._request_streams [request_id ].set_result (None )
@@ -117,9 +128,18 @@ def get_new_requests(self):
117
128
Get new requests from http server.
118
129
"""
119
130
new_requests : List [Dict ] = []
131
+ finished_requests : Set [int ] = set ()
132
+
133
+ while not self ._finished_requests .empty ():
134
+ request_id = self ._finished_requests .get_nowait ()
135
+ finished_requests .add (request_id )
120
136
121
137
while not self ._new_requests .empty ():
122
138
stream , new_request = self ._new_requests .get_nowait ()
139
+ if new_request ["request_id" ] in finished_requests :
140
+ # The request has been aborted.
141
+ stream .set_result (None )
142
+ continue
123
143
self ._request_streams [stream .request_id ] = stream
124
144
new_requests .append (new_request )
125
145
@@ -133,7 +153,8 @@ async def wait_for_new_requests(self):
133
153
134
154
class _AsyncInferenceEngine (InferenceEngine ):
135
155
"""
136
- Async methods for Inference Engine.
156
+ Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
157
+ Methods: 1. async_step: The async version of Engine.step()
137
158
"""
138
159
139
160
async def async_step (self ) -> List [str ]:
@@ -161,22 +182,23 @@ async def async_step(self) -> List[str]:
161
182
if self .inference_config .pad_input :
162
183
logits = logits [:, - 1 , :]
163
184
self .request_handler .search_tokens (self .generation_config , logits )
164
- # Return: List[Sequence]
185
+
165
186
finished_sequences = self .request_handler .update ()
166
187
for sequence in finished_sequences :
167
188
sequence .output = self .tokenizer .decode (sequence .output_token_id )
168
189
169
- return finished_sequences , self .request_handler .current_requests_in_batch () > 0
190
+ return finished_sequences , self .request_handler .total_requests_in_batch_bucket () > 0
170
191
171
192
172
193
class AsyncInferenceEngine :
173
- """An asynchronous wrapper for LLMEngine .
194
+ """An asynchronous wrapper for the InferenceEngine class .
174
195
175
196
This class is used to wrap the InferenceEngine class to make it asynchronous.
176
197
It uses asyncio to create a background loop that keeps processing incoming
177
- requests. The LLMEngine is kicked by the generate method when there are
178
- requests in the waiting queue. The generate method yields the outputs
179
- from the InferenceEngine to the caller.
198
+ requests. Note that this class does not hold model directly, when incoming a new
199
+ request, it first called `add_request` and the Tracer will record the request, putting
200
+ it to the background `InferenceEngine`(done in background loop) to process. You can
201
+ consider this engine as an interface.
180
202
"""
181
203
182
204
_engine_class : Type [_AsyncInferenceEngine ] = _AsyncInferenceEngine
@@ -253,7 +275,7 @@ async def add_request(
253
275
prompt_token_ids : Optional [List [int ]] = None ,
254
276
) -> RequstStream :
255
277
"""
256
- Add a request to the background tracker(waitting queue), start the background loop if needed.
278
+ Add a request to the background tracker(waiting queue), start the background loop if needed.
257
279
"""
258
280
if not self .background_loop_status :
259
281
if self .start_engine_loop :
@@ -276,14 +298,12 @@ async def generate(
276
298
"""
277
299
Generate output from a request. It receives the request from http server, adds it into the
278
300
waitting queue of Async Engine and streams the output sequence.
279
-
280
301
"""
281
302
try :
282
303
stream = await self .add_request (request_id , prompt , prompt_token_ids = prompt_token_ids )
283
304
return await stream .get_result ()
284
305
285
306
except (Exception , asyncio .CancelledError ) as e :
286
- # If there is an exception or coroutine is cancelled, abort the
287
- # request.
307
+ # If there is an exception or coroutine is cancelled, abort the request.
288
308
self ._abort (request_id )
289
309
raise e
0 commit comments