Skip to content

Commit b3989af

Browse files
committed
auth4
1 parent f7e89d4 commit b3989af

File tree

14 files changed

+86
-365
lines changed

14 files changed

+86
-365
lines changed

src/app/endpoints/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66
from fastapi import APIRouter, Request
77

8-
from models.config import Configuration
9-
from configuration import configuration
108
from authorization.middleware import authorize
11-
from models.config import Action
9+
from configuration import configuration
10+
from models.config import Action, Configuration
1211
from utils.endpoints import check_configuration_loaded
1312

1413
logger = logging.getLogger(__name__)

src/app/endpoints/conversations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
12+
from app.database import get_session
13+
from auth import get_auth_dependency
14+
from authorization.middleware import authorize
15+
from models.config import Action
16+
from models.database.conversations import UserConversation
1217
from models.responses import (
1318
ConversationResponse,
1419
ConversationDeleteResponse,
1520
ConversationsListResponse,
1621
ConversationDetails,
1722
)
18-
from models.database.conversations import UserConversation
19-
from auth import get_auth_dependency
20-
from app.database import get_session
2123
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
22-
from authorization.middleware import authorize
23-
from models.config import Action
2424
from utils.suid import check_suid
2525

2626
logger = logging.getLogger("app.endpoints.handlers")

src/app/endpoints/feedback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010
from auth import get_auth_dependency
1111
from auth.interface import AuthTuple
1212
from authorization.middleware import authorize
13-
from models.config import Action
1413
from configuration import configuration
14+
from models.config import Action
15+
from models.requests import FeedbackRequest
1516
from models.responses import (
1617
ErrorResponse,
1718
FeedbackResponse,
1819
StatusResponse,
1920
UnauthorizedResponse,
2021
ForbiddenResponse,
2122
)
22-
from models.requests import FeedbackRequest
2323
from utils.suid import get_suid
2424

2525
logger = logging.getLogger(__name__)

src/app/endpoints/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44
from typing import Any
55

6+
from fastapi import APIRouter, HTTPException, Request, status
67
from fastapi.params import Depends
78
from llama_stack_client import APIConnectionError
8-
from fastapi import APIRouter, HTTPException, Request, status
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
@@ -67,7 +67,6 @@ async def models_endpoint_handler(
6767
Returns:
6868
ModelsResponse: An object containing the list of available models.
6969
"""
70-
7170
# Used only by the middleware
7271
_ = auth
7372

src/app/endpoints/query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@
2424
from configuration import configuration
2525
from app.database import get_session
2626
import metrics
27+
import constants
28+
from authorization.middleware import authorize
29+
from models.config import Action
2730
from models.database.conversations import UserConversation
28-
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2931
from models.requests import QueryRequest, Attachment
30-
import constants
32+
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
3133
from utils.endpoints import (
3234
check_configuration_loaded,
3335
get_agent,
3436
get_system_prompt,
3537
validate_conversation_ownership,
3638
)
37-
from authorization.middleware import authorize
38-
from models.config import Action
3939
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
4040
from utils.suid import get_suid
4141

src/app/endpoints/streaming_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from auth import get_auth_dependency
2121
from auth.interface import AuthTuple
2222
from authorization.middleware import authorize
23-
from models.config import Action
2423
from client import AsyncLlamaStackClientHolder
2524
from configuration import configuration
2625
import metrics
26+
from models.config import Action
2727
from models.requests import QueryRequest
2828
from models.database.conversations import UserConversation
2929
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt

src/authorization/engine.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class NoopAccessResolver(AccessResolver): # pylint: disable=too-few-public-meth
2828
"""No-op access resolver that does not perform any access checks."""
2929

3030
async def check_access(self, action: Action, user_roles: UserRoles) -> bool:
31-
"""Always return True, indicating access is granted."""
31+
"""Return True always, indicating access is granted."""
3232
_ = action # Unused
3333
_ = user_roles # Unused
3434
return True
@@ -68,22 +68,19 @@ def __init__(self, role_rules: list[JwtRoleRule]):
6868

6969
async def resolve_roles(self, auth: AuthTuple) -> UserRoles:
7070
"""Extract roles from JWT claims using configured rules."""
71-
7271
jwt_claims = self._get_claims(auth)
7372
return frozenset(
7473
role
7574
for rule in self.role_rules
76-
for role in self._evaluate_role_rules(rule, jwt_claims)
75+
for role in self.evaluate_role_rules(rule, jwt_claims)
7776
)
7877

7978
@staticmethod
80-
def _evaluate_role_rules(
81-
rule: JwtRoleRule, jwt_claims: dict[str, Any]
82-
) -> UserRoles:
83-
"""Get roles from a JWT role rule if it matches the claims"""
79+
def evaluate_role_rules(rule: JwtRoleRule, jwt_claims: dict[str, Any]) -> UserRoles:
80+
"""Get roles from a JWT role rule if it matches the claims."""
8481
return (
8582
frozenset(rule.roles)
86-
if __class__._evaluate_operator(
83+
if JwtRolesResolver._evaluate_operator(
8784
rule.negate,
8885
[match.value for match in parse(rule.jsonpath).find(jwt_claims)],
8986
rule.operator,
@@ -147,6 +144,7 @@ def __init__(self, access_rules: list[AccessRule]):
147144
self._access_lookup[rule.role].update(rule.actions)
148145

149146
async def check_access(self, action: Action, user_roles: UserRoles) -> bool:
147+
"""Check if the user has access to the specified action based on their roles."""
150148
if action != Action.ADMIN and self.check_access(action.ADMIN, user_roles):
151149
# Recurse to check if the roles allow the user to perform the admin action,
152150
# if they do, then we allow any action

src/authorization/middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -
8383

8484

8585
def authorize(action: Action) -> Callable:
86-
"""Decorator to check authorization for an endpoint (async version)."""
86+
"""Check authorization for an endpoint (async version)."""
8787

8888
def decorator(func: Callable) -> Callable:
8989
@wraps(func)
90-
async def wrapper(*args, **kwargs):
90+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
9191
await _perform_authorization_check(action, kwargs)
9292
return await func(*args, **kwargs)
9393

src/authorization/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Authorization models and data structures."""

src/models/config.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ class JwtRoleRule(BaseModel):
215215

216216
@model_validator(mode="after")
217217
def check_jsonpath(self) -> Self:
218-
"""
219-
Verify that the JSONPath expression is valid.
220-
"""
218+
"""Verify that the JSONPath expression is valid."""
221219
try:
222220
jsonpath_ng.parse(self.jsonpath)
223221
return self
@@ -228,9 +226,7 @@ def check_jsonpath(self) -> Self:
228226

229227
@model_validator(mode="after")
230228
def check_roles(self) -> Self:
231-
"""
232-
Ensure that at least one role is specified.
233-
"""
229+
"""Ensure that at least one role is specified."""
234230
if not self.roles:
235231
raise ValueError("At least one role must be specified in the rule")
236232

0 commit comments

Comments
 (0)