Skip to content

Commit e7b2f82

Browse files
dmontaguDouweM
andauthored
Fix bug with running graphs in temporal workflows (#3460)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 80a1ce0 commit e7b2f82

File tree

5 files changed

+170
-27
lines changed

5 files changed

+170
-27
lines changed

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
88

99
from pydantic_graph import BaseNode, End, GraphRunContext
10-
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
10+
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem
1111
from pydantic_graph.beta.step import NodeStep
1212

1313
from . import (
@@ -181,7 +181,7 @@ async def __anext__(
181181
return self._task_to_node(task)
182182

183183
def _task_to_node(
184-
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
184+
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest]
185185
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
186186
if isinstance(task, Sequence) and len(task) == 1:
187187
first_task = task[0]
@@ -197,8 +197,8 @@ def _task_to_node(
197197
return End(task.value)
198198
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
199199

200-
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
201-
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
200+
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest:
201+
return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=())
202202

203203
async def next(
204204
self,

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from __future__ import annotations as _annotations
99

1010
import sys
11-
import uuid
12-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
11+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
1312
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
1413
from dataclasses import dataclass, field
1514
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
@@ -22,7 +21,7 @@
2221
from pydantic_graph import exceptions
2322
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
2423
from pydantic_graph.beta.decision import Decision
25-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
24+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
2625
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
2726
from pydantic_graph.beta.node import (
2827
EndNode,
@@ -306,14 +305,13 @@ def __str__(self) -> str:
306305

307306

308307
@dataclass
309-
class GraphTask:
310-
"""A single task representing the execution of a node in the graph.
308+
class GraphTaskRequest:
309+
"""A request to run a task representing the execution of a node in the graph.
311310
312-
GraphTask encapsulates all the information needed to execute a specific
311+
GraphTaskRequest encapsulates all the information needed to execute a specific
313312
node, including its inputs and the fork context it's executing within.
314313
"""
315314

316-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317315
node_id: NodeID
318316
"""The ID of the node to execute."""
319317

@@ -326,9 +324,26 @@ class GraphTask:
326324
Used by the GraphRun to decide when to proceed through joins.
327325
"""
328326

329-
task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
327+
328+
@dataclass
329+
class GraphTask(GraphTaskRequest):
330+
"""A task representing the execution of a node in the graph.
331+
332+
GraphTask encapsulates all the information needed to execute a specific
333+
node, including its inputs and the fork context it's executing within,
334+
and has a unique ID to identify the task within the graph run.
335+
"""
336+
337+
task_id: TaskID = field(repr=False)
330338
"""Unique identifier for this task."""
331339

340+
@staticmethod
341+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342+
# Don't call the get_task_id callable, this is already a task
343+
if isinstance(request, GraphTask):
344+
return request
345+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346+
332347

333348
class GraphRun(Generic[StateT, DepsT, OutputT]):
334349
"""A single execution instance of a graph.
@@ -378,12 +393,20 @@ def __init__(
378393
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379394
"""The next item to be processed."""
380395

381-
run_id = GraphRunID(str(uuid.uuid4()))
382-
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383-
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
396+
self._next_task_id = 0
397+
self._next_node_run_id = 0
398+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
399+
self._first_task = GraphTask(
400+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
401+
)
384402
self._iterator_task_group = create_task_group()
385403
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386-
self.graph, self.state, self.deps, self._iterator_task_group
404+
self.graph,
405+
self.state,
406+
self.deps,
407+
self._iterator_task_group,
408+
self._get_next_node_run_id,
409+
self._get_next_task_id,
387410
)
388411
self._iterator = self._iterator_instance.iter_graph(self._first_task)
389412

@@ -449,7 +472,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
449472
return self._next
450473

451474
async def next(
452-
self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
475+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
453476
) -> EndMarker[OutputT] | Sequence[GraphTask]:
454477
"""Advance the graph execution by one step.
455478
@@ -467,7 +490,10 @@ async def next(
467490
# if `next` is called before the `first_node` has run.
468491
await anext(self)
469492
if value is not None:
470-
self._next = value
493+
if isinstance(value, EndMarker):
494+
self._next = value
495+
else:
496+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
471497
return await anext(self)
472498

473499
@property
@@ -490,6 +516,16 @@ def output(self) -> OutputT | None:
490516
return self._next.value
491517
return None
492518

519+
def _get_next_task_id(self) -> TaskID:
520+
next_id = TaskID(f'task:{self._next_task_id}')
521+
self._next_task_id += 1
522+
return next_id
523+
524+
def _get_next_node_run_id(self) -> NodeRunID:
525+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
526+
self._next_node_run_id += 1
527+
return next_id
528+
493529

494530
@dataclass
495531
class _GraphTaskAsyncIterable:
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510546
state: StateT
511547
deps: DepsT
512548
task_group: TaskGroup
549+
get_next_node_run_id: Callable[[], NodeRunID]
550+
get_next_task_id: Callable[[], TaskID]
513551

514552
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
515553
active_tasks: dict[TaskID, GraphTask] = field(init=False)
@@ -522,6 +560,7 @@ def __post_init__(self):
522560
self.active_tasks = {}
523561
self.active_reducers = {}
524562
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
563+
self._next_node_run_id = 1
525564

526565
async def iter_graph( # noqa C901
527566
self, first_task: GraphTask
@@ -782,12 +821,12 @@ def _handle_node(
782821
fork_stack: ForkStack,
783822
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
784823
if isinstance(next_node, StepNode):
785-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
824+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
786825
elif isinstance(next_node, JoinNode):
787826
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
788827
elif isinstance(next_node, BaseNode):
789828
node_step = NodeStep(next_node.__class__)
790-
return [GraphTask(node_step.id, next_node, fork_stack)]
829+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
791830
elif isinstance(next_node, End):
792831
return EndMarker(next_node.data)
793832
else:
@@ -821,7 +860,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
821860
'These markers should be removed from paths during graph building'
822861
)
823862
if isinstance(item, DestinationMarker):
824-
return [GraphTask(item.destination_id, inputs, fork_stack)]
863+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
825864
elif isinstance(item, TransformMarker):
826865
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
827866
return self._handle_path(path.next_path, inputs, fork_stack)
@@ -853,7 +892,7 @@ def _handle_fork_edges(
853892
) # this should have already been ensured during graph building
854893

855894
new_tasks: list[GraphTask] = []
856-
node_run_id = NodeRunID(str(uuid.uuid4()))
895+
node_run_id = self.get_next_node_run_id()
857896
if node.is_map:
858897
# If the map specifies a downstream join id, eagerly create a join state for it
859898
if (join_id := node.downstream_join_id) is not None:

pydantic_graph/pydantic_graph/beta/id_types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
ForkID = NodeID
2525
"""Alias for NodeId when referring to fork nodes."""
2626

27-
GraphRunID = NewType('GraphRunID', str)
28-
"""Unique identifier for a complete graph execution run."""
29-
3027
TaskID = NewType('TaskID', str)
3128
"""Unique identifier for a task within the graph execution."""
3229

tests/graph/beta/test_graph_iteration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from pydantic_graph.beta import GraphBuilder, StepContext
11-
from pydantic_graph.beta.graph import EndMarker, GraphTask
11+
from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest
1212
from pydantic_graph.beta.id_types import NodeID
1313
from pydantic_graph.beta.join import reduce_list_append
1414

@@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
400400
# Get the fork_stack from the EndMarker's source
401401
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()
402402

403-
new_task = GraphTask(
403+
new_task = GraphTaskRequest(
404404
node_id=NodeID('second_step'),
405405
inputs=event.value,
406406
fork_stack=fork_stack,

tests/test_temporal.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from pydantic_ai.run import AgentRunResult
5050
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition
5151
from pydantic_ai.usage import RequestUsage
52+
from pydantic_graph.beta import GraphBuilder, StepContext
53+
from pydantic_graph.beta.join import reduce_list_append
5254

5355
try:
5456
from temporalio import workflow
@@ -2228,3 +2230,108 @@ async def test_fastmcp_toolset(allow_model_requests: None, client: Client):
22282230
assert output == snapshot(
22292231
'The `pydantic/pydantic-ai` repository is a Python agent framework crafted for developing production-grade Generative AI applications. It emphasizes type safety, model-agnostic design, and extensibility. The framework supports various LLM providers, manages agent workflows using graph-based execution, and ensures structured, reliable LLM outputs. Key packages include core framework components, graph execution engines, evaluation tools, and example applications.'
22302232
)
2233+
2234+
2235+
# ============================================================================
2236+
# Beta Graph API Tests - Tests for running pydantic-graph beta API in Temporal
2237+
# ============================================================================
2238+
2239+
2240+
@dataclass
2241+
class GraphState:
2242+
"""State for the graph execution test."""
2243+
2244+
values: list[int] = field(default_factory=list)
2245+
2246+
2247+
# Create a graph with parallel execution using the beta API
2248+
graph_builder = GraphBuilder(
2249+
name='parallel_test_graph',
2250+
state_type=GraphState,
2251+
input_type=int,
2252+
output_type=list[int],
2253+
)
2254+
2255+
2256+
@graph_builder.step
2257+
async def source(ctx: StepContext[GraphState, None, int]) -> int:
2258+
"""Source step that passes through the input value."""
2259+
return ctx.inputs
2260+
2261+
2262+
@graph_builder.step
2263+
async def multiply_by_two(ctx: StepContext[GraphState, None, int]) -> int:
2264+
"""Multiply input by 2."""
2265+
return ctx.inputs * 2
2266+
2267+
2268+
@graph_builder.step
2269+
async def multiply_by_three(ctx: StepContext[GraphState, None, int]) -> int:
2270+
"""Multiply input by 3."""
2271+
return ctx.inputs * 3
2272+
2273+
2274+
@graph_builder.step
2275+
async def multiply_by_four(ctx: StepContext[GraphState, None, int]) -> int:
2276+
"""Multiply input by 4."""
2277+
return ctx.inputs * 4
2278+
2279+
2280+
# Create a join to collect results
2281+
result_collector = graph_builder.join(reduce_list_append, initial_factory=list[int])
2282+
2283+
# Build the graph with parallel edges (broadcast pattern)
2284+
graph_builder.add(
2285+
graph_builder.edge_from(graph_builder.start_node).to(source),
2286+
# Broadcast: send value to all three parallel steps
2287+
graph_builder.edge_from(source).to(multiply_by_two, multiply_by_three, multiply_by_four),
2288+
# Collect all results
2289+
graph_builder.edge_from(multiply_by_two, multiply_by_three, multiply_by_four).to(result_collector),
2290+
graph_builder.edge_from(result_collector).to(graph_builder.end_node),
2291+
)
2292+
2293+
parallel_test_graph = graph_builder.build()
2294+
2295+
2296+
@workflow.defn
2297+
class ParallelGraphWorkflow:
2298+
"""Workflow that executes a graph with parallel task execution."""
2299+
2300+
@workflow.run
2301+
async def run(self, input_value: int) -> list[int]:
2302+
"""Run the parallel graph workflow.
2303+
2304+
Args:
2305+
input_value: The input number to process
2306+
2307+
Returns:
2308+
List of results from parallel execution
2309+
"""
2310+
result = await parallel_test_graph.run(
2311+
state=GraphState(),
2312+
inputs=input_value,
2313+
)
2314+
return result
2315+
2316+
2317+
async def test_beta_graph_parallel_execution_in_workflow(client: Client):
2318+
"""Test that beta graph API with parallel execution works in Temporal workflows.
2319+
2320+
This test verifies the fix for the bug where parallel task execution in graphs
2321+
wasn't working properly with Temporal workflows due to GraphTask/GraphTaskRequest
2322+
serialization issues.
2323+
"""
2324+
async with Worker(
2325+
client,
2326+
task_queue=TASK_QUEUE,
2327+
workflows=[ParallelGraphWorkflow],
2328+
):
2329+
output = await client.execute_workflow(
2330+
ParallelGraphWorkflow.run,
2331+
args=[10],
2332+
id=ParallelGraphWorkflow.__name__,
2333+
task_queue=TASK_QUEUE,
2334+
)
2335+
# Results can be in any order due to parallel execution
2336+
# 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40
2337+
assert sorted(output) == [20, 30, 40]

0 commit comments

Comments
 (0)