Skip to content

Commit 61a1b2e

Browse files
committed
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs * add do sample test * del useless lines * fix comments * fix tests * delete version tag * delete version tag * add * del test sever * fix test * fix * Revert "add" This reverts commit b9305fb.
1 parent 7bbb28e commit 61a1b2e

File tree

12 files changed

+98
-172
lines changed

12 files changed

+98
-172
lines changed

colossalai/inference/config.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
33
"""
4-
import dataclasses
54
import logging
6-
from dataclasses import dataclass
5+
from dataclasses import dataclass, fields
76
from typing import Any, Dict, Optional, Union
87

98
import torch
@@ -218,7 +217,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
218217
@classmethod
219218
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
220219
# Get the list of attributes of this dataclass.
221-
attrs = [attr.name for attr in dataclasses.fields(cls)]
220+
attrs = [attr.name for attr in fields(cls)]
222221
inference_config_args = {}
223222
for attr in attrs:
224223
if attr in config_dict:

colossalai/inference/core/async_engine.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from functools import partial
4-
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
4+
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
55

66
from colossalai.inference.core.engine import InferenceEngine
77

@@ -10,7 +10,7 @@
1010
logger = logging.getLogger("colossalai-inference")
1111

1212

13-
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
13+
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
1414
msg = "Task finished unexpectedly. This should never happen! "
1515
try:
1616
try:
@@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
2626

2727

2828
class RequstStream:
29-
"""A stream of Output for a request that can be
30-
iterated over asynchronously."""
29+
"""
30+
A stream of Output for a request that can be iterated over asynchronously.
31+
Attributes: 1.request_id: The id of the request.
32+
2._future: A future that will be set when the request is finished.
33+
Methods: set_result and get_result, results will be set when finished, for once, and
34+
the `self.future` will be set to done.
35+
36+
"""
3137

3238
def __init__(self, request_id: int) -> None:
3339
self.request_id = request_id
@@ -51,6 +57,10 @@ def finished(self) -> bool:
5157
class Tracer:
5258
"""
5359
Recording new requests and finished requests.
60+
Attributes: 1._request_streams: We create one stream for each request to trace the output.
61+
2._finished_requests: A queue to store the finished requests.
62+
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
63+
4.new_requests_event: An event to notify the engine that there are new requests.
5464
"""
5565

5666
def __init__(self) -> None:
@@ -93,8 +103,8 @@ def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStr
93103
raise KeyError(f"Request {request_id} already exists.")
94104

95105
stream = RequstStream(request_id)
106+
logger.info(f"Added request {request_id}.")
96107
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
97-
98108
self.new_requests_event.set()
99109

100110
return stream
@@ -108,6 +118,7 @@ def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
108118

109119
if request_id not in self._request_streams or self._request_streams[request_id].finished:
110120
# The request has already finished or been aborted.
121+
# The requests in new_requests will be aborted when try to get them(if marked aborted)
111122
return
112123

113124
self._request_streams[request_id].set_result(None)
@@ -117,9 +128,18 @@ def get_new_requests(self):
117128
Get new requests from http server.
118129
"""
119130
new_requests: List[Dict] = []
131+
finished_requests: Set[int] = set()
132+
133+
while not self._finished_requests.empty():
134+
request_id = self._finished_requests.get_nowait()
135+
finished_requests.add(request_id)
120136

121137
while not self._new_requests.empty():
122138
stream, new_request = self._new_requests.get_nowait()
139+
if new_request["request_id"] in finished_requests:
140+
# The request has been aborted.
141+
stream.set_result(None)
142+
continue
123143
self._request_streams[stream.request_id] = stream
124144
new_requests.append(new_request)
125145

@@ -133,7 +153,8 @@ async def wait_for_new_requests(self):
133153

134154
class _AsyncInferenceEngine(InferenceEngine):
135155
"""
136-
Async methods for Inference Engine.
156+
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
157+
Methods: 1. async_step: The async version of Engine.step()
137158
"""
138159

139160
async def async_step(self) -> List[str]:
@@ -161,22 +182,23 @@ async def async_step(self) -> List[str]:
161182
if self.inference_config.pad_input:
162183
logits = logits[:, -1, :]
163184
self.request_handler.search_tokens(self.generation_config, logits)
164-
# Return: List[Sequence]
185+
165186
finished_sequences = self.request_handler.update()
166187
for sequence in finished_sequences:
167188
sequence.output = self.tokenizer.decode(sequence.output_token_id)
168189

169-
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
190+
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
170191

171192

172193
class AsyncInferenceEngine:
173-
"""An asynchronous wrapper for LLMEngine.
194+
"""An asynchronous wrapper for the InferenceEngine class.
174195
175196
This class is used to wrap the InferenceEngine class to make it asynchronous.
176197
It uses asyncio to create a background loop that keeps processing incoming
177-
requests. The LLMEngine is kicked by the generate method when there are
178-
requests in the waiting queue. The generate method yields the outputs
179-
from the InferenceEngine to the caller.
198+
requests. Note that this class does not hold model directly, when incoming a new
199+
request, it first called `add_request` and the Tracer will record the request, putting
200+
it to the background `InferenceEngine`(done in background loop) to process. You can
201+
consider this engine as an interface.
180202
"""
181203

182204
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
@@ -253,7 +275,7 @@ async def add_request(
253275
prompt_token_ids: Optional[List[int]] = None,
254276
) -> RequstStream:
255277
"""
256-
Add a request to the background tracker(waitting queue), start the background loop if needed.
278+
Add a request to the background tracker(waiting queue), start the background loop if needed.
257279
"""
258280
if not self.background_loop_status:
259281
if self.start_engine_loop:
@@ -276,14 +298,12 @@ async def generate(
276298
"""
277299
Generate output from a request. It receives the request from http server, adds it into the
278300
waitting queue of Async Engine and streams the output sequence.
279-
280301
"""
281302
try:
282303
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
283304
return await stream.get_result()
284305

285306
except (Exception, asyncio.CancelledError) as e:
286-
# If there is an exception or coroutine is cancelled, abort the
287-
# request.
307+
# If there is an exception or coroutine is cancelled, abort the request.
288308
self._abort(request_id)
289309
raise e

colossalai/inference/core/engine.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,15 @@ def generate(
527527
List[str]: Inference result returned by one generation.
528528
"""
529529
with torch.inference_mode():
530+
<<<<<<< HEAD
530531

531532
if isinstance(prompts, str) and isinstance(request_ids, int):
532533
prompts = [prompts]
533534
request_ids = [request_ids]
535+
=======
536+
if prompts is not None or prompts_token_ids is not None:
537+
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
538+
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
534539

535540
if prompts is not None or prompts_token_ids is not None:
536541
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
@@ -612,6 +617,9 @@ def add_request(
612617

613618
block_size = self.inference_config.block_size
614619

620+
if request_ids is not None and not isinstance(request_ids, list):
621+
request_ids = [request_ids]
622+
615623
if prompts is not None and not isinstance(prompts, list):
616624
prompts = [prompts]
617625

@@ -621,9 +629,10 @@ def add_request(
621629
"input_ids"
622630
]
623631

632+
# list of torch Tensor
624633
if isinstance(prompts_token_ids, list):
625634
if isinstance(prompts_token_ids[0], torch.Tensor):
626-
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
635+
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
627636
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
628637
prompts_token_ids = prompts_token_ids.tolist()
629638
else:
@@ -738,8 +747,6 @@ def step(self) -> List[str]:
738747
logits = logits[:, -1, :]
739748
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
740749
self.request_handler.append_next_tokens(next_tokens)
741-
742-
self.request_handler.search_tokens(self.generation_config, logits)
743750
finished_sequences = self.request_handler.update()
744751

745752
return finished_sequences

colossalai/inference/core/request_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def update_batch_finished(self, batch: BatchBucket, generation_config: Generatio
328328
def check_unfinished_seqs(self) -> bool:
329329
return self._has_waiting() or not self.running_list.is_empty()
330330

331-
def current_requests_in_batch(self) -> int:
331+
def total_requests_in_batch_bucket(self) -> int:
332332
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
333333

334334
def search_tokens(self, generation_config: GenerationConfig, logits):

colossalai/inference/server/api_server.py

+6-34
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
Usage: (for local user)
77
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
88
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
9-
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
9+
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
1010
-H 'Content-Type: application/json' \
1111
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
12+
Version: V1.0
1213
"""
1314

1415
import argparse
@@ -36,7 +37,8 @@
3637
app = FastAPI()
3738

3839

39-
@app.get("/v0/models")
40+
# NOTE: (CjhHa1) models are still under development, need to be updated
41+
@app.get("/models")
4042
def get_available_models() -> Response:
4143
return JSONResponse(supported_models_dict)
4244

@@ -81,7 +83,7 @@ def stream_results():
8183
return JSONResponse(ret)
8284

8385

84-
@app.post("/v1/completion")
86+
@app.post("/completion")
8587
async def create_completion(request: Request):
8688
request_dict = await request.json()
8789
stream = request_dict.pop("stream", "false").lower()
@@ -95,7 +97,7 @@ async def create_completion(request: Request):
9597
return JSONResponse(content=ret)
9698

9799

98-
@app.post("/v1/chat")
100+
@app.post("/chat")
99101
async def create_chat(request: Request):
100102
request_dict = await request.json()
101103

@@ -127,14 +129,6 @@ def add_engine_config(parser):
127129
help="model context length. If unspecified, " "will be automatically derived from the model.",
128130
)
129131
# Parallel arguments
130-
parser.add_argument(
131-
"--worker-use-ray",
132-
action="store_true",
133-
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
134-
)
135-
136-
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
137-
138132
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
139133

140134
# KV cache arguments
@@ -149,28 +143,6 @@ def add_engine_config(parser):
149143
default=None,
150144
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
151145
)
152-
153-
# Quantization settings.
154-
parser.add_argument(
155-
"--quantization",
156-
"-q",
157-
type=str,
158-
choices=["awq", "gptq", "squeezellm", None],
159-
default=None,
160-
help="Method used to quantize the weights. If "
161-
"None, we first check the `quantization_config` "
162-
"attribute in the model config file. If that is "
163-
"None, we assume the model weights are not "
164-
"quantized and use `dtype` to determine the data "
165-
"type of the weights.",
166-
)
167-
parser.add_argument(
168-
"--enforce-eager",
169-
action="store_true",
170-
help="Always use eager-mode PyTorch. If False, "
171-
"will use eager mode and CUDA graph in hybrid "
172-
"for maximal performance and flexibility.",
173-
)
174146
return parser
175147

176148

colossalai/shardformer/layer/embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
248248
he initializer of weight, defaults to normal initializer.
249249
250250
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
251-
251+
::
252252
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
253253
renormalized to have norm max_norm. Note: this will modify weight in-place.
254254
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.

examples/inference/client/locustfile.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ class QuickstartUser(HttpUser):
77
@tag("online-generation")
88
@task(5)
99
def completion(self):
10-
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
10+
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
1111

1212
@tag("online-generation")
1313
@task(5)
1414
def completion_streaming(self):
15-
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
15+
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
1616

1717
@tag("online-chat")
1818
@task(5)
1919
def chat(self):
2020
self.client.post(
21-
"v1/chat",
21+
"/chat",
2222
json={
2323
"converation": [
2424
{"role": "system", "content": "you are a helpful assistant"},
@@ -32,7 +32,7 @@ def chat(self):
3232
@task(5)
3333
def chat_streaming(self):
3434
self.client.post(
35-
"v1/chat",
35+
"/chat",
3636
json={
3737
"converation": [
3838
{"role": "system", "content": "you are a helpful assistant"},
@@ -55,4 +55,4 @@ def generate(self):
5555
@tag("online-generation", "offline-generation")
5656
@task
5757
def get_models(self):
58-
self.client.get("/v0/models")
58+
self.client.get("/models")

0 commit comments

Comments
 (0)