88from __future__ import annotations as _annotations
99
1010import sys
11- import uuid
12- from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable , Sequence
11+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Callable , Iterable , Sequence
1312from contextlib import AbstractContextManager , AsyncExitStack , ExitStack , asynccontextmanager , contextmanager
1413from dataclasses import dataclass , field
1514from typing import TYPE_CHECKING , Any , Generic , Literal , TypeGuard , cast , get_args , get_origin , overload
2221from pydantic_graph import exceptions
2322from pydantic_graph ._utils import AbstractSpan , get_traceparent , infer_obj_name , logfire_span
2423from pydantic_graph .beta .decision import Decision
25- from pydantic_graph .beta .id_types import ForkID , ForkStack , ForkStackItem , GraphRunID , JoinID , NodeID , NodeRunID , TaskID
24+ from pydantic_graph .beta .id_types import ForkID , ForkStack , ForkStackItem , JoinID , NodeID , NodeRunID , TaskID
2625from pydantic_graph .beta .join import Join , JoinNode , JoinState , ReducerContext
2726from pydantic_graph .beta .node import (
2827 EndNode ,
@@ -306,14 +305,13 @@ def __str__(self) -> str:
306305
307306
308307@dataclass
309- class GraphTask :
310- """A single task representing the execution of a node in the graph.
308+ class GraphTaskRequest :
309+ """A request to run a task representing the execution of a node in the graph.
311310
312- GraphTask encapsulates all the information needed to execute a specific
311+ GraphTaskRequest encapsulates all the information needed to execute a specific
313312 node, including its inputs and the fork context it's executing within.
314313 """
315314
316- # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317315 node_id : NodeID
318316 """The ID of the node to execute."""
319317
@@ -326,9 +324,26 @@ class GraphTask:
326324 Used by the GraphRun to decide when to proceed through joins.
327325 """
328326
329- task_id : TaskID = field (default_factory = lambda : TaskID (str (uuid .uuid4 ())), repr = False )
327+
328+ @dataclass
329+ class GraphTask (GraphTaskRequest ):
330+ """A task representing the execution of a node in the graph.
331+
332+ GraphTask encapsulates all the information needed to execute a specific
333+ node, including its inputs and the fork context it's executing within,
334+ and has a unique ID to identify the task within the graph run.
335+ """
336+
337+ task_id : TaskID = field (repr = False )
330338 """Unique identifier for this task."""
331339
340+ @staticmethod
341+ def from_request (request : GraphTaskRequest , get_task_id : Callable [[], TaskID ]) -> GraphTask :
342+ # Don't call the get_task_id callable, this is already a task
343+ if isinstance (request , GraphTask ):
344+ return request
345+ return GraphTask (request .node_id , request .inputs , request .fork_stack , get_task_id ())
346+
332347
333348class GraphRun (Generic [StateT , DepsT , OutputT ]):
334349 """A single execution instance of a graph.
@@ -378,12 +393,20 @@ def __init__(
378393 self ._next : EndMarker [OutputT ] | Sequence [GraphTask ] | None = None
379394 """The next item to be processed."""
380395
381- run_id = GraphRunID (str (uuid .uuid4 ()))
382- initial_fork_stack : ForkStack = (ForkStackItem (StartNode .id , NodeRunID (run_id ), 0 ),)
383- self ._first_task = GraphTask (node_id = StartNode .id , inputs = inputs , fork_stack = initial_fork_stack )
396+ self ._next_task_id = 0
397+ self ._next_node_run_id = 0
398+ initial_fork_stack : ForkStack = (ForkStackItem (StartNode .id , self ._get_next_node_run_id (), 0 ),)
399+ self ._first_task = GraphTask (
400+ node_id = StartNode .id , inputs = inputs , fork_stack = initial_fork_stack , task_id = self ._get_next_task_id ()
401+ )
384402 self ._iterator_task_group = create_task_group ()
385403 self ._iterator_instance = _GraphIterator [StateT , DepsT , OutputT ](
386- self .graph , self .state , self .deps , self ._iterator_task_group
404+ self .graph ,
405+ self .state ,
406+ self .deps ,
407+ self ._iterator_task_group ,
408+ self ._get_next_node_run_id ,
409+ self ._get_next_task_id ,
387410 )
388411 self ._iterator = self ._iterator_instance .iter_graph (self ._first_task )
389412
@@ -449,7 +472,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
449472 return self ._next
450473
451474 async def next (
452- self , value : EndMarker [OutputT ] | Sequence [GraphTask ] | None = None
475+ self , value : EndMarker [OutputT ] | Sequence [GraphTaskRequest ] | None = None
453476 ) -> EndMarker [OutputT ] | Sequence [GraphTask ]:
454477 """Advance the graph execution by one step.
455478
@@ -467,7 +490,10 @@ async def next(
467490 # if `next` is called before the `first_node` has run.
468491 await anext (self )
469492 if value is not None :
470- self ._next = value
493+ if isinstance (value , EndMarker ):
494+ self ._next = value
495+ else :
496+ self ._next = [GraphTask .from_request (gtr , self ._get_next_task_id ) for gtr in value ]
471497 return await anext (self )
472498
473499 @property
@@ -490,6 +516,16 @@ def output(self) -> OutputT | None:
490516 return self ._next .value
491517 return None
492518
519+ def _get_next_task_id (self ) -> TaskID :
520+ next_id = TaskID (f'task:{ self ._next_task_id } ' )
521+ self ._next_task_id += 1
522+ return next_id
523+
524+ def _get_next_node_run_id (self ) -> NodeRunID :
525+ next_id = NodeRunID (f'task:{ self ._next_node_run_id } ' )
526+ self ._next_node_run_id += 1
527+ return next_id
528+
493529
494530@dataclass
495531class _GraphTaskAsyncIterable :
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510546 state : StateT
511547 deps : DepsT
512548 task_group : TaskGroup
549+ get_next_node_run_id : Callable [[], NodeRunID ]
550+ get_next_task_id : Callable [[], TaskID ]
513551
514552 cancel_scopes : dict [TaskID , CancelScope ] = field (init = False )
515553 active_tasks : dict [TaskID , GraphTask ] = field (init = False )
@@ -522,6 +560,7 @@ def __post_init__(self):
522560 self .active_tasks = {}
523561 self .active_reducers = {}
524562 self .iter_stream_sender , self .iter_stream_receiver = create_memory_object_stream [_GraphTaskResult ]()
563+ self ._next_node_run_id = 1
525564
526565 async def iter_graph ( # noqa C901
527566 self , first_task : GraphTask
@@ -782,12 +821,12 @@ def _handle_node(
782821 fork_stack : ForkStack ,
783822 ) -> Sequence [GraphTask ] | JoinItem | EndMarker [OutputT ]:
784823 if isinstance (next_node , StepNode ):
785- return [GraphTask (next_node .step .id , next_node .inputs , fork_stack )]
824+ return [GraphTask (next_node .step .id , next_node .inputs , fork_stack , self . get_next_task_id () )]
786825 elif isinstance (next_node , JoinNode ):
787826 return JoinItem (next_node .join .id , next_node .inputs , fork_stack )
788827 elif isinstance (next_node , BaseNode ):
789828 node_step = NodeStep (next_node .__class__ )
790- return [GraphTask (node_step .id , next_node , fork_stack )]
829+ return [GraphTask (node_step .id , next_node , fork_stack , self . get_next_task_id () )]
791830 elif isinstance (next_node , End ):
792831 return EndMarker (next_node .data )
793832 else :
@@ -821,7 +860,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
821860 'These markers should be removed from paths during graph building'
822861 )
823862 if isinstance (item , DestinationMarker ):
824- return [GraphTask (item .destination_id , inputs , fork_stack )]
863+ return [GraphTask (item .destination_id , inputs , fork_stack , self . get_next_task_id () )]
825864 elif isinstance (item , TransformMarker ):
826865 inputs = item .transform (StepContext (state = self .state , deps = self .deps , inputs = inputs ))
827866 return self ._handle_path (path .next_path , inputs , fork_stack )
@@ -853,7 +892,7 @@ def _handle_fork_edges(
853892 ) # this should have already been ensured during graph building
854893
855894 new_tasks : list [GraphTask ] = []
856- node_run_id = NodeRunID ( str ( uuid . uuid4 ()) )
895+ node_run_id = self . get_next_node_run_id ( )
857896 if node .is_map :
858897 # If the map specifies a downstream join id, eagerly create a join state for it
859898 if (join_id := node .downstream_join_id ) is not None :
0 commit comments