Skip to content

Commit 12b7f2f

Browse files
D4ve-Rabetlen
andauthored
[Feat] Multi model support (#931)
* Update Llama class to handle chat_format & caching * Add settings.py * Add util.py & update __main__.py * multimodel * update settings.py * cleanup * delete util.py * Fix /v1/models endpoint * MultiLlama now iterable, app check-alive on "/" * instant model init if file is given * backward compability * revert model param mandatory * fix error * handle individual model config json * refactor * revert chathandler/clip_model changes * handle chat_handler in MulitLlama() * split settings into server/llama * reduce global vars * Update LlamaProxy to handle config files * Add free method to LlamaProxy * update arg parsers & install server alias * refactor cache settings * change server executable name * better var name * whitespace * Revert "whitespace" This reverts commit bc5cf51. * remove exe_name * Fix merge bugs * Fix type annotations * Fix type annotations * Fix uvicorn app factory * Fix settings * Refactor server * Remove formatting fix * Format * Use default model if not found in model settings * Fix * Cleanup * Fix * Fix * Remove unnused CommandLineSettings * Cleanup * Support default name for copilot-codex models --------- Co-authored-by: Andrei Betlen <[email protected]>
1 parent 4a85442 commit 12b7f2f

File tree

7 files changed

+1037
-788
lines changed

7 files changed

+1037
-788
lines changed

llama_cpp/server/__main__.py

+58-71
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
Then run:
1111
```
12-
uvicorn llama_cpp.server.app:app --reload
12+
uvicorn llama_cpp.server.app:create_app --reload
1313
```
1414
1515
or
@@ -21,81 +21,68 @@
2121
Then visit http://localhost:8000/docs to see the interactive API docs.
2222
2323
"""
24+
from __future__ import annotations
25+
2426
import os
27+
import sys
2528
import argparse
26-
from typing import List, Literal, Union
2729

2830
import uvicorn
2931

30-
from llama_cpp.server.app import create_app, Settings
31-
32-
def get_base_type(annotation):
33-
if getattr(annotation, '__origin__', None) is Literal:
34-
return type(annotation.__args__[0])
35-
elif getattr(annotation, '__origin__', None) is Union:
36-
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
37-
if non_optional_args:
38-
return get_base_type(non_optional_args[0])
39-
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
40-
return get_base_type(annotation.__args__[0])
41-
else:
42-
return annotation
43-
44-
def contains_list_type(annotation) -> bool:
45-
origin = getattr(annotation, '__origin__', None)
46-
47-
if origin is list or origin is List:
48-
return True
49-
elif origin in (Literal, Union):
50-
return any(contains_list_type(arg) for arg in annotation.__args__)
51-
else:
52-
return False
53-
54-
def parse_bool_arg(arg):
55-
if isinstance(arg, bytes):
56-
arg = arg.decode('utf-8')
57-
58-
true_values = {'1', 'on', 't', 'true', 'y', 'yes'}
59-
false_values = {'0', 'off', 'f', 'false', 'n', 'no'}
60-
61-
arg_str = str(arg).lower().strip()
62-
63-
if arg_str in true_values:
64-
return True
65-
elif arg_str in false_values:
66-
return False
67-
else:
68-
raise ValueError(f'Invalid boolean argument: {arg}')
69-
70-
if __name__ == "__main__":
71-
parser = argparse.ArgumentParser()
72-
for name, field in Settings.model_fields.items():
73-
description = field.description
74-
if field.default is not None and description is not None:
75-
description += f" (default: {field.default})"
76-
base_type = get_base_type(field.annotation) if field.annotation is not None else str
77-
list_type = contains_list_type(field.annotation)
78-
if base_type is not bool:
79-
parser.add_argument(
80-
f"--{name}",
81-
dest=name,
82-
nargs="*" if list_type else None,
83-
type=base_type,
84-
help=description,
85-
)
86-
if base_type is bool:
87-
parser.add_argument(
88-
f"--{name}",
89-
dest=name,
90-
type=parse_bool_arg,
91-
help=f"{description}",
92-
)
93-
32+
from llama_cpp.server.app import create_app
33+
from llama_cpp.server.settings import (
34+
Settings,
35+
ServerSettings,
36+
ModelSettings,
37+
ConfigFileSettings,
38+
)
39+
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
40+
41+
42+
def main():
43+
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
44+
parser = argparse.ArgumentParser(description=description)
45+
46+
add_args_from_model(parser, Settings)
47+
parser.add_argument(
48+
"--config_file",
49+
type=str,
50+
help="Path to a config file to load.",
51+
)
52+
server_settings: ServerSettings | None = None
53+
model_settings: list[ModelSettings] = []
9454
args = parser.parse_args()
95-
settings = Settings(**{k: v for k, v in vars(args).items() if v is not None})
96-
app = create_app(settings=settings)
97-
55+
try:
56+
# Load server settings from config_file if provided
57+
config_file = os.environ.get("CONFIG_FILE", args.config_file)
58+
if config_file:
59+
if not os.path.exists(config_file):
60+
raise ValueError(f"Config file {config_file} not found!")
61+
with open(config_file, "rb") as f:
62+
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
63+
server_settings = ServerSettings.model_validate(config_file_settings)
64+
model_settings = config_file_settings.models
65+
else:
66+
server_settings = parse_model_from_args(ServerSettings, args)
67+
model_settings = [parse_model_from_args(ModelSettings, args)]
68+
except Exception as e:
69+
print(e, file=sys.stderr)
70+
parser.print_help()
71+
sys.exit(1)
72+
assert server_settings is not None
73+
assert model_settings is not None
74+
app = create_app(
75+
server_settings=server_settings,
76+
model_settings=model_settings,
77+
)
9878
uvicorn.run(
99-
app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)),
100-
ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile
79+
app,
80+
host=os.getenv("HOST", server_settings.host),
81+
port=int(os.getenv("PORT", server_settings.port)),
82+
ssl_keyfile=server_settings.ssl_keyfile,
83+
ssl_certfile=server_settings.ssl_certfile,
10184
)
85+
86+
87+
if __name__ == "__main__":
88+
main()

0 commit comments

Comments
 (0)