Skip to content

Commit cf39bf6

Browse files
authored
[Format] Apply isort and black for python/ (mlc-ai#1097)
[Format] Apply isort and black on `python/` The commands I am using are: ``` isort --profile black python/ black python/ ``` It is always recommended to format the code before submission, given we don't have a linter CI yet.
1 parent 62d0c03 commit cf39bf6

File tree

6 files changed

+56
-36
lines changed

6 files changed

+56
-36
lines changed

python/mlc_chat/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,5 @@
22
33
MLC Chat is the app runtime of MLC LLM.
44
"""
5+
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
56
from .libinfo import __version__
6-
from .chat_module import ChatModule
7-
from .chat_module import ConvConfig
8-
from .chat_module import ChatConfig
9-
from .chat_module import GenerationConfig

python/mlc_chat/chat_module.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,12 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi
352352
# We override using user's chat config
353353
for field in fields(user_chat_config):
354354
field_name = field.name
355-
if field_name == 'model_lib':
356-
warn_msg = ('WARNING: Do not override "model_lib" in ChatConfig. '
357-
'This override will be ignored. '
358-
'Please use ChatModule.model_lib_path to override the full model library path instead.')
355+
if field_name == "model_lib":
356+
warn_msg = (
357+
'WARNING: Do not override "model_lib" in ChatConfig. '
358+
"This override will be ignored. "
359+
"Please use ChatModule.model_lib_path to override the full model library path instead."
360+
)
359361
warnings.warn(warn_msg)
360362
continue
361363
field_value = getattr(user_chat_config, field_name)
@@ -740,7 +742,12 @@ def __init__(
740742

741743
# 5. Look up model library
742744
self.model_lib_path = _get_lib_module_path(
743-
model, self.model_path, self.chat_config, model_lib_path, device_name, self.config_file_path
745+
model,
746+
self.model_path,
747+
self.chat_config,
748+
model_lib_path,
749+
device_name,
750+
self.config_file_path,
744751
)
745752

746753
# 6. Call reload

python/mlc_chat/embeddings/openai.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
from __future__ import annotations
22

3-
from langchain.embeddings import OpenAIEmbeddings
4-
from langchain.embeddings.openai import embed_with_retry, async_embed_with_retry
5-
63
import logging
7-
from typing import (
8-
List,
9-
Optional,
10-
Sequence,
11-
Tuple,
12-
)
4+
from typing import List, Optional, Sequence, Tuple
135

146
import numpy as np
7+
from langchain.embeddings import OpenAIEmbeddings
8+
from langchain.embeddings.openai import async_embed_with_retry, embed_with_retry
159

1610
logger = logging.getLogger(__name__)
1711

@@ -121,9 +115,9 @@ def _get_len_safe_embeddings(
121115
self,
122116
input="",
123117
**self._invocation_params,
124-
)[
125-
"data"
126-
][0]["embedding"]
118+
)["data"][
119+
0
120+
]["embedding"]
127121
for _result, num_tokens in zip(results, num_tokens_in_batch):
128122
if len(_result) == 0:
129123
average = empty_average
@@ -155,7 +149,9 @@ async def _aget_len_safe_embeddings(
155149
input="",
156150
**self._invocation_params,
157151
)
158-
)["data"][0]["embedding"]
152+
)[
153+
"data"
154+
][0]["embedding"]
159155
for _result, num_tokens in zip(results, num_tokens_in_batch):
160156
if len(_result) == 0:
161157
average = empty_average

python/mlc_chat/gradio.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# pylint: disable=import-error, import-outside-toplevel, invalid-name, line-too-long, protected-access
33
# too-many-instance-attributes, too-many-locals, unused-import
44

5-
from typing import Dict
65
import argparse
7-
import os
86
import glob
7+
import os
8+
from typing import Dict
9+
910
import gradio as gr
1011

1112
from .chat_module import ChatModule
@@ -148,7 +149,12 @@ def gradio_stats(self):
148149

149150

150151
def launch_gradio(
151-
artifact_path: str = "dist", device: str = "auto", port: int = 7860, share: bool = False, host: str = "127.0.0.1"):
152+
artifact_path: str = "dist",
153+
device: str = "auto",
154+
port: int = 7860,
155+
share: bool = False,
156+
host: str = "127.0.0.1",
157+
):
152158
r"""Launch the gradio interface with a given port, creating a publically sharable link if specified."""
153159

154160
# create a gradio module
@@ -230,7 +236,7 @@ def launch_gradio(
230236
stats_button.click(mod.gradio_stats, [], [stats_output])
231237

232238
# launch to the web
233-
demo.launch(share=share, enable_queue=True, server_port=port,server_name=host)
239+
demo.launch(share=share, enable_queue=True, server_port=port, server_name=host)
234240

235241

236242
if __name__ == "__main__":

python/mlc_chat/interface/openai_api.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
Adapted from FastChat's OpenAI protocol: https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py
33
"""
44

5-
from typing import Literal, Optional, List, Dict, Any, Union
6-
from pydantic import BaseModel, Field
7-
import shortuuid
85
import time
6+
from typing import Any, Dict, List, Literal, Optional, Union
7+
8+
import shortuuid
9+
from pydantic import BaseModel, Field
910

1011

1112
class ChatMessage(BaseModel):
1213
role: str
1314
content: str
1415
name: str | None = None
1516

17+
1618
class ChatCompletionRequest(BaseModel):
1719
model: str
1820
messages: list[ChatMessage]
@@ -35,16 +37,19 @@ class ChatCompletionRequest(BaseModel):
3537
# logit_bias
3638
# user: Optional[str] = None
3739

40+
3841
class UsageInfo(BaseModel):
3942
prompt_tokens: int = 0
4043
completion_tokens: int | None = 0
4144
total_tokens: int = 0
4245

46+
4347
class ChatCompletionResponseChoice(BaseModel):
4448
index: int
4549
message: ChatMessage
4650
finish_reason: Literal["stop", "length"] | None = None
4751

52+
4853
class ChatCompletionResponse(BaseModel):
4954
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
5055
object: str = "chat.completion"
@@ -53,21 +58,25 @@ class ChatCompletionResponse(BaseModel):
5358
# TODO: Implement support for the following fields
5459
usage: UsageInfo | None = None
5560

61+
5662
class DeltaMessage(BaseModel):
5763
role: str | None = None
5864
content: str | None = None
5965

66+
6067
class ChatCompletionResponseStreamChoice(BaseModel):
6168
index: int
6269
delta: DeltaMessage
6370
finish_reason: Literal["stop", "length"] | None = None
6471

72+
6573
class ChatCompletionStreamResponse(BaseModel):
6674
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
6775
object: str = "chat.completion.chunk"
6876
created: int = Field(default_factory=lambda: int(time.time()))
6977
choices: list[ChatCompletionResponseStreamChoice]
7078

79+
7180
class CompletionRequest(BaseModel):
7281
model: str
7382
prompt: str | list[str]
@@ -91,35 +100,41 @@ class CompletionRequest(BaseModel):
91100
# logit_bias
92101
# user: Optional[str] = None
93102

103+
94104
class CompletionResponseChoice(BaseModel):
95105
index: int
96106
text: str
97107
logprobs: int | None = None
98108
finish_reason: Literal["stop", "length"] | None = None
99109

110+
100111
class CompletionResponse(BaseModel):
101112
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
102113
object: str = "text_completion"
103114
created: int = Field(default_factory=lambda: int(time.time()))
104115
choices: list[CompletionResponseChoice]
105116
usage: UsageInfo
106117

118+
107119
class CompletionResponseStreamChoice(BaseModel):
108120
index: int
109121
text: str
110122
finish_reason: Optional[Literal["stop", "length"]] = None
111123

124+
112125
class CompletionStreamResponse(BaseModel):
113126
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
114127
object: str = "text_completion"
115128
created: int = Field(default_factory=lambda: int(time.time()))
116129
choices: List[CompletionResponseStreamChoice]
117130

131+
118132
class EmbeddingsRequest(BaseModel):
119133
model: Optional[str] = None
120134
input: Union[str, List[Any]]
121135
user: Optional[str] = None
122136

137+
123138
class EmbeddingsResponse(BaseModel):
124139
object: str = "list"
125140
data: List[Dict[str, Any]]

python/mlc_chat/rest.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
import argparse
22
import asyncio
33
from contextlib import asynccontextmanager
4+
from dataclasses import dataclass, field, fields
45

5-
from mlc_chat.chat_module import GenerationConfig
6-
6+
import numpy as np
77
import uvicorn
88
from fastapi import FastAPI
9-
from fastapi.responses import StreamingResponse
109
from fastapi.middleware.cors import CORSMiddleware
11-
12-
from dataclasses import dataclass, field, fields
10+
from fastapi.responses import StreamingResponse
11+
from mlc_chat.chat_module import GenerationConfig
1312

1413
from .base import set_global_random_seed
1514
from .chat_module import ChatModule
1615
from .interface.openai_api import *
1716

18-
import numpy as np
19-
2017

2118
@dataclass
2219
class RestAPIArgs:
@@ -327,13 +324,15 @@ async def read_stats():
327324
"""
328325
return session["chat_mod"].stats()
329326

327+
330328
@app.get("/verbose_stats")
331329
async def read_stats_verbose():
332330
"""
333331
Get the verbose runtime stats.
334332
"""
335333
return session["chat_mod"].stats(verbose=True)
336334

335+
337336
ARGS = convert_args_to_argparser().parse_args()
338337
if __name__ == "__main__":
339338
uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False)

0 commit comments

Comments
 (0)