Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions src/app/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,89 @@
import logging
from typing import Any

from fastapi import APIRouter

from models.responses import ReadinessResponse, LivenessResponse, NotAvailableResponse
from llama_stack.providers.datatypes import HealthStatus

from fastapi import APIRouter, status, Response
from client import get_llama_stack_client
from configuration import configuration
from models.responses import (
LivenessResponse,
ReadinessResponse,
ProviderHealthStatus,
)

logger = logging.getLogger(__name__)
router = APIRouter(tags=["health"])


def get_providers_health_statuses() -> list[ProviderHealthStatus]:
"""Check health of all providers.

Returns:
List of provider health statuses.
"""
try:
llama_stack_config = configuration.llama_stack_configuration

client = get_llama_stack_client(llama_stack_config)

providers = client.providers.list()
logger.debug("Found %d providers", len(providers))

health_results = [
ProviderHealthStatus(
provider_id=provider.provider_id,
status=str(provider.health.get("status", "unknown")),
message=str(provider.health.get("message", "")),
)
for provider in providers
]
return health_results

except Exception as e: # pylint: disable=broad-exception-caught
# eg. no providers defined
logger.error("Failed to check providers health: %s", e)
return [
ProviderHealthStatus(
provider_id="unknown",
status=HealthStatus.ERROR.value,
message=f"Failed to initialize health check: {str(e)}",
)
]


get_readiness_responses: dict[int | str, dict[str, Any]] = {
200: {
"description": "Service is ready",
"model": ReadinessResponse,
},
503: {
"description": "Service is not ready",
"model": NotAvailableResponse,
"model": ReadinessResponse,
},
}


@router.get("/readiness", responses=get_readiness_responses)
def readiness_probe_get_method() -> ReadinessResponse:
"""Ready status of service."""
return ReadinessResponse(ready=True, reason="service is ready")
def readiness_probe_get_method(response: Response) -> ReadinessResponse:
"""Ready status of service with provider health details."""
provider_statuses = get_providers_health_statuses()

# Check if any provider is unhealthy (not counting not_implemented as unhealthy)
unhealthy_providers = [
p for p in provider_statuses if p.status == HealthStatus.ERROR.value
]

if unhealthy_providers:
ready = False
unhealthy_provider_names = [p.provider_id for p in unhealthy_providers]
reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}"
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
else:
ready = True
reason = "All providers are healthy"

return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers)


get_liveness_responses: dict[int | str, dict[str, Any]] = {
Expand Down
35 changes: 31 additions & 4 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,56 @@ class InfoResponse(BaseModel):
}


class ProviderHealthStatus(BaseModel):
"""Model representing the health status of a provider.

Attributes:
provider_id: The ID of the provider.
status: The health status ('ok', 'unhealthy', 'not_implemented').
message: Optional message about the health status.
"""

provider_id: str
status: str
message: Optional[str] = None


class ReadinessResponse(BaseModel):
"""Model representing a response to a readiness request.
"""Model representing response to a readiness request.

Attributes:
ready: The readiness of the service.
ready: If service is ready.
reason: The reason for the readiness.
providers: List of unhealthy providers in case of readiness failure.

Example:
```python
readiness_response = ReadinessResponse(ready=True, reason="service is ready")
readiness_response = ReadinessResponse(
ready=False,
reason="Service is not ready",
providers=[
ProviderHealthStatus(
provider_id="ollama",
status="Error",
message="Server is unavailable"
)
]
)
```
"""

ready: bool
reason: str
providers: list[ProviderHealthStatus]

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"ready": True,
"reason": "service is ready",
"reason": "Service is ready",
"providers": [],
}
]
}
Expand Down
173 changes: 167 additions & 6 deletions tests/unit/app/endpoints/test_health.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,177 @@
from app.endpoints.health import readiness_probe_get_method, liveness_probe_get_method
from unittest.mock import Mock

from app.endpoints.health import (
readiness_probe_get_method,
liveness_probe_get_method,
get_providers_health_statuses,
)
from models.responses import ProviderHealthStatus, ReadinessResponse
from llama_stack.providers.datatypes import HealthStatus

def test_readiness_probe(mocker):
"""Test the readiness endpoint handler."""
response = readiness_probe_get_method()

def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
"""Test the readiness endpoint handler fails when providers are unhealthy."""
# Mock get_providers_health_statuses to return an unhealthy provider
mock_get_providers_health_statuses = mocker.patch(
"app.endpoints.health.get_providers_health_statuses"
)
mock_get_providers_health_statuses.return_value = [
ProviderHealthStatus(
provider_id="test_provider",
status=HealthStatus.ERROR.value,
message="Provider is down",
)
]

# Mock the Response object
mock_response = Mock()

response = readiness_probe_get_method(mock_response)

assert response.ready is False
assert "test_provider" in response.reason
assert "Providers not healthy" in response.reason
assert mock_response.status_code == 503


def test_readiness_probe_success_when_all_providers_healthy(mocker):
"""Test the readiness endpoint handler succeeds when all providers are healthy."""
# Mock get_providers_health_statuses to return healthy providers
mock_get_providers_health_statuses = mocker.patch(
"app.endpoints.health.get_providers_health_statuses"
)
mock_get_providers_health_statuses.return_value = [
ProviderHealthStatus(
provider_id="provider1",
status=HealthStatus.OK.value,
message="Provider is healthy",
),
ProviderHealthStatus(
provider_id="provider2",
status=HealthStatus.NOT_IMPLEMENTED.value,
message="Provider does not implement health check",
),
]

# Mock the Response object
mock_response = Mock()

response = readiness_probe_get_method(mock_response)
assert response is not None
assert isinstance(response, ReadinessResponse)
assert response.ready is True
assert response.reason == "service is ready"
assert response.reason == "All providers are healthy"
# Should return empty list since no providers are unhealthy
assert len(response.providers) == 0


def test_liveness_probe(mocker):
def test_liveness_probe():
"""Test the liveness endpoint handler."""
response = liveness_probe_get_method()
assert response is not None
assert response.alive is True


class TestProviderHealthStatus:
"""Test cases for the ProviderHealthStatus model."""

def test_provider_health_status_creation(self):
"""Test creating a ProviderHealthStatus instance."""
status = ProviderHealthStatus(
provider_id="test_provider", status="ok", message="All good"
)
assert status.provider_id == "test_provider"
assert status.status == "ok"
assert status.message == "All good"

def test_provider_health_status_optional_fields(self):
"""Test creating a ProviderHealthStatus with minimal fields."""
status = ProviderHealthStatus(provider_id="test_provider", status="ok")
assert status.provider_id == "test_provider"
assert status.status == "ok"
assert status.message is None


class TestGetProvidersHealthStatuses:
"""Test cases for the get_providers_health_statuses function."""

def test_get_providers_health_statuses(self, mocker):
"""Test get_providers_health_statuses with healthy providers."""
# Mock the imports
mock_get_llama_stack_client = mocker.patch(
"app.endpoints.health.get_llama_stack_client"
)
mock_configuration = mocker.patch("app.endpoints.health.configuration")

# Mock the client and its methods
mock_client = mocker.Mock()
mock_get_llama_stack_client.return_value = mock_client

# Mock providers.list() to return providers with health
mock_provider_1 = mocker.Mock()
mock_provider_1.provider_id = "provider1"
mock_provider_1.health = {
"status": HealthStatus.OK.value,
"message": "All good",
}

mock_provider_2 = mocker.Mock()
mock_provider_2.provider_id = "provider2"
mock_provider_2.health = {
"status": HealthStatus.NOT_IMPLEMENTED.value,
"message": "Provider does not implement health check",
}

mock_provider_3 = mocker.Mock()
mock_provider_3.provider_id = "unhealthy_provider"
mock_provider_3.health = {
"status": HealthStatus.ERROR.value,
"message": "Connection failed",
}

mock_client.providers.list.return_value = [
mock_provider_1,
mock_provider_2,
mock_provider_3,
]

# Mock configuration
mock_llama_stack_config = mocker.Mock()
mock_configuration.llama_stack_configuration = mock_llama_stack_config

result = get_providers_health_statuses()

assert len(result) == 3
assert result[0].provider_id == "provider1"
assert result[0].status == HealthStatus.OK.value
assert result[0].message == "All good"
assert result[1].provider_id == "provider2"
assert result[1].status == HealthStatus.NOT_IMPLEMENTED.value
assert result[1].message == "Provider does not implement health check"
assert result[2].provider_id == "unhealthy_provider"
assert result[2].status == HealthStatus.ERROR.value
assert result[2].message == "Connection failed"

def test_get_providers_health_statuses_connection_error(self, mocker):
"""Test get_providers_health_statuses when connection fails."""
# Mock the imports
mock_get_llama_stack_client = mocker.patch(
"app.endpoints.health.get_llama_stack_client"
)
mock_configuration = mocker.patch("app.endpoints.health.configuration")

# Mock configuration
mock_llama_stack_config = mocker.Mock()
mock_configuration.llama_stack_configuration = mock_llama_stack_config

# Mock get_llama_stack_client to raise an exception
mock_get_llama_stack_client.side_effect = Exception("Connection error")

result = get_providers_health_statuses()

assert len(result) == 1
assert result[0].provider_id == "unknown"
assert result[0].status == HealthStatus.ERROR.value
assert (
result[0].message == "Failed to initialize health check: Connection error"
)