Skip to content

Commit f330c8e

Browse files
committed
chore: simplify route addition when calling inspect
https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918 Signed-off-by: Sébastien Han <[email protected]>
1 parent ead9e63 commit f330c8e

File tree

2 files changed

+97
-56
lines changed

2 files changed

+97
-56
lines changed

src/llama_stack/core/inspect.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
from llama_stack.core.datatypes import StackRunConfig
1212
from llama_stack.core.external import load_external_apis
13-
from llama_stack.core.resolver import api_protocol_map
14-
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
13+
from llama_stack.core.server.fastapi_router_registry import (
14+
_ROUTER_FACTORIES,
15+
build_fastapi_router,
16+
get_router_routes,
17+
)
1518
from llama_stack.core.server.routes import get_all_api_routes
1619
from llama_stack_api import (
1720
Api,
@@ -46,6 +49,7 @@ async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse
4649
run_config: StackRunConfig = self.config.run_config
4750

4851
# Helper function to determine if a route should be included based on api_filter
52+
# TODO: remove this once we've migrated all APIs to FastAPI routers
4953
def should_include_route(webmethod) -> bool:
5054
if api_filter is None:
5155
# Default: only non-deprecated APIs
@@ -57,40 +61,15 @@ def should_include_route(webmethod) -> bool:
5761
# Filter by API level (non-deprecated routes only)
5862
return not webmethod.deprecated and webmethod.level == api_filter
5963

60-
ret = []
61-
external_apis = load_external_apis(run_config)
62-
all_endpoints = get_all_api_routes(external_apis)
63-
6464
# Helper function to get provider types for an API
65-
def get_provider_types(api: Api) -> list[str]:
65+
def _get_provider_types(api: Api) -> list[str]:
6666
if api.value in ["providers", "inspect"]:
6767
return [] # These APIs don't have "real" providers they're internal to the stack
6868
providers = run_config.providers.get(api.value, [])
6969
return [p.provider_type for p in providers] if providers else []
7070

71-
# Process webmethod-based routes (legacy)
72-
for api, endpoints in all_endpoints.items():
73-
# Skip APIs that have routers - they'll be processed separately
74-
if build_fastapi_router(api, None) is not None:
75-
continue
76-
77-
provider_types = get_provider_types(api)
78-
# Always include provider and inspect APIs, filter others based on run config
79-
if api.value in ["providers", "inspect"] or provider_types:
80-
ret.extend(
81-
[
82-
RouteInfo(
83-
route=e.path,
84-
method=next(iter([m for m in e.methods if m != "HEAD"])),
85-
provider_types=provider_types,
86-
)
87-
for e, webmethod in endpoints
88-
if e.methods is not None and should_include_route(webmethod)
89-
]
90-
)
91-
9271
# Helper function to determine if a router route should be included based on api_filter
93-
def should_include_router_route(route, router_prefix: str | None) -> bool:
72+
def _should_include_router_route(route, router_prefix: str | None) -> bool:
9473
"""Check if a router-based route should be included based on api_filter."""
9574
# Check deprecated status
9675
route_deprecated = getattr(route, "deprecated", False) or False
@@ -109,36 +88,59 @@ def should_include_router_route(route, router_prefix: str | None) -> bool:
10988
return not route_deprecated and prefix_level == api_filter
11089
return not route_deprecated
11190

112-
protocols = api_protocol_map(external_apis)
113-
for api in protocols.keys():
114-
# For route inspection, we don't need a real implementation
115-
router = build_fastapi_router(api, None)
116-
if not router:
91+
ret = []
92+
external_apis = load_external_apis(run_config)
93+
all_endpoints = get_all_api_routes(external_apis)
94+
95+
# Process routes from APIs with FastAPI routers
96+
for api_name in _ROUTER_FACTORIES.keys():
97+
api = Api(api_name)
98+
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
99+
if router:
100+
router_routes = get_router_routes(router)
101+
for route in router_routes:
102+
if _should_include_router_route(route, router.prefix):
103+
ret.append(
104+
RouteInfo(
105+
route=route.path,
106+
method=next(iter([m for m in route.methods if m != "HEAD"])),
107+
provider_types=_get_provider_types(api),
108+
)
109+
)
110+
111+
# Process routes from legacy webmethod-based APIs
112+
for api, endpoints in all_endpoints.items():
113+
# Skip APIs that have routers (already processed above)
114+
if api.value in _ROUTER_FACTORIES:
117115
continue
118116

119-
provider_types = get_provider_types(api)
120-
# Only include if there are providers (or it's a special API)
121-
if api.value in ["providers", "inspect"] or provider_types:
122-
router_prefix = getattr(router, "prefix", None)
123-
for route in router.routes:
124-
# Extract HTTP methods from the route
125-
# FastAPI routes have methods as a set
126-
if hasattr(route, "methods") and route.methods:
127-
methods = {m for m in route.methods if m != "HEAD"}
128-
if methods and should_include_router_route(route, router_prefix):
129-
# FastAPI already combines router prefix with route path
130-
# Only APIRoute has a path attribute, use getattr to safely access it
131-
path = getattr(route, "path", None)
132-
if path is None:
133-
continue
134-
135-
ret.append(
136-
RouteInfo(
137-
route=path,
138-
method=next(iter(methods)),
139-
provider_types=provider_types,
140-
)
117+
# Always include provider and inspect APIs, filter others based on run config
118+
if api.value in ["providers", "inspect"]:
119+
ret.extend(
120+
[
121+
RouteInfo(
122+
route=e.path,
123+
method=next(iter([m for m in e.methods if m != "HEAD"])),
124+
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
125+
)
126+
for e, webmethod in endpoints
127+
if e.methods is not None and should_include_route(webmethod)
128+
]
129+
)
130+
else:
131+
providers = run_config.providers.get(api.value, [])
132+
if providers: # Only process if there are providers for this API
133+
ret.extend(
134+
[
135+
RouteInfo(
136+
route=e.path,
137+
method=next(iter([m for m in e.methods if m != "HEAD"])),
138+
provider_types=[p.provider_type for p in providers],
141139
)
140+
for e, webmethod in endpoints
141+
if e.methods is not None and should_include_route(webmethod)
142+
]
143+
)
142144

143145
return ListRoutesResponse(data=ret)
144146

src/llama_stack/core/server/fastapi_router_registry.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from typing import Any, cast
1515

1616
from fastapi import APIRouter
17+
from fastapi.routing import APIRoute
18+
from starlette.routing import Route
1719

1820
# Router factories for APIs that have FastAPI routers
1921
# Add new APIs here as they are migrated to the router system
@@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
4345
# If a router factory returns the wrong type, it will fail at runtime when
4446
# app.include_router(router) is called
4547
return cast(APIRouter, router_factory(impl))
48+
49+
50+
def get_router_routes(router: APIRouter) -> list[Route]:
51+
"""Extract routes from a FastAPI router.
52+
53+
Args:
54+
router: The FastAPI router to extract routes from
55+
56+
Returns:
57+
List of Route objects from the router
58+
"""
59+
routes = []
60+
61+
for route in router.routes:
62+
# FastAPI routers use APIRoute objects, which have path and methods attributes
63+
if isinstance(route, APIRoute):
64+
# Combine router prefix with route path
65+
routes.append(
66+
Route(
67+
path=route.path,
68+
methods=route.methods,
69+
name=route.name,
70+
endpoint=route.endpoint,
71+
)
72+
)
73+
elif isinstance(route, Route):
74+
# Fallback for regular Starlette Route objects
75+
routes.append(
76+
Route(
77+
path=route.path,
78+
methods=route.methods,
79+
name=route.name,
80+
endpoint=route.endpoint,
81+
)
82+
)
83+
84+
return routes

0 commit comments

Comments
 (0)