|
9 | 9 |
|
10 | 10 | Then run:
|
11 | 11 | ```
|
12 |
| -uvicorn llama_cpp.server.app:app --reload |
| 12 | +uvicorn llama_cpp.server.app:create_app --reload |
13 | 13 | ```
|
14 | 14 |
|
15 | 15 | or
|
|
21 | 21 | Then visit http://localhost:8000/docs to see the interactive API docs.
|
22 | 22 |
|
23 | 23 | """
|
| 24 | +from __future__ import annotations |
| 25 | + |
24 | 26 | import os
|
| 27 | +import sys |
25 | 28 | import argparse
|
26 |
| -from typing import List, Literal, Union |
27 | 29 |
|
28 | 30 | import uvicorn
|
29 | 31 |
|
30 |
| -from llama_cpp.server.app import create_app, Settings |
31 |
| - |
32 |
| -def get_base_type(annotation): |
33 |
| - if getattr(annotation, '__origin__', None) is Literal: |
34 |
| - return type(annotation.__args__[0]) |
35 |
| - elif getattr(annotation, '__origin__', None) is Union: |
36 |
| - non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)] |
37 |
| - if non_optional_args: |
38 |
| - return get_base_type(non_optional_args[0]) |
39 |
| - elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List: |
40 |
| - return get_base_type(annotation.__args__[0]) |
41 |
| - else: |
42 |
| - return annotation |
43 |
| - |
44 |
| -def contains_list_type(annotation) -> bool: |
45 |
| - origin = getattr(annotation, '__origin__', None) |
46 |
| - |
47 |
| - if origin is list or origin is List: |
48 |
| - return True |
49 |
| - elif origin in (Literal, Union): |
50 |
| - return any(contains_list_type(arg) for arg in annotation.__args__) |
51 |
| - else: |
52 |
| - return False |
53 |
| - |
54 |
| -def parse_bool_arg(arg): |
55 |
| - if isinstance(arg, bytes): |
56 |
| - arg = arg.decode('utf-8') |
57 |
| - |
58 |
| - true_values = {'1', 'on', 't', 'true', 'y', 'yes'} |
59 |
| - false_values = {'0', 'off', 'f', 'false', 'n', 'no'} |
60 |
| - |
61 |
| - arg_str = str(arg).lower().strip() |
62 |
| - |
63 |
| - if arg_str in true_values: |
64 |
| - return True |
65 |
| - elif arg_str in false_values: |
66 |
| - return False |
67 |
| - else: |
68 |
| - raise ValueError(f'Invalid boolean argument: {arg}') |
69 |
| - |
70 |
| -if __name__ == "__main__": |
71 |
| - parser = argparse.ArgumentParser() |
72 |
| - for name, field in Settings.model_fields.items(): |
73 |
| - description = field.description |
74 |
| - if field.default is not None and description is not None: |
75 |
| - description += f" (default: {field.default})" |
76 |
| - base_type = get_base_type(field.annotation) if field.annotation is not None else str |
77 |
| - list_type = contains_list_type(field.annotation) |
78 |
| - if base_type is not bool: |
79 |
| - parser.add_argument( |
80 |
| - f"--{name}", |
81 |
| - dest=name, |
82 |
| - nargs="*" if list_type else None, |
83 |
| - type=base_type, |
84 |
| - help=description, |
85 |
| - ) |
86 |
| - if base_type is bool: |
87 |
| - parser.add_argument( |
88 |
| - f"--{name}", |
89 |
| - dest=name, |
90 |
| - type=parse_bool_arg, |
91 |
| - help=f"{description}", |
92 |
| - ) |
93 |
| - |
| 32 | +from llama_cpp.server.app import create_app |
| 33 | +from llama_cpp.server.settings import ( |
| 34 | + Settings, |
| 35 | + ServerSettings, |
| 36 | + ModelSettings, |
| 37 | + ConfigFileSettings, |
| 38 | +) |
| 39 | +from llama_cpp.server.cli import add_args_from_model, parse_model_from_args |
| 40 | + |
| 41 | + |
| 42 | +def main(): |
| 43 | + description = "🦙 Llama.cpp python server. Host your own LLMs!🚀" |
| 44 | + parser = argparse.ArgumentParser(description=description) |
| 45 | + |
| 46 | + add_args_from_model(parser, Settings) |
| 47 | + parser.add_argument( |
| 48 | + "--config_file", |
| 49 | + type=str, |
| 50 | + help="Path to a config file to load.", |
| 51 | + ) |
| 52 | + server_settings: ServerSettings | None = None |
| 53 | + model_settings: list[ModelSettings] = [] |
94 | 54 | args = parser.parse_args()
|
95 |
| - settings = Settings(**{k: v for k, v in vars(args).items() if v is not None}) |
96 |
| - app = create_app(settings=settings) |
97 |
| - |
| 55 | + try: |
| 56 | + # Load server settings from config_file if provided |
| 57 | + config_file = os.environ.get("CONFIG_FILE", args.config_file) |
| 58 | + if config_file: |
| 59 | + if not os.path.exists(config_file): |
| 60 | + raise ValueError(f"Config file {config_file} not found!") |
| 61 | + with open(config_file, "rb") as f: |
| 62 | + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) |
| 63 | + server_settings = ServerSettings.model_validate(config_file_settings) |
| 64 | + model_settings = config_file_settings.models |
| 65 | + else: |
| 66 | + server_settings = parse_model_from_args(ServerSettings, args) |
| 67 | + model_settings = [parse_model_from_args(ModelSettings, args)] |
| 68 | + except Exception as e: |
| 69 | + print(e, file=sys.stderr) |
| 70 | + parser.print_help() |
| 71 | + sys.exit(1) |
| 72 | + assert server_settings is not None |
| 73 | + assert model_settings is not None |
| 74 | + app = create_app( |
| 75 | + server_settings=server_settings, |
| 76 | + model_settings=model_settings, |
| 77 | + ) |
98 | 78 | uvicorn.run(
|
99 |
| - app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)), |
100 |
| - ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile |
| 79 | + app, |
| 80 | + host=os.getenv("HOST", server_settings.host), |
| 81 | + port=int(os.getenv("PORT", server_settings.port)), |
| 82 | + ssl_keyfile=server_settings.ssl_keyfile, |
| 83 | + ssl_certfile=server_settings.ssl_certfile, |
101 | 84 | )
|
| 85 | + |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + main() |
0 commit comments