From 4f3c19114ca00bb86cf377dd7cc687eb457b66a8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 6 Feb 2025 16:04:25 -0800 Subject: [PATCH 1/4] Add HTTPS serving option --- llama_stack/cli/stack/run.py | 13 ++++++ llama_stack/distribution/datatypes.py | 20 ++++++++++ llama_stack/distribution/server/server.py | 44 +++++++++++++++++++-- llama_stack/distribution/start_conda_env.sh | 5 ++- llama_stack/distribution/start_container.sh | 12 +++++- 5 files changed, 88 insertions(+), 6 deletions(-) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f84def184b..502dfbed43 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -55,6 +55,16 @@ def _add_arguments(self): default=[], metavar="KEY=VALUE", ) + self.parser.add_argument( + "--ssl-keyfile", + type=str, + help="Path to SSL key file for HTTPS", + ) + self.parser.add_argument( + "--ssl-certfile", + type=str, + help="Path to SSL certificate file for HTTPS", + ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import importlib.resources @@ -178,4 +188,7 @@ def get_conda_prefix(env_name): return run_args.extend(["--env", f"{key}={value}"]) + if args.ssl_keyfile and args.ssl_certfile: + run_args.extend(["--ssl-keyfile", args.ssl_keyfile, "--ssl-certfile", args.ssl_certfile]) + run_with_pty(run_args) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 8b579b6365..a9b64398e0 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -117,6 +117,21 @@ class Provider(BaseModel): config: Dict[str, Any] +class ServerConfig(BaseModel): + port: int = Field( + default=8321, + description="Port to listen on", + ) + ssl_certfile: Optional[str] = Field( + default=None, + description="Path to SSL certificate file for HTTPS", + ) + ssl_keyfile: Optional[str] = Field( + default=None, + description="Path to SSL key file for HTTPS", + ) + + class StackRunConfig(BaseModel): version: str = LLAMA_STACK_RUN_CONFIG_VERSION @@ -159,6 +174,11 @@ class StackRunConfig(BaseModel): eval_tasks: List[EvalTaskInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list) + server: ServerConfig = Field( + default_factory=ServerConfig, + description="Configuration for the HTTP(S) server", + ) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index fcd0e3cad1..69d3e3a626 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -282,6 +282,19 @@ def main(): action="append", help="Environment variables in KEY=value format. Can be specified multiple times.", ) + parser.add_argument( + "--ssl-keyfile", + help="Path to SSL key file for HTTPS", + ) + parser.add_argument( + "--ssl-certfile", + help="Path to SSL certificate file for HTTPS", + ) + + if args.ssl_keyfile and not args.ssl_certfile: + parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") + if args.ssl_certfile and not args.ssl_keyfile: + parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") args = parser.parse_args() if args.env: @@ -381,11 +394,36 @@ def main(): import uvicorn - # FYI this does not do hot-reloads + # Configure SSL if certificates are provided + port = args.port or config.server.port + + ssl_config = None + if args.ssl_keyfile: + keyfile = args.ssl_keyfile + certfile = args.ssl_certfile + else: + keyfile = config.server.ssl_keyfile + certfile = config.server.ssl_certfile + + if keyfile and certfile: + ssl_config = { + "ssl_keyfile": keyfile, + "ssl_certfile": certfile, + } + print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" - print(f"Listening on {listen_host}:{args.port}") - uvicorn.run(app, host=listen_host, port=args.port) + print(f"Listening on {listen_host}:{port}") + + uvicorn_config = { + "app": app, + "host": listen_host, + "port": port, + } + if ssl_config: + uvicorn_config.update(ssl_config) + + uvicorn.run(**uvicorn_config) def extract_path_params(route: str) -> List[str]: diff --git a/llama_stack/distribution/start_conda_env.sh b/llama_stack/distribution/start_conda_env.sh index c37f30ef00..fe830059ff 100755 --- a/llama_stack/distribution/start_conda_env.sh +++ b/llama_stack/distribution/start_conda_env.sh @@ -34,6 +34,7 @@ shift # Process environment variables from --env arguments env_vars="" +other_args="" while [[ $# -gt 0 ]]; do case "$1" in --env) @@ -48,6 +49,7 @@ while [[ $# -gt 0 ]]; do fi ;; *) + other_args="$other_args $1" shift ;; esac @@ -61,4 +63,5 @@ $CONDA_PREFIX/bin/python \ -m llama_stack.distribution.server.server \ --yaml-config "$yaml_config" \ --port "$port" \ - $env_vars + $env_vars \ + $other_args diff --git a/llama_stack/distribution/start_container.sh b/llama_stack/distribution/start_container.sh index 2c5d65d09e..a5f543fb43 100755 --- a/llama_stack/distribution/start_container.sh +++ b/llama_stack/distribution/start_container.sh @@ -40,8 +40,12 @@ shift port="$1" shift +# Initialize other_args +other_args="" + # Process environment variables from --env arguments env_vars="" + while [[ $# -gt 0 ]]; do case "$1" in --env) @@ -55,6 +59,7 @@ while [[ $# -gt 0 ]]; do fi ;; *) + other_args="$other_args $1" shift ;; esac @@ -93,5 +98,8 @@ $CONTAINER_BINARY run $CONTAINER_OPTS -it \ -v "$yaml_config:/app/config.yaml" \ $mounts \ --env LLAMA_STACK_PORT=$port \ - --entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \ - $container_image:$version_tag + --entrypoint python \ + $container_image:$version_tag \ + -m llama_stack.distribution.server.server \ + --yaml-config /app/config.yaml \ + $other_args From e23213ee8af7214c506220cba16014e49f00b819 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 6 Feb 2025 16:54:31 -0800 Subject: [PATCH 2/4] fix --- llama_stack/distribution/server/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 69d3e3a626..f72a8cf541 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -291,12 +291,13 @@ def main(): help="Path to SSL certificate file for HTTPS", ) + args = parser.parse_args() + if args.ssl_keyfile and not args.ssl_certfile: parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") if args.ssl_certfile and not args.ssl_keyfile: parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") - args = parser.parse_args() if args.env: for env_pair in args.env: try: From 997b097e26279a6d1372670e54a7e1324f5732d4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 7 Feb 2025 09:11:15 -0800 Subject: [PATCH 3/4] Address feedback --- llama_stack/distribution/datatypes.py | 2 ++ llama_stack/distribution/server/server.py | 7 ++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a9b64398e0..b366de0407 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -121,6 +121,8 @@ class ServerConfig(BaseModel): port: int = Field( default=8321, description="Port to listen on", + ge=1024, + le=65535, ) ssl_certfile: Optional[str] = Field( default=None, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f72a8cf541..2ecae8857f 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -285,19 +285,16 @@ def main(): parser.add_argument( "--ssl-keyfile", help="Path to SSL key file for HTTPS", + required="--ssl-certfile" in sys.argv, ) parser.add_argument( "--ssl-certfile", help="Path to SSL certificate file for HTTPS", + required="--ssl-keyfile" in sys.argv, ) args = parser.parse_args() - if args.ssl_keyfile and not args.ssl_certfile: - parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") - if args.ssl_certfile and not args.ssl_keyfile: - parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") - if args.env: for env_pair in args.env: try: From b3bfff78c58ebe83acc52275ca1739b92066737f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 7 Feb 2025 09:28:26 -0800 Subject: [PATCH 4/4] SSL -> TLS thanks @leseb --- llama_stack/cli/stack/run.py | 12 ++++++------ llama_stack/distribution/datatypes.py | 8 ++++---- llama_stack/distribution/server/server.py | 22 +++++++++++----------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 502dfbed43..e7d6df2926 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -56,14 +56,14 @@ def _add_arguments(self): metavar="KEY=VALUE", ) self.parser.add_argument( - "--ssl-keyfile", + "--tls-keyfile", type=str, - help="Path to SSL key file for HTTPS", + help="Path to TLS key file for HTTPS", ) self.parser.add_argument( - "--ssl-certfile", + "--tls-certfile", type=str, - help="Path to SSL certificate file for HTTPS", + help="Path to TLS certificate file for HTTPS", ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: @@ -188,7 +188,7 @@ def get_conda_prefix(env_name): return run_args.extend(["--env", f"{key}={value}"]) - if args.ssl_keyfile and args.ssl_certfile: - run_args.extend(["--ssl-keyfile", args.ssl_keyfile, "--ssl-certfile", args.ssl_certfile]) + if args.tls_keyfile and args.tls_certfile: + run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile]) run_with_pty(run_args) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index b366de0407..97706f22a5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -124,13 +124,13 @@ class ServerConfig(BaseModel): ge=1024, le=65535, ) - ssl_certfile: Optional[str] = Field( + tls_certfile: Optional[str] = Field( default=None, - description="Path to SSL certificate file for HTTPS", + description="Path to TLS certificate file for HTTPS", ) - ssl_keyfile: Optional[str] = Field( + tls_keyfile: Optional[str] = Field( default=None, - description="Path to SSL key file for HTTPS", + description="Path to TLS key file for HTTPS", ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2ecae8857f..d2c32de119 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -283,14 +283,14 @@ def main(): help="Environment variables in KEY=value format. Can be specified multiple times.", ) parser.add_argument( - "--ssl-keyfile", - help="Path to SSL key file for HTTPS", - required="--ssl-certfile" in sys.argv, + "--tls-keyfile", + help="Path to TLS key file for HTTPS", + required="--tls-certfile" in sys.argv, ) parser.add_argument( - "--ssl-certfile", - help="Path to SSL certificate file for HTTPS", - required="--ssl-keyfile" in sys.argv, + "--tls-certfile", + help="Path to TLS certificate file for HTTPS", + required="--tls-keyfile" in sys.argv, ) args = parser.parse_args() @@ -396,12 +396,12 @@ def main(): port = args.port or config.server.port ssl_config = None - if args.ssl_keyfile: - keyfile = args.ssl_keyfile - certfile = args.ssl_certfile + if args.tls_keyfile: + keyfile = args.tls_keyfile + certfile = args.tls_certfile else: - keyfile = config.server.ssl_keyfile - certfile = config.server.ssl_certfile + keyfile = config.server.tls_keyfile + certfile = config.server.tls_certfile if keyfile and certfile: ssl_config = {