diff --git a/application_sdk/activities/query_extraction/sql.py b/application_sdk/activities/query_extraction/sql.py index f712e3439..6ec6b51b8 100644 --- a/application_sdk/activities/query_extraction/sql.py +++ b/application_sdk/activities/query_extraction/sql.py @@ -1,7 +1,7 @@ import json import os from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, cast from pydantic import BaseModel, Field from temporalio import activity @@ -13,6 +13,7 @@ get_workflow_id, ) from application_sdk.clients.sql import BaseSQLClient +from application_sdk.common.utils import parse_credentials_extra from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME from application_sdk.handlers import HandlerInterface from application_sdk.handlers.sql import BaseSQLHandler @@ -422,7 +423,9 @@ async def write_marker( ) logger.info(f"Marker file written to {marker_file_path}") - async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]: + async def read_marker( + self, workflow_args: Dict[str, Any], output_path: Optional[str] = None + ) -> Optional[int]: """Read the marker from the output path. This method reads the current marker value from a marker file to determine the @@ -441,8 +444,13 @@ async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]: Exception: If marker file reading fails (logged as warning, not re-raised) """ try: - output_path = workflow_args["output_path"].rsplit("/", 1)[0] - marker_file_path = os.path.join(output_path, "markerfile") + base_output_path: str = cast(str, workflow_args["output_path"]) + resolved_output_path: str = ( + output_path + if isinstance(output_path, str) and output_path + else base_output_path.rsplit("/", 1)[0] + ) + marker_file_path = os.path.join(resolved_output_path, "markerfile") logger.info(f"Downloading marker file from {marker_file_path}") await ObjectStore.download_file( @@ -463,6 +471,25 @@ async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]: logger.warning(f"Failed to read marker: {e}") return None + @activity.defn(name="miner_preflight_check") + @auto_heartbeater + async def preflight_check(self, workflow_args: Dict[str, Any]): + return await super().preflight_check(workflow_args) + + @activity.defn(name="miner_get_workflow_args") + @auto_heartbeater + async def get_workflow_args(self, workflow_config: Dict[str, Any]): + workflow_args = await super().get_workflow_args(workflow_config) + if "credential_guid" in workflow_args: + credentials = await SecretStore.get_credentials( + credential_guid=workflow_args["credential_guid"] + ) + extra = parse_credentials_extra(credentials) + workflow_args["deployment_type"] = extra.get( + "deployment_type", "provisioned" + ) + return workflow_args + @activity.defn @auto_heartbeater async def get_query_batches(