diff --git a/tests/asyncio/test_context.py b/tests/asyncio/test_context.py index 046d610..f6471e6 100644 --- a/tests/asyncio/test_context.py +++ b/tests/asyncio/test_context.py @@ -30,7 +30,8 @@ async def test_workflow_headers(qstash_client: AsyncQStash) -> None: url=WORKFLOW_ENDPOINT, initial_payload="my-payload", env=None, - retries=None, + retries=1, + failure_url=WORKFLOW_ENDPOINT, ) async def execute() -> None: @@ -60,28 +61,41 @@ async def execute() -> None: "queue": None, "headers": { "Content-Type": "application/json", - "Upstash-Callback": WORKFLOW_ENDPOINT, + "Upstash-Method": "PATCH", + "Upstash-Workflow-Init": "false", + "Upstash-Workflow-RunId": "wfr-id", + "Upstash-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Feature-Set": "WF_NoDelete,InitialBody", + "Upstash-Failure-Callback-Forward-Upstash-Workflow-Is-Failure": "true", + "Upstash-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback": "true", + "Upstash-Failure-Callback-Workflow-Runid": "wfr-id", + "Upstash-Failure-Callback-Workflow-Init": "false", + "Upstash-Failure-Callback-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Failure-Callback-Workflow-Calltype": "failureCall", + "Upstash-Callback-Failure-Callback-Forward-Upstash-Workflow-Is-Failure": "true", + "Upstash-Callback-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback": "true", + "Upstash-Callback-Failure-Callback-Workflow-Runid": "wfr-id", + "Upstash-Callback-Failure-Callback-Workflow-Init": "false", + "Upstash-Callback-Failure-Callback-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Callback-Failure-Callback-Workflow-Calltype": "failureCall", + "Upstash-Failure-Callback-Retries": "1", + "Upstash-Callback-Failure-Callback-Retries": "1", + "Upstash-Retries": "10", + "Upstash-Callback-Retries": "1", + "Upstash-Forward-my-header": "my-value", + "Upstash-Callback": "https://www.my-website.com/api", + "Upstash-Callback-Workflow-RunId": "wfr-id", + "Upstash-Callback-Workflow-CallType": "fromCallback", + "Upstash-Callback-Workflow-Init": "false", + "Upstash-Callback-Workflow-Url": "https://www.my-website.com/api", "Upstash-Callback-Feature-Set": "LazyFetch,InitialBody", "Upstash-Callback-Forward-Upstash-Workflow-Callback": "true", - "Upstash-Callback-Forward-Upstash-Workflow-Concurrent": "1", - "Upstash-Callback-Forward-Upstash-Workflow-ContentType": "application/json", "Upstash-Callback-Forward-Upstash-Workflow-StepId": "1", "Upstash-Callback-Forward-Upstash-Workflow-StepName": "my-step", "Upstash-Callback-Forward-Upstash-Workflow-StepType": "Call", - "Upstash-Callback-Retries": "3", - "Upstash-Callback-Workflow-CallType": "fromCallback", - "Upstash-Callback-Workflow-Init": "false", - "Upstash-Callback-Workflow-RunId": "wfr-id", - "Upstash-Callback-Workflow-Url": WORKFLOW_ENDPOINT, - "Upstash-Failure-Callback-Retries": "3", - "Upstash-Feature-Set": "WF_NoDelete,InitialBody", - "Upstash-Forward-my-header": "my-value", - "Upstash-Method": "PATCH", - "Upstash-Retries": str(retries), + "Upstash-Callback-Forward-Upstash-Workflow-Concurrent": "1", + "Upstash-Callback-Forward-Upstash-Workflow-ContentType": "application/json", "Upstash-Workflow-CallType": "toCallback", - "Upstash-Workflow-Init": "false", - "Upstash-Workflow-RunId": "wfr-id", - "Upstash-Workflow-Url": WORKFLOW_ENDPOINT, }, } ], diff --git a/tests/test_context.py b/tests/test_context.py index 305374e..ec9845b 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -29,7 +29,8 @@ def test_workflow_headers(qstash_client: QStash) -> None: url=WORKFLOW_ENDPOINT, initial_payload="my-payload", env=None, - retries=None, + retries=1, + failure_url=WORKFLOW_ENDPOINT, ) def execute() -> None: @@ -59,28 +60,41 @@ def execute() -> None: "queue": None, "headers": { "Content-Type": "application/json", - "Upstash-Callback": WORKFLOW_ENDPOINT, + "Upstash-Method": "PATCH", + "Upstash-Workflow-Init": "false", + "Upstash-Workflow-RunId": "wfr-id", + "Upstash-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Feature-Set": "WF_NoDelete,InitialBody", + "Upstash-Failure-Callback-Forward-Upstash-Workflow-Is-Failure": "true", + "Upstash-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback": "true", + "Upstash-Failure-Callback-Workflow-Runid": "wfr-id", + "Upstash-Failure-Callback-Workflow-Init": "false", + "Upstash-Failure-Callback-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Failure-Callback-Workflow-Calltype": "failureCall", + "Upstash-Callback-Failure-Callback-Forward-Upstash-Workflow-Is-Failure": "true", + "Upstash-Callback-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback": "true", + "Upstash-Callback-Failure-Callback-Workflow-Runid": "wfr-id", + "Upstash-Callback-Failure-Callback-Workflow-Init": "false", + "Upstash-Callback-Failure-Callback-Workflow-Url": "https://www.my-website.com/api", + "Upstash-Callback-Failure-Callback-Workflow-Calltype": "failureCall", + "Upstash-Failure-Callback-Retries": "1", + "Upstash-Callback-Failure-Callback-Retries": "1", + "Upstash-Retries": "10", + "Upstash-Callback-Retries": "1", + "Upstash-Forward-my-header": "my-value", + "Upstash-Callback": "https://www.my-website.com/api", + "Upstash-Callback-Workflow-RunId": "wfr-id", + "Upstash-Callback-Workflow-CallType": "fromCallback", + "Upstash-Callback-Workflow-Init": "false", + "Upstash-Callback-Workflow-Url": "https://www.my-website.com/api", "Upstash-Callback-Feature-Set": "LazyFetch,InitialBody", "Upstash-Callback-Forward-Upstash-Workflow-Callback": "true", - "Upstash-Callback-Forward-Upstash-Workflow-Concurrent": "1", - "Upstash-Callback-Forward-Upstash-Workflow-ContentType": "application/json", "Upstash-Callback-Forward-Upstash-Workflow-StepId": "1", "Upstash-Callback-Forward-Upstash-Workflow-StepName": "my-step", "Upstash-Callback-Forward-Upstash-Workflow-StepType": "Call", - "Upstash-Callback-Retries": "3", - "Upstash-Callback-Workflow-CallType": "fromCallback", - "Upstash-Callback-Workflow-Init": "false", - "Upstash-Callback-Workflow-RunId": "wfr-id", - "Upstash-Callback-Workflow-Url": WORKFLOW_ENDPOINT, - "Upstash-Failure-Callback-Retries": "3", - "Upstash-Feature-Set": "WF_NoDelete,InitialBody", - "Upstash-Forward-my-header": "my-value", - "Upstash-Method": "PATCH", - "Upstash-Retries": str(retries), + "Upstash-Callback-Forward-Upstash-Workflow-Concurrent": "1", + "Upstash-Callback-Forward-Upstash-Workflow-ContentType": "application/json", "Upstash-Workflow-CallType": "toCallback", - "Upstash-Workflow-Init": "false", - "Upstash-Workflow-RunId": "wfr-id", - "Upstash-Workflow-Url": WORKFLOW_ENDPOINT, }, } ], diff --git a/upstash_workflow/asyncio/context/auto_executor.py b/upstash_workflow/asyncio/context/auto_executor.py index 6615389..b1b4c13 100644 --- a/upstash_workflow/asyncio/context/auto_executor.py +++ b/upstash_workflow/asyncio/context/auto_executor.py @@ -84,6 +84,7 @@ async def submit_steps_to_qstash( self.context.retries, lazy_step.retries if isinstance(lazy_step, _LazyCallStep) else None, lazy_step.timeout if isinstance(lazy_step, _LazyCallStep) else None, + self.context.failure_url, ).headers will_wait = ( diff --git a/upstash_workflow/asyncio/context/context.py b/upstash_workflow/asyncio/context/context.py index 4e24069..68d0d2b 100644 --- a/upstash_workflow/asyncio/context/context.py +++ b/upstash_workflow/asyncio/context/context.py @@ -47,6 +47,7 @@ def __init__( headers: Dict[str, str], steps: List[DefaultStep], url: str, + failure_url: Optional[str], initial_payload: TInitialPayload, env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, @@ -55,6 +56,7 @@ def __init__( self.workflow_run_id: str = workflow_run_id self._steps: List[DefaultStep] = steps self.url: str = url + self.failure_url = failure_url self.headers: Dict[str, str] = headers self.request_payload: TInitialPayload = initial_payload self.env: Dict[str, Optional[str]] = env or {} diff --git a/upstash_workflow/asyncio/serve/authorization.py b/upstash_workflow/asyncio/serve/authorization.py index 44563df..7646a34 100644 --- a/upstash_workflow/asyncio/serve/authorization.py +++ b/upstash_workflow/asyncio/serve/authorization.py @@ -38,6 +38,7 @@ async def try_authentication( initial_payload=context.request_payload, env=context.env, retries=context.retries, + failure_url=context.failure_url, ) try: diff --git a/upstash_workflow/asyncio/serve/options.py b/upstash_workflow/asyncio/serve/options.py index b114fd0..1a7ad6f 100644 --- a/upstash_workflow/asyncio/serve/options.py +++ b/upstash_workflow/asyncio/serve/options.py @@ -1,19 +1,42 @@ import os import json import logging -from typing import Callable, Dict, Optional, cast, TypeVar +from typing import Callable, Dict, Optional, cast, TypeVar, Any, Generic, Awaitable from qstash import AsyncQStash, Receiver from upstash_workflow.workflow_types import _Response from upstash_workflow.constants import DEFAULT_RETRIES from upstash_workflow.types import ( _FinishCondition, ) -from upstash_workflow.asyncio.types import ServeBaseOptions +from upstash_workflow import AsyncWorkflowContext -_logger = logging.getLogger(__name__) +from dataclasses import dataclass -TResponse = TypeVar("TResponse") +_logger = logging.getLogger(__name__) TInitialPayload = TypeVar("TInitialPayload") +TResponse = TypeVar("TResponse") + + +@dataclass +class ServeOptions(Generic[TInitialPayload, TResponse]): + qstash_client: AsyncQStash + initial_payload_parser: Callable[[str], TInitialPayload] + receiver: Optional[Receiver] + base_url: Optional[str] + env: Dict[str, Optional[str]] + retries: int + url: Optional[str] + failure_function: Optional[ + Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]] + ] + failure_url: Optional[str] + + +@dataclass +class ServeBaseOptions( + Generic[TInitialPayload, TResponse], ServeOptions[TInitialPayload, TResponse] +): + on_step_finish: Callable[[str, _FinishCondition], TResponse] def _process_options( @@ -26,6 +49,10 @@ def _process_options( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = DEFAULT_RETRIES, url: Optional[str] = None, + failure_function: Optional[ + Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]] + ] = None, + failure_url: Optional[str] = None, ) -> ServeBaseOptions[TInitialPayload, TResponse]: environment = env if env is not None else dict(os.environ) @@ -94,6 +121,8 @@ def _initial_payload_parser(initial_request: str) -> TInitialPayload: env=environment, retries=DEFAULT_RETRIES if retries is None else retries, url=url, + failure_url=failure_url, + failure_function=failure_function, ) diff --git a/upstash_workflow/asyncio/serve/serve.py b/upstash_workflow/asyncio/serve/serve.py index a4220a0..f1d5bc7 100644 --- a/upstash_workflow/asyncio/serve/serve.py +++ b/upstash_workflow/asyncio/serve/serve.py @@ -5,6 +5,7 @@ from upstash_workflow.workflow_types import _Response, _AsyncRequest from upstash_workflow.asyncio.workflow_parser import ( _get_payload, + _handle_failure, ) from upstash_workflow.workflow_parser import _validate_request, _parse_request from upstash_workflow.asyncio.workflow_requests import ( @@ -39,6 +40,10 @@ def _serve_base( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]] + ] = None, + failure_url: Optional[str] = None, ) -> Dict[str, Callable[[TRequest], Awaitable[TResponse]]]: processed_options = _process_options( qstash_client=qstash_client, @@ -49,6 +54,8 @@ def _serve_base( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ) qstash_client = processed_options.qstash_client on_step_finish = processed_options.on_step_finish @@ -58,9 +65,17 @@ def _serve_base( env = processed_options.env retries = processed_options.retries url = processed_options.url + failure_url = processed_options.failure_url + failure_function = processed_options.failure_function async def _handler(request: TRequest) -> TResponse: - workflow_url = _determine_urls(cast(_AsyncRequest, request), url, base_url) + workflow_url, workflow_failure_url = _determine_urls( + cast(_AsyncRequest, request), + url, + base_url, + False if failure_function is None else True, + failure_url, + ) request_payload = await _get_payload(request) or "" _verify_request( @@ -78,6 +93,20 @@ async def _handler(request: TRequest) -> TResponse: raw_initial_payload = parse_request_response.raw_initial_payload steps = parse_request_response.steps + failure_check = await _handle_failure( + request, + request_payload, + qstash_client, + initial_payload_parser, + route_function, + failure_function, + env, + retries, + ) + + if failure_check == "is-failure-callback": + return on_step_finish(workflow_run_id, "failure-callback") + workflow_context = AsyncWorkflowContext( qstash_client=qstash_client, workflow_run_id=workflow_run_id, @@ -89,6 +118,7 @@ async def _handler(request: TRequest) -> TResponse: url=workflow_url, env=env, retries=retries, + failure_url=workflow_failure_url, ) auth_check = await _DisabledWorkflowContext[Any].try_authentication( @@ -110,6 +140,7 @@ async def _handler(request: TRequest) -> TResponse: raw_initial_payload, qstash_client, workflow_url, + workflow_failure_url, retries, ) @@ -153,6 +184,10 @@ def serve( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]] + ] = None, + failure_url: Optional[str] = None, ) -> Dict[str, Callable[[TRequest], Awaitable[TResponse]]]: """ Creates a method that handles incoming requests and runs the provided @@ -178,4 +213,6 @@ def serve( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ) diff --git a/upstash_workflow/asyncio/types.py b/upstash_workflow/asyncio/types.py deleted file mode 100644 index 52e6c62..0000000 --- a/upstash_workflow/asyncio/types.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import ( - Callable, - Optional, - Dict, - Generic, - TypeVar, -) -from qstash import AsyncQStash, Receiver -from dataclasses import dataclass -from upstash_workflow.types import _FinishCondition - -TInitialPayload = TypeVar("TInitialPayload") -TResponse = TypeVar("TResponse") - - -@dataclass -class ServeOptions(Generic[TInitialPayload, TResponse]): - qstash_client: AsyncQStash - initial_payload_parser: Callable[[str], TInitialPayload] - receiver: Optional[Receiver] - base_url: Optional[str] - env: Dict[str, Optional[str]] - retries: int - url: Optional[str] - - -@dataclass -class ServeBaseOptions( - Generic[TInitialPayload, TResponse], ServeOptions[TInitialPayload, TResponse] -): - on_step_finish: Callable[[str, _FinishCondition], TResponse] diff --git a/upstash_workflow/asyncio/workflow_parser.py b/upstash_workflow/asyncio/workflow_parser.py index d30557b..857979c 100644 --- a/upstash_workflow/asyncio/workflow_parser.py +++ b/upstash_workflow/asyncio/workflow_parser.py @@ -1,5 +1,16 @@ -from typing import Optional +from typing import Optional, cast from upstash_workflow.workflow_types import _AsyncRequest +import json +from typing import Callable, Dict, Any, Literal, Awaitable, TypeVar +from upstash_workflow.utils import _decode_base64 +from upstash_workflow.constants import ( + WORKFLOW_FAILURE_HEADER, +) +from upstash_workflow.error import WorkflowError +from upstash_workflow.workflow_requests import _recreate_user_headers +from upstash_workflow.asyncio.serve.authorization import _DisabledWorkflowContext +from qstash import AsyncQStash +from upstash_workflow import AsyncWorkflowContext async def _get_payload(request: _AsyncRequest) -> Optional[str]: @@ -13,3 +24,80 @@ async def _get_payload(request: _AsyncRequest) -> Optional[str]: return (await request.body()).decode() except Exception: return None + + +TInitialPayload = TypeVar("TInitialPayload") +TRequest = TypeVar("TRequest", bound=_AsyncRequest) + + +async def _handle_failure( + request: TRequest, + request_payload: str, + qstash_client: AsyncQStash, + initial_payload_parser: Callable[[str], Any], + route_function: Callable[[AsyncWorkflowContext[TInitialPayload]], Awaitable[None]], + failure_function: Optional[ + Callable[ + [AsyncWorkflowContext[TInitialPayload], int, str, Dict[str, str]], + Awaitable[Any], + ] + ], + env: Dict[str, Any], + retries: int, +) -> Literal["not-failure-callback", "is-failure-callback"]: + if request.headers and request.headers.get(WORKFLOW_FAILURE_HEADER) != "true": + return "not-failure-callback" + + if not failure_function: + raise WorkflowError( + "Workflow endpoint is called to handle a failure, " + "but a failure_function is not provided in serve options. " + "Either provide a failure_url or a failure_function." + ) + + try: + payload = json.loads(request_payload) + status = payload["status"] + header = payload["header"] + body = payload["body"] + url = payload["url"] + source_body = payload["sourceBody"] + workflow_run_id = payload["workflowRunId"] + + decoded_body = _decode_base64(body) if body else "{}" + error_payload = json.loads(decoded_body) + + # Create context + workflow_context = AsyncWorkflowContext( + qstash_client=qstash_client, + workflow_run_id=workflow_run_id, + initial_payload=initial_payload_parser(_decode_base64(source_body)) + if source_body + else None, + headers=_recreate_user_headers(request.headers or {}), + steps=[], + url=url, + failure_url=url, + env=env, + retries=retries, + ) + + # Attempt running route_function until the first step + auth_check = await _DisabledWorkflowContext[Any].try_authentication( + route_function, + cast(AsyncWorkflowContext[TInitialPayload], workflow_context), + ) + + if auth_check == "run-ended": + raise WorkflowError("Not authorized to run the failure function.") + + await failure_function( + cast(AsyncWorkflowContext[TInitialPayload], workflow_context), + status, + error_payload.get("message"), + header, + ) + except Exception as error: + raise error + + return "is-failure-callback" diff --git a/upstash_workflow/asyncio/workflow_requests.py b/upstash_workflow/asyncio/workflow_requests.py index 13c5dae..467fbff 100644 --- a/upstash_workflow/asyncio/workflow_requests.py +++ b/upstash_workflow/asyncio/workflow_requests.py @@ -79,6 +79,7 @@ async def _handle_third_party_call_result( request_payload: str, client: AsyncQStash, workflow_url: str, + workflow_failure_url: Optional[str], retries: int, ) -> Literal["call-will-retry", "is-call-return", "continue-workflow"]: """ @@ -173,6 +174,7 @@ async def _handle_third_party_call_result( user_headers, None, retries, + workflow_failure_url=workflow_failure_url, ).headers call_response = { diff --git a/upstash_workflow/context/auto_executor.py b/upstash_workflow/context/auto_executor.py index e3303a9..63bed04 100644 --- a/upstash_workflow/context/auto_executor.py +++ b/upstash_workflow/context/auto_executor.py @@ -77,6 +77,7 @@ def submit_steps_to_qstash( self.context.retries, lazy_step.retries if isinstance(lazy_step, _LazyCallStep) else None, lazy_step.timeout if isinstance(lazy_step, _LazyCallStep) else None, + self.context.failure_url, ).headers will_wait = ( diff --git a/upstash_workflow/context/context.py b/upstash_workflow/context/context.py index 9f51060..281b5c1 100644 --- a/upstash_workflow/context/context.py +++ b/upstash_workflow/context/context.py @@ -46,6 +46,7 @@ def __init__( headers: Dict[str, str], steps: List[DefaultStep], url: str, + failure_url: Optional[str], initial_payload: TInitialPayload, env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, @@ -54,6 +55,7 @@ def __init__( self.workflow_run_id: str = workflow_run_id self._steps: List[DefaultStep] = steps self.url: str = url + self.failure_url = failure_url self.headers: Dict[str, str] = headers self.request_payload: TInitialPayload = initial_payload self.env: Dict[str, Optional[str]] = env or {} diff --git a/upstash_workflow/fastapi.py b/upstash_workflow/fastapi.py index dda95ff..57431a0 100644 --- a/upstash_workflow/fastapi.py +++ b/upstash_workflow/fastapi.py @@ -3,7 +3,7 @@ import os from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from typing import Callable, Awaitable, cast, TypeVar, Optional, Dict +from typing import Callable, Awaitable, cast, TypeVar, Optional, Dict, Any from qstash import AsyncQStash, Receiver from upstash_workflow import async_serve, AsyncWorkflowContext from upstash_workflow.workflow_types import _Response as WorkflowResponse @@ -29,6 +29,10 @@ def post( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]] + ] = None, + failure_url: Optional[str] = None, ) -> Callable[ [AsyncRouteFunction[TInitialPayload]], AsyncRouteFunction[TInitialPayload] ]: @@ -74,6 +78,8 @@ def decorator( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ).get("handler"), ) diff --git a/upstash_workflow/flask.py b/upstash_workflow/flask.py index 7de9626..abfe119 100644 --- a/upstash_workflow/flask.py +++ b/upstash_workflow/flask.py @@ -2,7 +2,7 @@ import os from flask import Flask, request from werkzeug.wrappers import Response -from typing import Callable, cast, TypeVar, Optional, Dict +from typing import Callable, cast, TypeVar, Optional, Dict, Any from qstash import QStash, Receiver from upstash_workflow import serve, WorkflowContext from upstash_workflow.workflow_types import ( @@ -32,6 +32,10 @@ def route( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[WorkflowContext, int, str, Dict[str, str]], Any] + ] = None, + failure_url: Optional[str] = None, ) -> Callable[ [RouteFunction[TInitialPayload]], RouteFunction[TInitialPayload], @@ -83,6 +87,8 @@ def decorator( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ).get("handler"), ) diff --git a/upstash_workflow/serve/authorization.py b/upstash_workflow/serve/authorization.py index d8e1b2d..2e897f6 100644 --- a/upstash_workflow/serve/authorization.py +++ b/upstash_workflow/serve/authorization.py @@ -76,6 +76,7 @@ def try_authentication( initial_payload=context.request_payload, env=context.env, retries=context.retries, + failure_url=context.failure_url, ) try: diff --git a/upstash_workflow/serve/options.py b/upstash_workflow/serve/options.py index 1ae4428..8f5599e 100644 --- a/upstash_workflow/serve/options.py +++ b/upstash_workflow/serve/options.py @@ -2,14 +2,26 @@ import json import re import logging -from typing import Callable, Dict, Optional, cast, TypeVar, Match, Union +from typing import ( + Callable, + Dict, + Optional, + cast, + TypeVar, + Match, + Union, + Any, + Generic, + Tuple, +) from qstash import QStash, Receiver from upstash_workflow.workflow_types import _Response, _SyncRequest, _AsyncRequest from upstash_workflow.constants import DEFAULT_RETRIES from upstash_workflow.types import ( _FinishCondition, - ServeBaseOptions, ) +from upstash_workflow import WorkflowContext +from dataclasses import dataclass _logger = logging.getLogger(__name__) @@ -17,6 +29,28 @@ TInitialPayload = TypeVar("TInitialPayload") +@dataclass +class ServeOptions(Generic[TInitialPayload, TResponse]): + qstash_client: QStash + initial_payload_parser: Callable[[str], TInitialPayload] + receiver: Optional[Receiver] + base_url: Optional[str] + env: Dict[str, Optional[str]] + retries: int + url: Optional[str] + failure_function: Optional[ + Callable[[WorkflowContext[TInitialPayload], int, str, Dict[str, str]], Any] + ] + failure_url: Optional[str] + + +@dataclass +class ServeBaseOptions( + Generic[TInitialPayload, TResponse], ServeOptions[TInitialPayload, TResponse] +): + on_step_finish: Callable[[str, _FinishCondition], TResponse] + + def _process_options( *, qstash_client: Optional[QStash] = None, @@ -27,6 +61,10 @@ def _process_options( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = DEFAULT_RETRIES, url: Optional[str] = None, + failure_function: Optional[ + Callable[[WorkflowContext, int, str, Dict[str, str]], Any] + ] = None, + failure_url: Optional[str] = None, ) -> ServeBaseOptions[TInitialPayload, TResponse]: """ Fills the options with default values if they are not provided. @@ -108,6 +146,8 @@ def _initial_payload_parser(initial_request: str) -> TInitialPayload: env=environment, retries=DEFAULT_RETRIES if retries is None else retries, url=url, + failure_url=failure_url, + failure_function=failure_function, ) @@ -115,7 +155,9 @@ def _determine_urls( request: Union[_SyncRequest, _AsyncRequest], url: Optional[str], base_url: Optional[str], -) -> str: + failure_function_exists: bool, + failure_url: Optional[str], +) -> Tuple[str, Optional[str]]: initial_workflow_url = str(url if url is not None else request.url) if base_url: @@ -130,7 +172,8 @@ def replace_base(match: Match[str]) -> str: else: workflow_url = initial_workflow_url - return workflow_url + workflow_failure_url = workflow_url if failure_function_exists else failure_url + return (workflow_url, workflow_failure_url) AUTH_FAIL_MESSAGE = "Failed to authenticate Workflow request. If this is unexpected, see the caveat https://upstash.com/docs/workflow/basics/caveats#avoid-non-deterministic-code-outside-context-run" diff --git a/upstash_workflow/serve/serve.py b/upstash_workflow/serve/serve.py index fe8cf82..8476ee2 100644 --- a/upstash_workflow/serve/serve.py +++ b/upstash_workflow/serve/serve.py @@ -7,6 +7,7 @@ _get_payload, _validate_request, _parse_request, + _handle_failure, ) from upstash_workflow.workflow_requests import ( _verify_request, @@ -40,6 +41,10 @@ def _serve_base( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[WorkflowContext, int, str, Dict[str, str]], Any] + ] = None, + failure_url: Optional[str] = None, ) -> Dict[str, Callable[[TRequest], TResponse]]: processed_options = _process_options( qstash_client=qstash_client, @@ -50,6 +55,8 @@ def _serve_base( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ) qstash_client = processed_options.qstash_client on_step_finish = processed_options.on_step_finish @@ -59,6 +66,8 @@ def _serve_base( env = processed_options.env retries = processed_options.retries url = processed_options.url + failure_url = processed_options.failure_url + failure_function = processed_options.failure_function def _handler(request: TRequest) -> TResponse: """ @@ -70,7 +79,13 @@ def _handler(request: TRequest) -> TResponse: :param request: The incoming request to handle. :return: A response. """ - workflow_url = _determine_urls(cast(_SyncRequest, request), url, base_url) + workflow_url, workflow_failure_url = _determine_urls( + cast(_SyncRequest, request), + url, + base_url, + False if failure_function is None else True, + failure_url, + ) request_payload = _get_payload(request) or "" _verify_request( @@ -88,6 +103,20 @@ def _handler(request: TRequest) -> TResponse: raw_initial_payload = parse_request_response.raw_initial_payload steps = parse_request_response.steps + failure_check = _handle_failure( + request, + request_payload, + qstash_client, + initial_payload_parser, + route_function, + failure_function, + env, + retries, + ) + + if failure_check == "is-failure-callback": + return on_step_finish(workflow_run_id, "failure-callback") + workflow_context = WorkflowContext( qstash_client=qstash_client, workflow_run_id=workflow_run_id, @@ -99,6 +128,7 @@ def _handler(request: TRequest) -> TResponse: url=workflow_url, env=env, retries=retries, + failure_url=workflow_failure_url, ) auth_check = _DisabledWorkflowContext[Any].try_authentication( @@ -120,6 +150,7 @@ def _handler(request: TRequest) -> TResponse: raw_initial_payload, qstash_client, workflow_url, + workflow_failure_url, retries, ) @@ -163,6 +194,10 @@ def serve( env: Optional[Dict[str, Optional[str]]] = None, retries: Optional[int] = None, url: Optional[str] = None, + failure_function: Optional[ + Callable[[WorkflowContext, int, str, Dict[str, str]], Any] + ] = None, + failure_url: Optional[str] = None, ) -> Dict[str, Callable[[TRequest], TResponse]]: """ Creates a method that handles incoming requests and runs the provided @@ -188,4 +223,6 @@ def serve( env=env, retries=retries, url=url, + failure_function=failure_function, + failure_url=failure_url, ) diff --git a/upstash_workflow/types.py b/upstash_workflow/types.py index 89166cd..73efdf8 100644 --- a/upstash_workflow/types.py +++ b/upstash_workflow/types.py @@ -1,5 +1,4 @@ from typing import ( - Callable, Literal, Optional, Dict, @@ -10,10 +9,8 @@ Any, TypedDict, ) -from qstash import QStash, Receiver from dataclasses import dataclass - _FinishCondition = Literal[ "success", "duplicate-step", @@ -26,24 +23,6 @@ TResponse = TypeVar("TResponse") -@dataclass -class ServeOptions(Generic[TInitialPayload, TResponse]): - qstash_client: QStash - initial_payload_parser: Callable[[str], TInitialPayload] - receiver: Optional[Receiver] - base_url: Optional[str] - env: Dict[str, Optional[str]] - retries: int - url: Optional[str] - - -@dataclass -class ServeBaseOptions( - Generic[TInitialPayload, TResponse], ServeOptions[TInitialPayload, TResponse] -): - on_step_finish: Callable[[str, _FinishCondition], TResponse] - - StepTypes = [ "Initial", "Run", diff --git a/upstash_workflow/workflow_parser.py b/upstash_workflow/workflow_parser.py index 5ef9e67..454d592 100644 --- a/upstash_workflow/workflow_parser.py +++ b/upstash_workflow/workflow_parser.py @@ -1,12 +1,25 @@ import json -from typing import Optional, List, Tuple, Union +from typing import ( + Optional, + List, + Tuple, + Union, + Callable, + Dict, + Any, + Literal, + TypeVar, + cast, +) from upstash_workflow.utils import _nanoid, _decode_base64 from upstash_workflow.constants import ( WORKFLOW_PROTOCOL_VERSION, WORKFLOW_PROTOCOL_VERSION_HEADER, + WORKFLOW_FAILURE_HEADER, WORKFLOW_ID_HEADER, NO_CONCURRENCY, ) +from qstash import QStash from upstash_workflow.error import WorkflowError from upstash_workflow.types import ( Step, @@ -15,6 +28,9 @@ _ParseRequestResponse, ) from upstash_workflow.workflow_types import _SyncRequest, _AsyncRequest +from upstash_workflow import WorkflowContext +from upstash_workflow.workflow_requests import _recreate_user_headers +from upstash_workflow.serve.authorization import _DisabledWorkflowContext def _get_payload(request: _SyncRequest) -> Optional[str]: @@ -168,3 +184,76 @@ def _parse_request( return _ParseRequestResponse( raw_initial_payload=raw_initial_payload, steps=steps ) + + +TInitialPayload = TypeVar("TInitialPayload") +TRequest = TypeVar("TRequest", bound=_SyncRequest) + + +def _handle_failure( + request: TRequest, + request_payload: str, + qstash_client: QStash, + initial_payload_parser: Callable[[str], Any], + route_function: Callable[[WorkflowContext[TInitialPayload]], None], + failure_function: Optional[ + Callable[[WorkflowContext[TInitialPayload], int, str, Dict[str, str]], Any] + ], + env: Dict[str, Any], + retries: int, +) -> Literal["not-failure-callback", "is-failure-callback"]: + if request.headers and request.headers.get(WORKFLOW_FAILURE_HEADER) != "true": + return "not-failure-callback" + + if not failure_function: + raise WorkflowError( + "Workflow endpoint is called to handle a failure, " + "but a failure_function is not provided in serve options. " + "Either provide a failure_url or a failure_function." + ) + + try: + payload = json.loads(request_payload) + status = payload["status"] + header = payload["header"] + body = payload["body"] + url = payload["url"] + source_body = payload["sourceBody"] + workflow_run_id = payload["workflowRunId"] + + decoded_body = _decode_base64(body) if body else "{}" + error_payload = json.loads(decoded_body) + + # Create context + workflow_context = WorkflowContext( + qstash_client=qstash_client, + workflow_run_id=workflow_run_id, + initial_payload=initial_payload_parser(_decode_base64(source_body)) + if source_body + else None, + headers=_recreate_user_headers(request.headers or {}), + steps=[], + url=url, + failure_url=url, + env=env, + retries=retries, + ) + + # Attempt running route_function until the first step + auth_check = _DisabledWorkflowContext[Any].try_authentication( + route_function, cast(WorkflowContext[TInitialPayload], workflow_context) + ) + + if auth_check == "run-ended": + raise WorkflowError("Not authorized to run the failure function.") + + failure_function( + cast(WorkflowContext[TInitialPayload], workflow_context), + status, + error_payload.get("message"), + header, + ) + except Exception as error: + raise error + + return "is-failure-callback" diff --git a/upstash_workflow/workflow_requests.py b/upstash_workflow/workflow_requests.py index d7e13ef..9c5cdcd 100644 --- a/upstash_workflow/workflow_requests.py +++ b/upstash_workflow/workflow_requests.py @@ -20,8 +20,10 @@ WORKFLOW_URL_HEADER, WORKFLOW_PROTOCOL_VERSION, WORKFLOW_PROTOCOL_VERSION_HEADER, - DEFAULT_CONTENT_TYPE, WORKFLOW_FEATURE_HEADER, + WORKFLOW_FAILURE_HEADER, + DEFAULT_CONTENT_TYPE, + DEFAULT_RETRIES, ) from upstash_workflow.types import StepTypes, DefaultStep, _HeadersResponse from upstash_workflow.workflow_types import _SyncRequest @@ -109,6 +111,7 @@ def _handle_third_party_call_result( request_payload: str, client: QStash, workflow_url: str, + workflow_failure_url: Optional[str], retries: int, ) -> Literal["call-will-retry", "is-call-return", "continue-workflow"]: """ @@ -203,6 +206,7 @@ def _handle_third_party_call_result( user_headers, None, retries, + workflow_failure_url=workflow_failure_url, ).headers call_response = { @@ -238,6 +242,10 @@ def _handle_third_party_call_result( ) +def _should_set_retries(retries: Optional[int]): + return retries is not None and retries != DEFAULT_RETRIES + + def _get_headers( init_header_value: Literal["true", "false"], workflow_run_id: str, @@ -247,6 +255,7 @@ def _get_headers( retries: Optional[int] = None, call_retries: Optional[int] = None, call_timeout: Optional[Union[int, str]] = None, + workflow_failure_url: Optional[str] = None, ) -> _HeadersResponse: """ Gets headers for calling QStash @@ -273,6 +282,45 @@ def _get_headers( if call_timeout: base_headers["Upstash-Timeout"] = str(call_timeout) + if workflow_failure_url: + base_headers[f"Upstash-Failure-Callback-Forward-{WORKFLOW_FAILURE_HEADER}"] = ( + "true" + ) + base_headers[ + "Upstash-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback" + ] = "true" + base_headers["Upstash-Failure-Callback-Workflow-Runid"] = workflow_run_id + base_headers["Upstash-Failure-Callback-Workflow-Init"] = "false" + base_headers["Upstash-Failure-Callback-Workflow-Url"] = workflow_url + base_headers["Upstash-Failure-Callback-Workflow-Calltype"] = "failureCall" + if step and step.call_url: + base_headers[ + f"Upstash-Callback-Failure-Callback-Forward-{WORKFLOW_FAILURE_HEADER}" + ] = "true" + base_headers[ + "Upstash-Callback-Failure-Callback-Forward-Upstash-Workflow-Failure-Callback" + ] = "true" + base_headers["Upstash-Callback-Failure-Callback-Workflow-Runid"] = ( + workflow_run_id + ) + base_headers["Upstash-Callback-Failure-Callback-Workflow-Init"] = "false" + base_headers["Upstash-Callback-Failure-Callback-Workflow-Url"] = ( + workflow_url + ) + base_headers["Upstash-Callback-Failure-Callback-Workflow-Calltype"] = ( + "failureCall" + ) + + if _should_set_retries(retries): + base_headers["Upstash-Failure-Callback-Retries"] = str(retries) + if step and step.call_url: + base_headers["Upstash-Callback-Failure-Callback-Retries"] = str(retries) + + if not step or not step.call_url: + base_headers["Upstash-Failure-Callback"] = workflow_failure_url + if step and step.call_url: + base_headers["Upstash-Callback-Failure-Callback"] = workflow_failure_url + if step and step.call_url: base_headers["Upstash-Retries"] = str( call_retries if call_retries is not None else 0 @@ -282,7 +330,7 @@ def _get_headers( if retries is not None: base_headers["Upstash-Callback-Retries"] = str(retries) base_headers["Upstash-Failure-Callback-Retries"] = str(retries) - elif retries is not None: + elif _should_set_retries(retries): base_headers["Upstash-Retries"] = str(retries) base_headers["Upstash-Failure-Callback-Retries"] = str(retries)