diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index cea167c7..afecc6af 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -8,15 +8,57 @@ import logging from typing import Any -from fastapi import APIRouter - -from models.responses import ReadinessResponse, LivenessResponse, NotAvailableResponse +from llama_stack.providers.datatypes import HealthStatus +from fastapi import APIRouter, status, Response +from client import get_llama_stack_client +from configuration import configuration +from models.responses import ( + LivenessResponse, + ReadinessResponse, + ProviderHealthStatus, +) logger = logging.getLogger(__name__) router = APIRouter(tags=["health"]) +def get_providers_health_statuses() -> list[ProviderHealthStatus]: + """Check health of all providers. + + Returns: + List of provider health statuses. + """ + try: + llama_stack_config = configuration.llama_stack_configuration + + client = get_llama_stack_client(llama_stack_config) + + providers = client.providers.list() + logger.debug("Found %d providers", len(providers)) + + health_results = [ + ProviderHealthStatus( + provider_id=provider.provider_id, + status=str(provider.health.get("status", "unknown")), + message=str(provider.health.get("message", "")), + ) + for provider in providers + ] + return health_results + + except Exception as e: # pylint: disable=broad-exception-caught + # eg. no providers defined + logger.error("Failed to check providers health: %s", e) + return [ + ProviderHealthStatus( + provider_id="unknown", + status=HealthStatus.ERROR.value, + message=f"Failed to initialize health check: {str(e)}", + ) + ] + + get_readiness_responses: dict[int | str, dict[str, Any]] = { 200: { "description": "Service is ready", @@ -24,15 +66,31 @@ }, 503: { "description": "Service is not ready", - "model": NotAvailableResponse, + "model": ReadinessResponse, }, } @router.get("/readiness", responses=get_readiness_responses) -def readiness_probe_get_method() -> ReadinessResponse: - """Ready status of service.""" - return ReadinessResponse(ready=True, reason="service is ready") +def readiness_probe_get_method(response: Response) -> ReadinessResponse: + """Ready status of service with provider health details.""" + provider_statuses = get_providers_health_statuses() + + # Check if any provider is unhealthy (not counting not_implemented as unhealthy) + unhealthy_providers = [ + p for p in provider_statuses if p.status == HealthStatus.ERROR.value + ] + + if unhealthy_providers: + ready = False + unhealthy_provider_names = [p.provider_id for p in unhealthy_providers] + reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}" + response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE + else: + ready = True + reason = "All providers are healthy" + + return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers) get_liveness_responses: dict[int | str, dict[str, Any]] = { diff --git a/src/models/responses.py b/src/models/responses.py index cdafd4b0..92c366c9 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -78,21 +78,47 @@ class InfoResponse(BaseModel): } +class ProviderHealthStatus(BaseModel): + """Model representing the health status of a provider. + + Attributes: + provider_id: The ID of the provider. + status: The health status ('ok', 'unhealthy', 'not_implemented'). + message: Optional message about the health status. + """ + + provider_id: str + status: str + message: Optional[str] = None + + class ReadinessResponse(BaseModel): - """Model representing a response to a readiness request. + """Model representing response to a readiness request. Attributes: - ready: The readiness of the service. + ready: If service is ready. reason: The reason for the readiness. + providers: List of unhealthy providers in case of readiness failure. Example: ```python - readiness_response = ReadinessResponse(ready=True, reason="service is ready") + readiness_response = ReadinessResponse( + ready=False, + reason="Service is not ready", + providers=[ + ProviderHealthStatus( + provider_id="ollama", + status="Error", + message="Server is unavailable" + ) + ] + ) ``` """ ready: bool reason: str + providers: list[ProviderHealthStatus] # provides examples for /docs endpoint model_config = { @@ -100,7 +126,8 @@ class ReadinessResponse(BaseModel): "examples": [ { "ready": True, - "reason": "service is ready", + "reason": "Service is ready", + "providers": [], } ] } diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index e73cb3e6..df90bee0 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -1,16 +1,177 @@ -from app.endpoints.health import readiness_probe_get_method, liveness_probe_get_method +from unittest.mock import Mock +from app.endpoints.health import ( + readiness_probe_get_method, + liveness_probe_get_method, + get_providers_health_statuses, +) +from models.responses import ProviderHealthStatus, ReadinessResponse +from llama_stack.providers.datatypes import HealthStatus -def test_readiness_probe(mocker): - """Test the readiness endpoint handler.""" - response = readiness_probe_get_method() + +def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): + """Test the readiness endpoint handler fails when providers are unhealthy.""" + # Mock get_providers_health_statuses to return an unhealthy provider + mock_get_providers_health_statuses = mocker.patch( + "app.endpoints.health.get_providers_health_statuses" + ) + mock_get_providers_health_statuses.return_value = [ + ProviderHealthStatus( + provider_id="test_provider", + status=HealthStatus.ERROR.value, + message="Provider is down", + ) + ] + + # Mock the Response object + mock_response = Mock() + + response = readiness_probe_get_method(mock_response) + + assert response.ready is False + assert "test_provider" in response.reason + assert "Providers not healthy" in response.reason + assert mock_response.status_code == 503 + + +def test_readiness_probe_success_when_all_providers_healthy(mocker): + """Test the readiness endpoint handler succeeds when all providers are healthy.""" + # Mock get_providers_health_statuses to return healthy providers + mock_get_providers_health_statuses = mocker.patch( + "app.endpoints.health.get_providers_health_statuses" + ) + mock_get_providers_health_statuses.return_value = [ + ProviderHealthStatus( + provider_id="provider1", + status=HealthStatus.OK.value, + message="Provider is healthy", + ), + ProviderHealthStatus( + provider_id="provider2", + status=HealthStatus.NOT_IMPLEMENTED.value, + message="Provider does not implement health check", + ), + ] + + # Mock the Response object + mock_response = Mock() + + response = readiness_probe_get_method(mock_response) assert response is not None + assert isinstance(response, ReadinessResponse) assert response.ready is True - assert response.reason == "service is ready" + assert response.reason == "All providers are healthy" + # Should return empty list since no providers are unhealthy + assert len(response.providers) == 0 -def test_liveness_probe(mocker): +def test_liveness_probe(): """Test the liveness endpoint handler.""" response = liveness_probe_get_method() assert response is not None assert response.alive is True + + +class TestProviderHealthStatus: + """Test cases for the ProviderHealthStatus model.""" + + def test_provider_health_status_creation(self): + """Test creating a ProviderHealthStatus instance.""" + status = ProviderHealthStatus( + provider_id="test_provider", status="ok", message="All good" + ) + assert status.provider_id == "test_provider" + assert status.status == "ok" + assert status.message == "All good" + + def test_provider_health_status_optional_fields(self): + """Test creating a ProviderHealthStatus with minimal fields.""" + status = ProviderHealthStatus(provider_id="test_provider", status="ok") + assert status.provider_id == "test_provider" + assert status.status == "ok" + assert status.message is None + + +class TestGetProvidersHealthStatuses: + """Test cases for the get_providers_health_statuses function.""" + + def test_get_providers_health_statuses(self, mocker): + """Test get_providers_health_statuses with healthy providers.""" + # Mock the imports + mock_get_llama_stack_client = mocker.patch( + "app.endpoints.health.get_llama_stack_client" + ) + mock_configuration = mocker.patch("app.endpoints.health.configuration") + + # Mock the client and its methods + mock_client = mocker.Mock() + mock_get_llama_stack_client.return_value = mock_client + + # Mock providers.list() to return providers with health + mock_provider_1 = mocker.Mock() + mock_provider_1.provider_id = "provider1" + mock_provider_1.health = { + "status": HealthStatus.OK.value, + "message": "All good", + } + + mock_provider_2 = mocker.Mock() + mock_provider_2.provider_id = "provider2" + mock_provider_2.health = { + "status": HealthStatus.NOT_IMPLEMENTED.value, + "message": "Provider does not implement health check", + } + + mock_provider_3 = mocker.Mock() + mock_provider_3.provider_id = "unhealthy_provider" + mock_provider_3.health = { + "status": HealthStatus.ERROR.value, + "message": "Connection failed", + } + + mock_client.providers.list.return_value = [ + mock_provider_1, + mock_provider_2, + mock_provider_3, + ] + + # Mock configuration + mock_llama_stack_config = mocker.Mock() + mock_configuration.llama_stack_configuration = mock_llama_stack_config + + result = get_providers_health_statuses() + + assert len(result) == 3 + assert result[0].provider_id == "provider1" + assert result[0].status == HealthStatus.OK.value + assert result[0].message == "All good" + assert result[1].provider_id == "provider2" + assert result[1].status == HealthStatus.NOT_IMPLEMENTED.value + assert result[1].message == "Provider does not implement health check" + assert result[2].provider_id == "unhealthy_provider" + assert result[2].status == HealthStatus.ERROR.value + assert result[2].message == "Connection failed" + + def test_get_providers_health_statuses_connection_error(self, mocker): + """Test get_providers_health_statuses when connection fails.""" + # Mock the imports + mock_get_llama_stack_client = mocker.patch( + "app.endpoints.health.get_llama_stack_client" + ) + mock_configuration = mocker.patch("app.endpoints.health.configuration") + + # Mock configuration + mock_llama_stack_config = mocker.Mock() + mock_configuration.llama_stack_configuration = mock_llama_stack_config + + # Mock get_llama_stack_client to raise an exception + mock_get_llama_stack_client.side_effect = Exception("Connection error") + + result = get_providers_health_statuses() + + assert len(result) == 1 + assert result[0].provider_id == "unknown" + assert result[0].status == HealthStatus.ERROR.value + assert ( + result[0].message == "Failed to initialize health check: Connection error" + )