Skip to content

Commit 9bee258

Browse files
committed
feat: add failure functions
1 parent 35351ff commit 9bee258

File tree

19 files changed

+371
-33
lines changed

19 files changed

+371
-33
lines changed

tests/asyncio/test_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async def test_workflow_headers(qstash_client: AsyncQStash) -> None:
3131
initial_payload="my-payload",
3232
env=None,
3333
retries=None,
34+
failure_url=None
3435
)
3536

3637
async def execute() -> None:

tests/test_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_workflow_headers(qstash_client: QStash) -> None:
3030
initial_payload="my-payload",
3131
env=None,
3232
retries=None,
33+
failure_url=WORKFLOW_ENDPOINT
3334
)
3435

3536
def execute() -> None:

upstash_workflow/asyncio/context/auto_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ async def submit_steps_to_qstash(
8484
self.context.retries,
8585
lazy_step.retries if isinstance(lazy_step, _LazyCallStep) else None,
8686
lazy_step.timeout if isinstance(lazy_step, _LazyCallStep) else None,
87+
workflow_failure_url=self.context.failure_url,
8788
).headers
8889

8990
will_wait = (

upstash_workflow/asyncio/context/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
headers: Dict[str, str],
4848
steps: List[DefaultStep],
4949
url: str,
50+
failure_url: Optional[str],
5051
initial_payload: TInitialPayload,
5152
env: Optional[Dict[str, Optional[str]]] = None,
5253
retries: Optional[int] = None,
@@ -55,6 +56,7 @@ def __init__(
5556
self.workflow_run_id: str = workflow_run_id
5657
self._steps: List[DefaultStep] = steps
5758
self.url: str = url
59+
self.failure_url = failure_url
5860
self.headers: Dict[str, str] = headers
5961
self.request_payload: TInitialPayload = initial_payload
6062
self.env: Dict[str, Optional[str]] = env or {}

upstash_workflow/asyncio/serve/authorization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def try_authentication(
3838
initial_payload=context.request_payload,
3939
env=context.env,
4040
retries=context.retries,
41+
failure_url=context.failure_url,
4142
)
4243

4344
try:

upstash_workflow/asyncio/serve/options.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
22
import json
33
import logging
4-
from typing import Callable, Dict, Optional, cast, TypeVar
4+
from typing import Callable, Dict, Optional, cast, TypeVar, Any
55
from qstash import AsyncQStash, Receiver
66
from upstash_workflow.workflow_types import _Response
77
from upstash_workflow.constants import DEFAULT_RETRIES
88
from upstash_workflow.types import (
99
_FinishCondition,
1010
)
1111
from upstash_workflow.asyncio.types import ServeBaseOptions
12+
from upstash_workflow import AsyncWorkflowContext
1213

1314
_logger = logging.getLogger(__name__)
1415

@@ -26,6 +27,10 @@ def _process_options(
2627
env: Optional[Dict[str, Optional[str]]] = None,
2728
retries: Optional[int] = DEFAULT_RETRIES,
2829
url: Optional[str] = None,
30+
failure_function: Optional[
31+
Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Any]
32+
] = None,
33+
failure_url: Optional[str] = None,
2934
) -> ServeBaseOptions[TInitialPayload, TResponse]:
3035
environment = env if env is not None else dict(os.environ)
3136

@@ -94,6 +99,8 @@ def _initial_payload_parser(initial_request: str) -> TInitialPayload:
9499
env=environment,
95100
retries=DEFAULT_RETRIES if retries is None else retries,
96101
url=url,
102+
failure_url=failure_url,
103+
failure_function=failure_function,
97104
)
98105

99106

upstash_workflow/asyncio/serve/serve.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from upstash_workflow.workflow_types import _Response, _AsyncRequest
66
from upstash_workflow.asyncio.workflow_parser import (
77
_get_payload,
8+
_handle_failure,
89
)
910
from upstash_workflow.workflow_parser import _validate_request, _parse_request
1011
from upstash_workflow.asyncio.workflow_requests import (
@@ -39,6 +40,10 @@ def _serve_base(
3940
env: Optional[Dict[str, Optional[str]]] = None,
4041
retries: Optional[int] = None,
4142
url: Optional[str] = None,
43+
failure_function: Optional[
44+
Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]]
45+
] = None,
46+
failure_url: Optional[str] = None,
4247
) -> Dict[str, Callable[[TRequest], Awaitable[TResponse]]]:
4348
processed_options = _process_options(
4449
qstash_client=qstash_client,
@@ -49,6 +54,8 @@ def _serve_base(
4954
env=env,
5055
retries=retries,
5156
url=url,
57+
failure_function=failure_function,
58+
failure_url=failure_url,
5259
)
5360
qstash_client = processed_options.qstash_client
5461
on_step_finish = processed_options.on_step_finish
@@ -58,9 +65,13 @@ def _serve_base(
5865
env = processed_options.env
5966
retries = processed_options.retries
6067
url = processed_options.url
68+
failure_url = processed_options.failure_url
69+
failure_function = processed_options.failure_function
6170

6271
async def _handler(request: TRequest) -> TResponse:
63-
workflow_url = _determine_urls(cast(_AsyncRequest, request), url, base_url)
72+
workflow_url, workflow_failure_url = _determine_urls(
73+
cast(_AsyncRequest, request), url, base_url, failure_function, failure_url
74+
)
6475

6576
request_payload = await _get_payload(request) or ""
6677
_verify_request(
@@ -78,6 +89,20 @@ async def _handler(request: TRequest) -> TResponse:
7889
raw_initial_payload = parse_request_response.raw_initial_payload
7990
steps = parse_request_response.steps
8091

92+
failure_check = await _handle_failure(
93+
request,
94+
request_payload,
95+
qstash_client,
96+
initial_payload_parser,
97+
route_function,
98+
failure_function,
99+
env,
100+
retries,
101+
)
102+
103+
if failure_check == "is-failure-callback":
104+
return on_step_finish(workflow_run_id, "failure-callback")
105+
81106
workflow_context = AsyncWorkflowContext(
82107
qstash_client=qstash_client,
83108
workflow_run_id=workflow_run_id,
@@ -89,6 +114,7 @@ async def _handler(request: TRequest) -> TResponse:
89114
url=workflow_url,
90115
env=env,
91116
retries=retries,
117+
failure_url=workflow_failure_url,
92118
)
93119

94120
auth_check = await _DisabledWorkflowContext[Any].try_authentication(
@@ -110,6 +136,7 @@ async def _handler(request: TRequest) -> TResponse:
110136
raw_initial_payload,
111137
qstash_client,
112138
workflow_url,
139+
workflow_failure_url,
113140
retries,
114141
)
115142

@@ -153,6 +180,10 @@ def serve(
153180
env: Optional[Dict[str, Optional[str]]] = None,
154181
retries: Optional[int] = None,
155182
url: Optional[str] = None,
183+
failure_function: Optional[
184+
Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Awaitable[Any]]
185+
] = None,
186+
failure_url: Optional[str] = None,
156187
) -> Dict[str, Callable[[TRequest], Awaitable[TResponse]]]:
157188
"""
158189
Creates a method that handles incoming requests and runs the provided
@@ -178,4 +209,6 @@ def serve(
178209
env=env,
179210
retries=retries,
180211
url=url,
212+
failure_function=failure_function,
213+
failure_url=failure_url,
181214
)

upstash_workflow/asyncio/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
Dict,
55
Generic,
66
TypeVar,
7+
Any,
78
)
89
from qstash import AsyncQStash, Receiver
910
from dataclasses import dataclass
1011
from upstash_workflow.types import _FinishCondition
12+
from upstash_workflow import AsyncWorkflowContext
1113

1214
TInitialPayload = TypeVar("TInitialPayload")
1315
TResponse = TypeVar("TResponse")
@@ -22,6 +24,10 @@ class ServeOptions(Generic[TInitialPayload, TResponse]):
2224
env: Dict[str, Optional[str]]
2325
retries: int
2426
url: Optional[str]
27+
failure_function: Optional[
28+
Callable[[AsyncWorkflowContext, int, str, Dict[str, str]], Any]
29+
]
30+
failure_url: Optional[str]
2531

2632

2733
@dataclass

upstash_workflow/asyncio/workflow_parser.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1-
from typing import Optional
1+
from typing import Optional, cast
22
from upstash_workflow.workflow_types import _AsyncRequest
3+
import json
4+
from typing import Callable, Dict, Any, Literal, Awaitable, TypeVar
5+
from upstash_workflow.utils import _decode_base64
6+
from upstash_workflow.constants import (
7+
WORKFLOW_FAILURE_HEADER,
8+
)
9+
from upstash_workflow.error import WorkflowError
10+
from upstash_workflow.workflow_requests import _recreate_user_headers
11+
from upstash_workflow.asyncio.serve.authorization import _DisabledWorkflowContext
12+
from qstash import AsyncQStash
13+
from upstash_workflow import AsyncWorkflowContext
314

415

516
async def _get_payload(request: _AsyncRequest) -> Optional[str]:
@@ -13,3 +24,79 @@ async def _get_payload(request: _AsyncRequest) -> Optional[str]:
1324
return (await request.body()).decode()
1425
except Exception:
1526
return None
27+
28+
29+
TInitialPayload = TypeVar("TInitialPayload")
30+
TRequest = TypeVar("TRequest", bound=_AsyncRequest)
31+
32+
33+
async def _handle_failure(
34+
request: TRequest,
35+
request_payload: str,
36+
qstash_client: AsyncQStash,
37+
initial_payload_parser: Callable[[str], Any],
38+
route_function: Callable[[AsyncWorkflowContext[TInitialPayload]], Awaitable[None]],
39+
failure_function: Optional[
40+
Callable[
41+
[AsyncWorkflowContext[TInitialPayload], int, str, Dict[str, str]],
42+
Awaitable[None],
43+
]
44+
],
45+
env: Dict[str, Any],
46+
retries: int,
47+
) -> Literal["not-failure-callback", "is-failure-callback"]:
48+
if request.headers and request.headers.get(WORKFLOW_FAILURE_HEADER) != "true":
49+
return "not-failure-callback"
50+
51+
if not failure_function:
52+
raise WorkflowError(
53+
"Workflow endpoint is called to handle a failure, "
54+
"but a failure_function is not provided in serve options. "
55+
"Either provide a failure_url or a failure_function."
56+
)
57+
58+
try:
59+
payload = json.loads(request_payload)
60+
status = payload["status"]
61+
header = payload["header"]
62+
body = payload["body"]
63+
url = payload["url"]
64+
source_body = payload["sourceBody"]
65+
workflow_run_id = payload["workflowRunId"]
66+
67+
decoded_body = _decode_base64(body) if body else "{}"
68+
error_payload = json.loads(decoded_body)
69+
70+
# Create context
71+
workflow_context = AsyncWorkflowContext(
72+
qstash_client=qstash_client,
73+
workflow_run_id=workflow_run_id,
74+
initial_payload=initial_payload_parser(_decode_base64(source_body))
75+
if source_body
76+
else None,
77+
headers=_recreate_user_headers(request.headers or {}),
78+
steps=[],
79+
url=url,
80+
failure_url=url,
81+
env=env,
82+
retries=retries,
83+
)
84+
85+
# Attempt running route_function until the first step
86+
auth_check = await _DisabledWorkflowContext[Any].try_authentication(
87+
route_function, cast(AsyncWorkflowContext[TInitialPayload], workflow_context)
88+
)
89+
90+
if auth_check == "run-ended":
91+
raise WorkflowError("Not authorized to run the failure function.")
92+
93+
await failure_function(
94+
cast(AsyncWorkflowContext[TInitialPayload], workflow_context),
95+
status,
96+
error_payload.get("message"),
97+
header,
98+
)
99+
except Exception as error:
100+
raise error
101+
102+
return "is-failure-callback"

upstash_workflow/asyncio/workflow_requests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def _handle_third_party_call_result(
7979
request_payload: str,
8080
client: AsyncQStash,
8181
workflow_url: str,
82+
workflow_failure_url: Optional[str],
8283
retries: int,
8384
) -> Literal["call-will-retry", "is-call-return", "continue-workflow"]:
8485
"""
@@ -173,6 +174,7 @@ async def _handle_third_party_call_result(
173174
user_headers,
174175
None,
175176
retries,
177+
workflow_failure_url=workflow_failure_url,
176178
).headers
177179

178180
call_response = {

0 commit comments

Comments
 (0)