From 8973a4d8cd955931acbc32d5709adb813ff40ee4 Mon Sep 17 00:00:00 2001 From: qdaxb <4157870+qdaxb@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:10:55 +0800 Subject: [PATCH] feat(backend): add Completion Condition mechanism for CI monitoring Add a generic Completion Condition system to track async external conditions (e.g., CI pipelines) and trigger auto-fix workflows: - Add CompletionCondition model with status tracking and retry logic - Add GitHub/GitLab webhook endpoints for receiving CI events - Add CIMonitorService for handling CI events and auto-fix workflow - Add CompletionConditionService for CRUD operations - Add new webhook notification events: condition.satisfied, condition.ci_failed, condition.max_retry_reached, task.fully_completed - Add configuration for CI monitoring (CI_MONITOR_ENABLED, CI_MAX_RETRIES) - Add database migration for completion_conditions table The system automatically detects CI failures and creates fix subtasks with relevant logs, up to a configurable max retry limit. --- ...d5e6f7g_add_completion_conditions_table.py | 59 +++ backend/app/api/api.py | 18 +- backend/app/api/dependencies.py | 14 + .../api/endpoints/completion_conditions.py | 141 ++++++ .../app/api/endpoints/webhooks/__init__.py | 11 + backend/app/api/endpoints/webhooks/github.py | 207 +++++++++ backend/app/api/endpoints/webhooks/gitlab.py | 130 ++++++ backend/app/core/config.py | 7 + backend/app/models/__init__.py | 3 +- backend/app/models/completion_condition.py | 96 ++++ backend/app/schemas/completion_condition.py | 156 +++++++ backend/app/services/ci_monitor.py | 429 ++++++++++++++++++ backend/app/services/completion_condition.py | 309 +++++++++++++ 13 files changed, 1578 insertions(+), 2 deletions(-) create mode 100644 backend/alembic/versions/2b3c4d5e6f7g_add_completion_conditions_table.py create mode 100644 backend/app/api/endpoints/completion_conditions.py create mode 100644 backend/app/api/endpoints/webhooks/__init__.py create mode 100644 backend/app/api/endpoints/webhooks/github.py create mode 100644 backend/app/api/endpoints/webhooks/gitlab.py create mode 100644 backend/app/models/completion_condition.py create mode 100644 backend/app/schemas/completion_condition.py create mode 100644 backend/app/services/ci_monitor.py create mode 100644 backend/app/services/completion_condition.py diff --git a/backend/alembic/versions/2b3c4d5e6f7g_add_completion_conditions_table.py b/backend/alembic/versions/2b3c4d5e6f7g_add_completion_conditions_table.py new file mode 100644 index 00000000..8b8894f6 --- /dev/null +++ b/backend/alembic/versions/2b3c4d5e6f7g_add_completion_conditions_table.py @@ -0,0 +1,59 @@ +"""add completion_conditions table for CI monitoring + +Revision ID: 2b3c4d5e6f7g +Revises: 1a2b3c4d5e6f +Create Date: 2025-07-01 12:00:00.000000+08:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2b3c4d5e6f7g' +down_revision: Union[str, Sequence[str], None] = '1a2b3c4d5e6f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add completion_conditions table for tracking async completion conditions.""" + + # Create completion_conditions table + op.execute(""" + CREATE TABLE IF NOT EXISTS completion_conditions ( + id INT NOT NULL AUTO_INCREMENT, + subtask_id INT NOT NULL, + task_id INT NOT NULL, + user_id INT NOT NULL, + condition_type ENUM('CI_PIPELINE', 'EXTERNAL_TASK', 'APPROVAL', 'MANUAL_CONFIRM') NOT NULL DEFAULT 'CI_PIPELINE', + status ENUM('PENDING', 'IN_PROGRESS', 'SATISFIED', 'FAILED', 'CANCELLED') NOT NULL DEFAULT 'PENDING', + external_id VARCHAR(256) DEFAULT NULL, + external_url VARCHAR(1024) DEFAULT NULL, + git_platform ENUM('GITHUB', 'GITLAB') DEFAULT NULL, + git_domain VARCHAR(256) DEFAULT NULL, + repo_full_name VARCHAR(512) DEFAULT NULL, + branch_name VARCHAR(256) DEFAULT NULL, + retry_count INT NOT NULL DEFAULT 0, + max_retries INT NOT NULL DEFAULT 5, + last_failure_log TEXT DEFAULT NULL, + metadata JSON DEFAULT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + satisfied_at DATETIME DEFAULT NULL, + PRIMARY KEY (id), + KEY ix_completion_conditions_id (id), + KEY ix_completion_conditions_subtask_id (subtask_id), + KEY ix_completion_conditions_task_id (task_id), + KEY ix_completion_conditions_user_id (user_id), + KEY ix_completion_conditions_branch_name (branch_name), + KEY ix_completion_conditions_repo_branch (repo_full_name, branch_name) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """) + + +def downgrade() -> None: + """Remove completion_conditions table.""" + op.execute("DROP TABLE IF EXISTS completion_conditions") diff --git a/backend/app/api/api.py b/backend/app/api/api.py index 9b3e99e9..264af1c4 100644 --- a/backend/app/api/api.py +++ b/backend/app/api/api.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from app.api.endpoints import admin, auth, oidc, quota, repository, users +from app.api.endpoints import admin, auth, completion_conditions, oidc, quota, repository, users from app.api.endpoints.adapter import ( agents, bots, @@ -13,6 +13,7 @@ teams, ) from app.api.endpoints.kind import k_router +from app.api.endpoints.webhooks import github_router, gitlab_router from app.api.router import api_router api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) @@ -29,3 +30,18 @@ api_router.include_router(quota.router, prefix="/quota", tags=["quota"]) api_router.include_router(dify.router, prefix="/dify", tags=["dify"]) api_router.include_router(k_router) + +# Completion conditions and CI monitoring +api_router.include_router( + completion_conditions.router, + prefix="/completion-conditions", + tags=["completion-conditions"], +) + +# External webhooks (no auth required) +api_router.include_router( + github_router, prefix="/webhooks/github", tags=["webhooks"] +) +api_router.include_router( + gitlab_router, prefix="/webhooks/gitlab", tags=["webhooks"] +) diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py index a49ce7dc..1debe1b3 100644 --- a/backend/app/api/dependencies.py +++ b/backend/app/api/dependencies.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from typing import Generator from sqlalchemy.orm import Session @@ -19,3 +20,16 @@ def get_db() -> Generator[Session, None, None]: yield db finally: db.close() + + +@contextmanager +def get_db_context() -> Generator[Session, None, None]: + """ + Database session context manager for use outside of FastAPI dependency injection. + Use this when you need a database session in async functions or background tasks. + """ + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/backend/app/api/endpoints/completion_conditions.py b/backend/app/api/endpoints/completion_conditions.py new file mode 100644 index 00000000..d8b3e4eb --- /dev/null +++ b/backend/app/api/endpoints/completion_conditions.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Completion Conditions API endpoints +""" +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from app.api.dependencies import get_db +from app.core import security +from app.models.user import User +from app.schemas.completion_condition import ( + CompletionConditionCreate, + CompletionConditionInDB, + CompletionConditionListResponse, + TaskCompletionStatus, +) +from app.services.completion_condition import completion_condition_service + +router = APIRouter() + + +@router.get("", response_model=CompletionConditionListResponse) +def list_completion_conditions( + subtask_id: Optional[int] = Query(None, description="Filter by subtask ID"), + task_id: Optional[int] = Query(None, description="Filter by task ID"), + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +): + """ + List completion conditions with optional filters. + At least one of subtask_id or task_id must be provided. + """ + if subtask_id is None and task_id is None: + raise HTTPException( + status_code=400, + detail="At least one of subtask_id or task_id must be provided", + ) + + if subtask_id: + conditions = completion_condition_service.get_by_subtask_id( + db, subtask_id=subtask_id, user_id=current_user.id + ) + else: + conditions = completion_condition_service.get_by_task_id( + db, task_id=task_id, user_id=current_user.id + ) + + return CompletionConditionListResponse(total=len(conditions), items=conditions) + + +@router.get("/{condition_id}", response_model=CompletionConditionInDB) +def get_completion_condition( + condition_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +): + """Get a specific completion condition by ID""" + condition = completion_condition_service.get_by_id( + db, condition_id=condition_id, user_id=current_user.id + ) + if not condition: + raise HTTPException(status_code=404, detail="Completion condition not found") + return condition + + +@router.post("", response_model=CompletionConditionInDB) +def create_completion_condition( + condition_in: CompletionConditionCreate, + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +): + """Create a new completion condition""" + condition = completion_condition_service.create_condition( + db, obj_in=condition_in, user_id=current_user.id + ) + return condition + + +@router.delete("/{condition_id}/cancel") +def cancel_completion_condition( + condition_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +): + """Cancel a completion condition""" + from app.models.completion_condition import ConditionStatus + + condition = completion_condition_service.get_by_id( + db, condition_id=condition_id, user_id=current_user.id + ) + if not condition: + raise HTTPException(status_code=404, detail="Completion condition not found") + + if condition.status in [ConditionStatus.SATISFIED, ConditionStatus.FAILED]: + raise HTTPException( + status_code=400, + detail=f"Cannot cancel condition in {condition.status} status", + ) + + condition = completion_condition_service.update_status( + db, condition_id=condition_id, status=ConditionStatus.CANCELLED + ) + return {"status": "cancelled", "id": condition_id} + + +@router.get("/tasks/{task_id}/completion-status", response_model=TaskCompletionStatus) +def get_task_completion_status( + task_id: int, + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +): + """ + Get the overall completion status for a task, + including all completion conditions and their status. + """ + status = completion_condition_service.get_task_completion_status( + db, task_id=task_id, user_id=current_user.id + ) + + # Convert conditions to schema objects + from app.schemas.completion_condition import CompletionConditionInDB + + conditions_in_db = [ + CompletionConditionInDB.model_validate(c) for c in status["conditions"] + ] + + return TaskCompletionStatus( + task_id=task_id, + subtask_completed=True, # This would need to be checked from subtask status + all_conditions_satisfied=status["all_conditions_satisfied"], + pending_conditions=status["pending_conditions"], + in_progress_conditions=status["in_progress_conditions"], + satisfied_conditions=status["satisfied_conditions"], + failed_conditions=status["failed_conditions"], + conditions=conditions_in_db, + ) diff --git a/backend/app/api/endpoints/webhooks/__init__.py b/backend/app/api/endpoints/webhooks/__init__.py new file mode 100644 index 00000000..91c2add2 --- /dev/null +++ b/backend/app/api/endpoints/webhooks/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Webhooks package for handling external CI events +""" +from app.api.endpoints.webhooks.github import router as github_router +from app.api.endpoints.webhooks.gitlab import router as gitlab_router + +__all__ = ["github_router", "gitlab_router"] diff --git a/backend/app/api/endpoints/webhooks/github.py b/backend/app/api/endpoints/webhooks/github.py new file mode 100644 index 00000000..57dd1f79 --- /dev/null +++ b/backend/app/api/endpoints/webhooks/github.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +GitHub webhook endpoint for receiving CI events +""" +import hashlib +import hmac +import logging +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Header, HTTPException, Request + +from app.api.dependencies import get_db +from app.core.config import settings +from app.models.completion_condition import ConditionStatus, GitPlatform +from app.services.completion_condition import completion_condition_service +from app.services.ci_monitor import ci_monitor_service + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def verify_github_signature( + payload: bytes, + signature: Optional[str], + secret: str, +) -> bool: + """Verify GitHub webhook signature""" + if not signature or not secret: + return False + + if signature.startswith("sha256="): + signature = signature[7:] + expected = hmac.new( + secret.encode("utf-8"), + payload, + hashlib.sha256, + ).hexdigest() + elif signature.startswith("sha1="): + signature = signature[5:] + expected = hmac.new( + secret.encode("utf-8"), + payload, + hashlib.sha1, + ).hexdigest() + else: + return False + + return hmac.compare_digest(signature, expected) + + +@router.post("") +async def github_webhook( + request: Request, + x_hub_signature_256: Optional[str] = Header(None, alias="X-Hub-Signature-256"), + x_hub_signature: Optional[str] = Header(None, alias="X-Hub-Signature"), + x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"), + x_github_delivery: Optional[str] = Header(None, alias="X-GitHub-Delivery"), +): + """ + Handle GitHub webhook events for CI monitoring. + + Supported events: + - check_run: GitHub Actions check runs + - workflow_run: GitHub Actions workflow runs + """ + # Read raw body for signature verification + body = await request.body() + + # Verify signature if secret is configured + if settings.GITHUB_WEBHOOK_SECRET: + signature = x_hub_signature_256 or x_hub_signature + if not verify_github_signature(body, signature, settings.GITHUB_WEBHOOK_SECRET): + logger.warning( + f"GitHub webhook signature verification failed for delivery {x_github_delivery}" + ) + raise HTTPException(status_code=401, detail="Invalid signature") + + # Parse JSON payload + try: + payload = await request.json() + except Exception as e: + logger.error(f"Failed to parse GitHub webhook payload: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + logger.info( + f"Received GitHub webhook: event={x_github_event}, delivery={x_github_delivery}" + ) + + # Route to appropriate handler based on event type + if x_github_event == "check_run": + await handle_check_run(payload) + elif x_github_event == "workflow_run": + await handle_workflow_run(payload) + elif x_github_event == "ping": + logger.info("Received GitHub ping event") + return {"status": "pong"} + else: + logger.debug(f"Ignoring GitHub event: {x_github_event}") + + return {"status": "ok"} + + +async def handle_check_run(payload: Dict[str, Any]): + """Handle GitHub check_run event""" + action = payload.get("action") + check_run = payload.get("check_run", {}) + repository = payload.get("repository", {}) + + repo_full_name = repository.get("full_name", "") + check_name = check_run.get("name", "") + status = check_run.get("status", "") + conclusion = check_run.get("conclusion") + head_sha = check_run.get("head_sha", "") + head_branch = check_run.get("check_suite", {}).get("head_branch", "") + html_url = check_run.get("html_url", "") + external_id = str(check_run.get("id", "")) + + logger.info( + f"GitHub check_run: repo={repo_full_name}, branch={head_branch}, " + f"name={check_name}, action={action}, status={status}, conclusion={conclusion}" + ) + + # Only process configured check types + if settings.CI_CHECK_TYPES: + check_type_match = any( + ct.lower() in check_name.lower() for ct in settings.CI_CHECK_TYPES + ) + if not check_type_match: + logger.debug(f"Ignoring check_run '{check_name}' - not in configured types") + return + + # Process based on action + if action == "created" and status == "in_progress": + await ci_monitor_service.handle_ci_started( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + external_url=html_url, + git_platform=GitPlatform.GITHUB, + ) + elif action == "completed": + if conclusion == "success": + await ci_monitor_service.handle_ci_success( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + git_platform=GitPlatform.GITHUB, + ) + elif conclusion in ("failure", "cancelled", "timed_out"): + await ci_monitor_service.handle_ci_failure( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + conclusion=conclusion, + git_platform=GitPlatform.GITHUB, + ) + + +async def handle_workflow_run(payload: Dict[str, Any]): + """Handle GitHub workflow_run event""" + action = payload.get("action") + workflow_run = payload.get("workflow_run", {}) + repository = payload.get("repository", {}) + + repo_full_name = repository.get("full_name", "") + workflow_name = workflow_run.get("name", "") + status = workflow_run.get("status", "") + conclusion = workflow_run.get("conclusion") + head_branch = workflow_run.get("head_branch", "") + html_url = workflow_run.get("html_url", "") + external_id = str(workflow_run.get("id", "")) + run_number = workflow_run.get("run_number", "") + + logger.info( + f"GitHub workflow_run: repo={repo_full_name}, branch={head_branch}, " + f"workflow={workflow_name}, action={action}, status={status}, conclusion={conclusion}" + ) + + # Process based on action and status + if action == "requested" or (action == "in_progress" and status == "in_progress"): + await ci_monitor_service.handle_ci_started( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + external_url=html_url, + git_platform=GitPlatform.GITHUB, + ) + elif action == "completed": + if conclusion == "success": + await ci_monitor_service.handle_ci_success( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + git_platform=GitPlatform.GITHUB, + ) + elif conclusion in ("failure", "cancelled", "timed_out"): + await ci_monitor_service.handle_ci_failure( + repo_full_name=repo_full_name, + branch_name=head_branch, + external_id=external_id, + conclusion=conclusion, + git_platform=GitPlatform.GITHUB, + ) diff --git a/backend/app/api/endpoints/webhooks/gitlab.py b/backend/app/api/endpoints/webhooks/gitlab.py new file mode 100644 index 00000000..2dfa3d80 --- /dev/null +++ b/backend/app/api/endpoints/webhooks/gitlab.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +GitLab webhook endpoint for receiving CI events +""" +import logging +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Header, HTTPException, Request + +from app.core.config import settings +from app.models.completion_condition import GitPlatform +from app.services.ci_monitor import ci_monitor_service + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("") +async def gitlab_webhook( + request: Request, + x_gitlab_token: Optional[str] = Header(None, alias="X-Gitlab-Token"), + x_gitlab_event: Optional[str] = Header(None, alias="X-Gitlab-Event"), +): + """ + Handle GitLab webhook events for CI monitoring. + + Supported events: + - Pipeline Hook: GitLab CI/CD pipeline events + """ + # Verify token if configured + if settings.GITLAB_WEBHOOK_TOKEN: + if x_gitlab_token != settings.GITLAB_WEBHOOK_TOKEN: + logger.warning("GitLab webhook token verification failed") + raise HTTPException(status_code=401, detail="Invalid token") + + # Parse JSON payload + try: + payload = await request.json() + except Exception as e: + logger.error(f"Failed to parse GitLab webhook payload: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + object_kind = payload.get("object_kind", "") + logger.info(f"Received GitLab webhook: event={x_gitlab_event}, kind={object_kind}") + + # Route to appropriate handler based on event type + if object_kind == "pipeline": + await handle_pipeline_event(payload) + elif object_kind == "build": + await handle_build_event(payload) + else: + logger.debug(f"Ignoring GitLab event: {object_kind}") + + return {"status": "ok"} + + +async def handle_pipeline_event(payload: Dict[str, Any]): + """Handle GitLab pipeline event""" + object_attributes = payload.get("object_attributes", {}) + project = payload.get("project", {}) + + # Extract pipeline information + pipeline_id = str(object_attributes.get("id", "")) + status = object_attributes.get("status", "") + ref = object_attributes.get("ref", "") # Branch name + sha = object_attributes.get("sha", "") + pipeline_url = object_attributes.get("url", "") + + # Extract project information + repo_full_name = project.get("path_with_namespace", "") + git_domain = project.get("web_url", "").split("/")[2] if project.get("web_url") else "" + + logger.info( + f"GitLab pipeline: repo={repo_full_name}, branch={ref}, " + f"pipeline_id={pipeline_id}, status={status}" + ) + + # Process based on status + if status == "running": + await ci_monitor_service.handle_ci_started( + repo_full_name=repo_full_name, + branch_name=ref, + external_id=pipeline_id, + external_url=pipeline_url, + git_platform=GitPlatform.GITLAB, + git_domain=git_domain, + ) + elif status == "success": + await ci_monitor_service.handle_ci_success( + repo_full_name=repo_full_name, + branch_name=ref, + external_id=pipeline_id, + git_platform=GitPlatform.GITLAB, + ) + elif status in ("failed", "canceled", "skipped"): + await ci_monitor_service.handle_ci_failure( + repo_full_name=repo_full_name, + branch_name=ref, + external_id=pipeline_id, + conclusion=status, + git_platform=GitPlatform.GITLAB, + ) + + +async def handle_build_event(payload: Dict[str, Any]): + """Handle GitLab build/job event""" + object_attributes = payload.get("object_attributes", {}) + project = payload.get("project", {}) + + # Extract build information + build_id = str(object_attributes.get("id", "")) + build_name = object_attributes.get("name", "") + status = object_attributes.get("status", "") + ref = object_attributes.get("ref", "") + pipeline_id = str(object_attributes.get("pipeline_id", "")) + + # Extract project information + repo_full_name = project.get("path_with_namespace", "") + + logger.info( + f"GitLab build: repo={repo_full_name}, branch={ref}, " + f"build={build_name}, status={status}" + ) + + # We primarily track pipeline status, but log build events for debugging + # Individual build failures will be handled when the pipeline fails diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 913eb1bc..b634f89a 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -89,6 +89,13 @@ class Settings(BaseSettings): WEBHOOK_HEADERS: str = "" WEBHOOK_TIMEOUT: int = 30 + # CI Monitor configuration + CI_MONITOR_ENABLED: bool = True + CI_MAX_RETRIES: int = 5 + CI_CHECK_TYPES: str = "test,lint,build" # Comma-separated list of check types to monitor + GITHUB_WEBHOOK_SECRET: str = "" # GitHub webhook signature secret + GITLAB_WEBHOOK_TOKEN: str = "" # GitLab webhook verification token + # YAML initialization configuration INIT_DATA_DIR: str = "/app/init_data" INIT_DATA_ENABLED: bool = True diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 65bf81ad..509bbe4b 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -5,6 +5,7 @@ """ Models package """ +from app.models.completion_condition import CompletionCondition from app.models.kind import Kind from app.models.shared_team import SharedTeam from app.models.skill_binary import SkillBinary @@ -14,4 +15,4 @@ # All models should import Base directly from app.db.base from app.models.user import User -__all__ = ["User", "Kind", "Subtask", "SharedTeam", "SkillBinary"] +__all__ = ["User", "Kind", "Subtask", "SharedTeam", "SkillBinary", "CompletionCondition"] diff --git a/backend/app/models/completion_condition.py b/backend/app/models/completion_condition.py new file mode 100644 index 00000000..63babd94 --- /dev/null +++ b/backend/app/models/completion_condition.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Completion Condition model for tracking external async conditions +""" +from datetime import datetime +from enum import Enum as PyEnum + +from sqlalchemy import JSON, Column, DateTime, Integer, String, Text +from sqlalchemy import Enum as SQLEnum +from sqlalchemy.sql import func + +from app.db.base import Base + + +class ConditionType(str, PyEnum): + """Type of completion condition""" + + CI_PIPELINE = "CI_PIPELINE" + EXTERNAL_TASK = "EXTERNAL_TASK" + APPROVAL = "APPROVAL" + MANUAL_CONFIRM = "MANUAL_CONFIRM" + + +class ConditionStatus(str, PyEnum): + """Status of completion condition""" + + PENDING = "PENDING" + IN_PROGRESS = "IN_PROGRESS" + SATISFIED = "SATISFIED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" + + +class GitPlatform(str, PyEnum): + """Git platform type""" + + GITHUB = "GITHUB" + GITLAB = "GITLAB" + + +class CompletionCondition(Base): + """ + Completion condition model for tracking async external conditions. + Used to track CI pipelines, external approvals, and other conditions + that must be satisfied before a task is considered truly complete. + """ + + __tablename__ = "completion_conditions" + + id = Column(Integer, primary_key=True, index=True) + subtask_id = Column(Integer, nullable=False, index=True) + task_id = Column(Integer, nullable=False, index=True) + user_id = Column(Integer, nullable=False, index=True) + + # Condition type and status + condition_type = Column( + SQLEnum(ConditionType), nullable=False, default=ConditionType.CI_PIPELINE + ) + status = Column( + SQLEnum(ConditionStatus), nullable=False, default=ConditionStatus.PENDING + ) + + # External resource identification + external_id = Column(String(256), nullable=True) # PR number, Pipeline ID, etc. + external_url = Column(String(1024), nullable=True) # Link to external resource + + # Git platform information + git_platform = Column(SQLEnum(GitPlatform), nullable=True) + git_domain = Column(String(256), nullable=True) + repo_full_name = Column(String(512), nullable=True) # owner/repo format + branch_name = Column(String(256), nullable=True, index=True) + + # Auto-fix retry tracking + retry_count = Column(Integer, nullable=False, default=0) + max_retries = Column(Integer, nullable=False, default=5) + last_failure_log = Column(Text, nullable=True) + + # Additional metadata (extensible) + metadata = Column(JSON, nullable=True) + + # Timestamps + created_at = Column(DateTime, default=func.now()) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) + satisfied_at = Column(DateTime, nullable=True) + + __table_args__ = ( + { + "sqlite_autoincrement": True, + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_unicode_ci", + }, + ) diff --git a/backend/app/schemas/completion_condition.py b/backend/app/schemas/completion_condition.py new file mode 100644 index 00000000..e2dcc7e4 --- /dev/null +++ b/backend/app/schemas/completion_condition.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Completion Condition schemas for API requests and responses +""" +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + + +class ConditionType(str, Enum): + """Type of completion condition""" + + CI_PIPELINE = "CI_PIPELINE" + EXTERNAL_TASK = "EXTERNAL_TASK" + APPROVAL = "APPROVAL" + MANUAL_CONFIRM = "MANUAL_CONFIRM" + + +class ConditionStatus(str, Enum): + """Status of completion condition""" + + PENDING = "PENDING" + IN_PROGRESS = "IN_PROGRESS" + SATISFIED = "SATISFIED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" + + +class GitPlatform(str, Enum): + """Git platform type""" + + GITHUB = "GITHUB" + GITLAB = "GITLAB" + + +class CompletionConditionBase(BaseModel): + """Base schema for completion condition""" + + subtask_id: int + task_id: int + condition_type: ConditionType = ConditionType.CI_PIPELINE + status: ConditionStatus = ConditionStatus.PENDING + external_id: Optional[str] = None + external_url: Optional[str] = None + git_platform: Optional[GitPlatform] = None + git_domain: Optional[str] = None + repo_full_name: Optional[str] = None + branch_name: Optional[str] = None + max_retries: int = 5 + metadata: Optional[Dict[str, Any]] = None + + +class CompletionConditionCreate(CompletionConditionBase): + """Schema for creating a completion condition""" + + pass + + +class CompletionConditionUpdate(BaseModel): + """Schema for updating a completion condition""" + + status: Optional[ConditionStatus] = None + external_id: Optional[str] = None + external_url: Optional[str] = None + retry_count: Optional[int] = None + last_failure_log: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + satisfied_at: Optional[datetime] = None + + +class CompletionConditionInDB(CompletionConditionBase): + """Schema for completion condition from database""" + + id: int + user_id: int + retry_count: int = 0 + last_failure_log: Optional[str] = None + created_at: datetime + updated_at: datetime + satisfied_at: Optional[datetime] = None + + class Config: + from_attributes = True + + +class CompletionConditionListResponse(BaseModel): + """Paginated response for completion conditions""" + + total: int + items: List[CompletionConditionInDB] + + +class TaskCompletionStatus(BaseModel): + """Overall completion status for a task""" + + task_id: int + subtask_completed: bool + all_conditions_satisfied: bool + pending_conditions: int + in_progress_conditions: int + satisfied_conditions: int + failed_conditions: int + conditions: List[CompletionConditionInDB] + + +# Webhook event schemas +class CIEventType(str, Enum): + """CI event types""" + + PIPELINE_STARTED = "pipeline_started" + PIPELINE_SUCCESS = "pipeline_success" + PIPELINE_FAILED = "pipeline_failed" + CHECK_RUN_STARTED = "check_run_started" + CHECK_RUN_COMPLETED = "check_run_completed" + + +class CIWebhookEvent(BaseModel): + """Schema for CI webhook events""" + + event_type: CIEventType + repo_full_name: str + branch_name: str + external_id: str + external_url: Optional[str] = None + conclusion: Optional[str] = None # success, failure, etc. + logs_url: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +class GitHubCheckRunEvent(BaseModel): + """Schema for GitHub check_run webhook event""" + + action: str # created, completed, rerequested + check_run: Dict[str, Any] + repository: Dict[str, Any] + + +class GitHubWorkflowRunEvent(BaseModel): + """Schema for GitHub workflow_run webhook event""" + + action: str # requested, completed + workflow_run: Dict[str, Any] + repository: Dict[str, Any] + + +class GitLabPipelineEvent(BaseModel): + """Schema for GitLab pipeline webhook event""" + + object_kind: str # pipeline + object_attributes: Dict[str, Any] + project: Dict[str, Any] diff --git a/backend/app/services/ci_monitor.py b/backend/app/services/ci_monitor.py new file mode 100644 index 00000000..9ad604b1 --- /dev/null +++ b/backend/app/services/ci_monitor.py @@ -0,0 +1,429 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +CI Monitor Service for handling CI events and auto-fix workflow +""" +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session + +from app.api.dependencies import get_db_context +from app.core.config import settings +from app.models.completion_condition import ( + CompletionCondition, + ConditionStatus, + ConditionType, + GitPlatform, +) +from app.services.completion_condition import completion_condition_service +from app.services.webhook_notification import Notification, webhook_notification_service + +logger = logging.getLogger(__name__) + + +class CIMonitorService: + """Service for monitoring CI events and triggering auto-fix""" + + def __init__(self): + self.enabled = getattr(settings, "CI_MONITOR_ENABLED", True) + self.max_retries = getattr(settings, "CI_MAX_RETRIES", 5) + + async def handle_ci_started( + self, + repo_full_name: str, + branch_name: str, + external_id: str, + external_url: Optional[str] = None, + git_platform: GitPlatform = GitPlatform.GITHUB, + git_domain: Optional[str] = None, + ): + """Handle CI pipeline started event""" + if not self.enabled: + return + + with get_db_context() as db: + # Find matching conditions + conditions = completion_condition_service.find_by_repo_and_branch( + db, + repo_full_name=repo_full_name, + branch_name=branch_name, + git_platform=git_platform, + status_list=[ConditionStatus.PENDING], + ) + + if not conditions: + logger.debug( + f"No pending conditions found for {repo_full_name}/{branch_name}" + ) + return + + # Update conditions to IN_PROGRESS + for condition in conditions: + condition.status = ConditionStatus.IN_PROGRESS + condition.external_id = external_id + if external_url: + condition.external_url = external_url + + db.commit() + logger.info( + f"Updated {len(conditions)} conditions to IN_PROGRESS for " + f"{repo_full_name}/{branch_name}" + ) + + async def handle_ci_success( + self, + repo_full_name: str, + branch_name: str, + external_id: str, + git_platform: GitPlatform = GitPlatform.GITHUB, + ): + """Handle CI pipeline success event""" + if not self.enabled: + return + + with get_db_context() as db: + # Find matching conditions + conditions = completion_condition_service.find_by_repo_and_branch( + db, + repo_full_name=repo_full_name, + branch_name=branch_name, + git_platform=git_platform, + status_list=[ConditionStatus.PENDING, ConditionStatus.IN_PROGRESS], + ) + + if not conditions: + logger.debug( + f"No active conditions found for {repo_full_name}/{branch_name}" + ) + return + + # Update conditions to SATISFIED + for condition in conditions: + condition.status = ConditionStatus.SATISFIED + condition.satisfied_at = datetime.utcnow() + + db.commit() + logger.info( + f"Updated {len(conditions)} conditions to SATISFIED for " + f"{repo_full_name}/{branch_name}" + ) + + # Send notifications for each satisfied condition + for condition in conditions: + await self._send_condition_satisfied_notification(db, condition) + + # Check if all conditions for the task are satisfied + await self._check_task_fully_completed(db, condition.task_id) + + async def handle_ci_failure( + self, + repo_full_name: str, + branch_name: str, + external_id: str, + conclusion: str, + git_platform: GitPlatform = GitPlatform.GITHUB, + ): + """Handle CI pipeline failure event""" + if not self.enabled: + return + + with get_db_context() as db: + # Find matching conditions + conditions = completion_condition_service.find_by_repo_and_branch( + db, + repo_full_name=repo_full_name, + branch_name=branch_name, + git_platform=git_platform, + status_list=[ConditionStatus.PENDING, ConditionStatus.IN_PROGRESS], + ) + + if not conditions: + logger.debug( + f"No active conditions found for {repo_full_name}/{branch_name}" + ) + return + + for condition in conditions: + await self._handle_condition_failure( + db, condition, conclusion, git_platform + ) + + async def _handle_condition_failure( + self, + db: Session, + condition: CompletionCondition, + conclusion: str, + git_platform: GitPlatform, + ): + """Handle a single condition failure with retry logic""" + logger.info( + f"Handling CI failure for condition {condition.id}: " + f"retry_count={condition.retry_count}, max_retries={condition.max_retries}" + ) + + # Check if we can retry + if completion_condition_service.can_retry(condition): + # Attempt auto-fix + await self._trigger_auto_fix(db, condition, conclusion, git_platform) + else: + # Max retries reached + condition.status = ConditionStatus.FAILED + condition.last_failure_log = f"Max retries ({condition.max_retries}) reached. Last conclusion: {conclusion}" + db.commit() + + logger.warning( + f"Condition {condition.id} failed after {condition.retry_count} retries" + ) + + # Send max retry notification + await self._send_max_retry_notification(db, condition) + + async def _trigger_auto_fix( + self, + db: Session, + condition: CompletionCondition, + conclusion: str, + git_platform: GitPlatform, + ): + """Trigger auto-fix workflow by creating a new subtask""" + logger.info(f"Triggering auto-fix for condition {condition.id}") + + # Fetch CI logs + ci_logs = await self._fetch_ci_logs(condition, git_platform) + + # Increment retry count + completion_condition_service.increment_retry( + db, + condition_id=condition.id, + failure_log=ci_logs[:10000] if ci_logs else f"CI failed with conclusion: {conclusion}", + ) + + # Create a fix subtask + await self._create_fix_subtask(db, condition, ci_logs, conclusion) + + # Send notification about auto-fix attempt + await self._send_ci_failed_notification(db, condition, conclusion) + + async def _fetch_ci_logs( + self, + condition: CompletionCondition, + git_platform: GitPlatform, + ) -> Optional[str]: + """Fetch CI logs from the git platform""" + try: + if git_platform == GitPlatform.GITHUB: + from app.repository.github_provider import GitHubProvider + + provider = GitHubProvider() + # Note: This would need user context for authentication + # For now, return a placeholder + return f"CI logs for GitHub run {condition.external_id} - implement log fetching" + + elif git_platform == GitPlatform.GITLAB: + from app.repository.gitlab_provider import GitLabProvider + + provider = GitLabProvider() + return f"CI logs for GitLab pipeline {condition.external_id} - implement log fetching" + + except Exception as e: + logger.error(f"Failed to fetch CI logs: {e}") + return None + + async def _create_fix_subtask( + self, + db: Session, + condition: CompletionCondition, + ci_logs: Optional[str], + conclusion: str, + ): + """Create a new subtask to fix CI issues""" + from app.services.subtask import subtask_service + + # Build the fix prompt + fix_prompt = self._build_fix_prompt(condition, ci_logs, conclusion) + + # Get the original subtask to find context + original_subtask = subtask_service.get_subtask_by_id( + db, subtask_id=condition.subtask_id, user_id=condition.user_id + ) + + if not original_subtask: + logger.error(f"Original subtask {condition.subtask_id} not found") + return + + # Create new fix subtask + from app.schemas.subtask import SubtaskCreate + + fix_subtask_data = SubtaskCreate( + task_id=condition.task_id, + team_id=original_subtask.team_id, + title=f"Auto-fix CI failure (attempt {condition.retry_count})", + bot_ids=original_subtask.bot_ids, + prompt=fix_prompt, + parent_id=original_subtask.id, + message_id=original_subtask.message_id + 1, + ) + + fix_subtask = subtask_service.create_subtask( + db, obj_in=fix_subtask_data, user_id=condition.user_id + ) + + logger.info( + f"Created fix subtask {fix_subtask.id} for condition {condition.id}" + ) + + # TODO: Trigger subtask execution through executor manager + + def _build_fix_prompt( + self, + condition: CompletionCondition, + ci_logs: Optional[str], + conclusion: str, + ) -> str: + """Build the prompt for the fix subtask""" + prompt_parts = [ + f"CI pipeline failed with conclusion: {conclusion}", + f"This is auto-fix attempt {condition.retry_count + 1} of {condition.max_retries}.", + "", + "Please analyze the CI failure and fix the issues.", + ] + + if ci_logs: + prompt_parts.extend([ + "", + "## CI Failure Logs", + "```", + ci_logs[:8000], # Limit log size + "```", + ]) + + prompt_parts.extend([ + "", + "## Instructions", + "1. Analyze the CI failure logs above", + "2. Identify the root cause of the failure", + "3. Make the necessary code changes to fix the issue", + "4. Commit and push the fix", + ]) + + return "\n".join(prompt_parts) + + async def _check_task_fully_completed(self, db: Session, task_id: int): + """Check if all conditions for a task are satisfied and send notification""" + status = completion_condition_service.get_task_completion_status( + db, task_id=task_id + ) + + if status["all_conditions_satisfied"]: + logger.info(f"Task {task_id} is fully completed with all CI checks passed") + await self._send_task_fully_completed_notification(db, task_id) + + async def _send_condition_satisfied_notification( + self, db: Session, condition: CompletionCondition + ): + """Send notification when a condition is satisfied""" + try: + from app.models.user import User + + user = db.query(User).filter(User.id == condition.user_id).first() + user_name = user.user_name if user else "unknown" + + notification = Notification( + user_name=user_name, + event="condition.satisfied", + id=str(condition.id), + start_time=condition.created_at.isoformat() if condition.created_at else "", + end_time=datetime.utcnow().isoformat(), + description=f"CI check passed for {condition.repo_full_name}/{condition.branch_name}", + status="satisfied", + detail_url=condition.external_url or "", + ) + await webhook_notification_service.send_notification(notification) + except Exception as e: + logger.error(f"Failed to send condition satisfied notification: {e}") + + async def _send_ci_failed_notification( + self, db: Session, condition: CompletionCondition, conclusion: str + ): + """Send notification when CI fails and auto-fix is triggered""" + try: + from app.models.user import User + + user = db.query(User).filter(User.id == condition.user_id).first() + user_name = user.user_name if user else "unknown" + + notification = Notification( + user_name=user_name, + event="condition.ci_failed", + id=str(condition.id), + start_time=condition.created_at.isoformat() if condition.created_at else "", + end_time=datetime.utcnow().isoformat(), + description=f"CI failed ({conclusion}) for {condition.repo_full_name}/{condition.branch_name}. Auto-fix attempt {condition.retry_count}/{condition.max_retries}", + status="failed", + detail_url=condition.external_url or "", + ) + await webhook_notification_service.send_notification(notification) + except Exception as e: + logger.error(f"Failed to send CI failed notification: {e}") + + async def _send_max_retry_notification( + self, db: Session, condition: CompletionCondition + ): + """Send notification when max retries are reached""" + try: + from app.models.user import User + + user = db.query(User).filter(User.id == condition.user_id).first() + user_name = user.user_name if user else "unknown" + + notification = Notification( + user_name=user_name, + event="condition.max_retry_reached", + id=str(condition.id), + start_time=condition.created_at.isoformat() if condition.created_at else "", + end_time=datetime.utcnow().isoformat(), + description=f"Max retries ({condition.max_retries}) reached for {condition.repo_full_name}/{condition.branch_name}. Manual intervention required.", + status="failed", + detail_url=condition.external_url or "", + ) + await webhook_notification_service.send_notification(notification) + except Exception as e: + logger.error(f"Failed to send max retry notification: {e}") + + async def _send_task_fully_completed_notification( + self, db: Session, task_id: int + ): + """Send notification when task is fully completed with all CI checks""" + try: + from app.models.user import User + from app.services.adapters.task_kinds import task_kinds_service + + # Get task info + task = task_kinds_service.get_task_by_id(db, task_id=task_id) + if not task: + return + + user = db.query(User).filter(User.id == task.get("user_id")).first() + user_name = user.user_name if user else "unknown" + + notification = Notification( + user_name=user_name, + event="task.fully_completed", + id=str(task_id), + start_time=task.get("created_at", ""), + end_time=datetime.utcnow().isoformat(), + description=f"Task '{task.get('title', '')}' completed with all CI checks passed", + status="completed", + detail_url=f"{settings.FRONTEND_URL}/tasks/{task_id}", + ) + await webhook_notification_service.send_notification(notification) + except Exception as e: + logger.error(f"Failed to send task fully completed notification: {e}") + + +# Global service instance +ci_monitor_service = CIMonitorService() diff --git a/backend/app/services/completion_condition.py b/backend/app/services/completion_condition.py new file mode 100644 index 00000000..bcd29a61 --- /dev/null +++ b/backend/app/services/completion_condition.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Completion Condition Service for managing async completion conditions +""" +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.models.completion_condition import ( + CompletionCondition, + ConditionStatus, + ConditionType, + GitPlatform, +) +from app.schemas.completion_condition import ( + CompletionConditionCreate, + CompletionConditionUpdate, +) +from app.services.base import BaseService + +logger = logging.getLogger(__name__) + + +class CompletionConditionService( + BaseService[CompletionCondition, CompletionConditionCreate, CompletionConditionUpdate] +): + """Service for managing completion conditions""" + + def __init__(self): + super().__init__(CompletionCondition) + + def create_condition( + self, + db: Session, + *, + obj_in: CompletionConditionCreate, + user_id: int, + ) -> CompletionCondition: + """Create a new completion condition""" + db_obj = CompletionCondition( + subtask_id=obj_in.subtask_id, + task_id=obj_in.task_id, + user_id=user_id, + condition_type=obj_in.condition_type, + status=obj_in.status, + external_id=obj_in.external_id, + external_url=obj_in.external_url, + git_platform=obj_in.git_platform, + git_domain=obj_in.git_domain, + repo_full_name=obj_in.repo_full_name, + branch_name=obj_in.branch_name, + max_retries=obj_in.max_retries, + metadata=obj_in.metadata, + ) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + logger.info( + f"Created completion condition {db_obj.id} for subtask {obj_in.subtask_id}" + ) + return db_obj + + def get_by_id( + self, + db: Session, + *, + condition_id: int, + user_id: Optional[int] = None, + ) -> Optional[CompletionCondition]: + """Get a completion condition by ID""" + query = db.query(CompletionCondition).filter( + CompletionCondition.id == condition_id + ) + if user_id is not None: + query = query.filter(CompletionCondition.user_id == user_id) + return query.first() + + def get_by_subtask_id( + self, + db: Session, + *, + subtask_id: int, + user_id: Optional[int] = None, + ) -> List[CompletionCondition]: + """Get all completion conditions for a subtask""" + query = db.query(CompletionCondition).filter( + CompletionCondition.subtask_id == subtask_id + ) + if user_id is not None: + query = query.filter(CompletionCondition.user_id == user_id) + return query.all() + + def get_by_task_id( + self, + db: Session, + *, + task_id: int, + user_id: Optional[int] = None, + ) -> List[CompletionCondition]: + """Get all completion conditions for a task""" + query = db.query(CompletionCondition).filter( + CompletionCondition.task_id == task_id + ) + if user_id is not None: + query = query.filter(CompletionCondition.user_id == user_id) + return query.all() + + def find_by_repo_and_branch( + self, + db: Session, + *, + repo_full_name: str, + branch_name: str, + git_platform: Optional[GitPlatform] = None, + status_list: Optional[List[ConditionStatus]] = None, + ) -> List[CompletionCondition]: + """Find completion conditions by repository and branch""" + query = db.query(CompletionCondition).filter( + and_( + CompletionCondition.repo_full_name == repo_full_name, + CompletionCondition.branch_name == branch_name, + ) + ) + if git_platform: + query = query.filter(CompletionCondition.git_platform == git_platform) + if status_list: + query = query.filter(CompletionCondition.status.in_(status_list)) + return query.all() + + def update_status( + self, + db: Session, + *, + condition_id: int, + status: ConditionStatus, + failure_log: Optional[str] = None, + ) -> Optional[CompletionCondition]: + """Update the status of a completion condition""" + condition = self.get_by_id(db, condition_id=condition_id) + if not condition: + return None + + condition.status = status + if failure_log: + condition.last_failure_log = failure_log + + if status == ConditionStatus.SATISFIED: + condition.satisfied_at = datetime.utcnow() + + db.commit() + db.refresh(condition) + logger.info(f"Updated condition {condition_id} status to {status}") + return condition + + def increment_retry( + self, + db: Session, + *, + condition_id: int, + failure_log: Optional[str] = None, + ) -> Optional[CompletionCondition]: + """Increment retry count and update failure log""" + condition = self.get_by_id(db, condition_id=condition_id) + if not condition: + return None + + condition.retry_count += 1 + if failure_log: + condition.last_failure_log = failure_log + condition.status = ConditionStatus.PENDING + + db.commit() + db.refresh(condition) + logger.info( + f"Incremented retry count for condition {condition_id} to {condition.retry_count}" + ) + return condition + + def can_retry(self, condition: CompletionCondition) -> bool: + """Check if a condition can be retried""" + return condition.retry_count < condition.max_retries + + def cancel_by_subtask( + self, + db: Session, + *, + subtask_id: int, + ) -> int: + """Cancel all pending/in_progress conditions for a subtask""" + conditions = ( + db.query(CompletionCondition) + .filter( + and_( + CompletionCondition.subtask_id == subtask_id, + CompletionCondition.status.in_( + [ConditionStatus.PENDING, ConditionStatus.IN_PROGRESS] + ), + ) + ) + .all() + ) + + count = 0 + for condition in conditions: + condition.status = ConditionStatus.CANCELLED + count += 1 + + db.commit() + logger.info(f"Cancelled {count} conditions for subtask {subtask_id}") + return count + + def cancel_by_task( + self, + db: Session, + *, + task_id: int, + ) -> int: + """Cancel all pending/in_progress conditions for a task""" + conditions = ( + db.query(CompletionCondition) + .filter( + and_( + CompletionCondition.task_id == task_id, + CompletionCondition.status.in_( + [ConditionStatus.PENDING, ConditionStatus.IN_PROGRESS] + ), + ) + ) + .all() + ) + + count = 0 + for condition in conditions: + condition.status = ConditionStatus.CANCELLED + count += 1 + + db.commit() + logger.info(f"Cancelled {count} conditions for task {task_id}") + return count + + def get_task_completion_status( + self, + db: Session, + *, + task_id: int, + user_id: Optional[int] = None, + ) -> Dict[str, Any]: + """Get overall completion status for a task""" + conditions = self.get_by_task_id(db, task_id=task_id, user_id=user_id) + + pending = sum(1 for c in conditions if c.status == ConditionStatus.PENDING) + in_progress = sum( + 1 for c in conditions if c.status == ConditionStatus.IN_PROGRESS + ) + satisfied = sum( + 1 for c in conditions if c.status == ConditionStatus.SATISFIED + ) + failed = sum(1 for c in conditions if c.status == ConditionStatus.FAILED) + + # All conditions are satisfied if there are conditions and none are pending/in_progress/failed + all_satisfied = ( + len(conditions) > 0 + and pending == 0 + and in_progress == 0 + and failed == 0 + ) + + return { + "task_id": task_id, + "total_conditions": len(conditions), + "pending_conditions": pending, + "in_progress_conditions": in_progress, + "satisfied_conditions": satisfied, + "failed_conditions": failed, + "all_conditions_satisfied": all_satisfied, + "conditions": conditions, + } + + def has_unsatisfied_conditions( + self, + db: Session, + *, + subtask_id: int, + ) -> bool: + """Check if a subtask has any unsatisfied (pending/in_progress) conditions""" + count = ( + db.query(CompletionCondition) + .filter( + and_( + CompletionCondition.subtask_id == subtask_id, + CompletionCondition.status.in_( + [ConditionStatus.PENDING, ConditionStatus.IN_PROGRESS] + ), + ) + ) + .count() + ) + return count > 0 + + +# Global service instance +completion_condition_service = CompletionConditionService()