992. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010"""
1111
12- import asyncio
1312import json
1413import logging
1514import os
1615import random
1716from concurrent .futures import ThreadPoolExecutor
18- from threading import Thread
19- from typing import Any , AsyncIterator , Callable , Dict , List , Mapping , Optional , Type , TypeVar , Union
20- from uuid import uuid4
17+ from typing import Any , AsyncIterator , Callable , Generator , Mapping , Optional , Type , TypeVar , Union , cast
2118
2219from opentelemetry import trace
2320from pydantic import BaseModel
2421
2522from ..event_loop .event_loop import event_loop_cycle
26- from ..handlers .callback_handler import CompositeCallbackHandler , PrintingCallbackHandler , null_callback_handler
23+ from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
2724from ..handlers .tool_handler import AgentToolHandler
2825from ..models .bedrock import BedrockModel
2926from ..telemetry .metrics import EventLoopMetrics
@@ -183,7 +180,7 @@ def __init__(
183180 self ,
184181 model : Union [Model , str , None ] = None ,
185182 messages : Optional [Messages ] = None ,
186- tools : Optional [List [Union [str , Dict [str , str ], Any ]]] = None ,
183+ tools : Optional [list [Union [str , dict [str , str ], Any ]]] = None ,
187184 system_prompt : Optional [str ] = None ,
188185 callback_handler : Optional [
189186 Union [Callable [..., Any ], _DefaultCallbackHandlerSentinel ]
@@ -255,7 +252,7 @@ def __init__(
255252 self .conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager ()
256253
257254 # Process trace attributes to ensure they're of compatible types
258- self .trace_attributes : Dict [str , AttributeValue ] = {}
255+ self .trace_attributes : dict [str , AttributeValue ] = {}
259256 if trace_attributes :
260257 for k , v in trace_attributes .items ():
261258 if isinstance (v , (str , int , float , bool )) or (
@@ -312,7 +309,7 @@ def tool(self) -> ToolCaller:
312309 return self .tool_caller
313310
314311 @property
315- def tool_names (self ) -> List [str ]:
312+ def tool_names (self ) -> list [str ]:
316313 """Get a list of all registered tool names.
317314
318315 Returns:
@@ -357,19 +354,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
357354 - metrics: Performance metrics from the event loop
358355 - state: The final state of the event loop
359356 """
357+ callback_handler = kwargs .get ("callback_handler" , self .callback_handler )
358+
360359 self ._start_agent_trace_span (prompt )
361360
362361 try :
363- # Run the event loop and get the result
364- result = self ._run_loop (prompt , kwargs )
362+ events = self ._run_loop (callback_handler , prompt , kwargs )
363+ for event in events :
364+ if "callback" in event :
365+ callback_handler (** event ["callback" ])
366+
367+ stop_reason , message , metrics , state = event ["stop" ]
368+ result = AgentResult (stop_reason , message , metrics , state )
365369
366370 self ._end_agent_trace_span (response = result )
367371
368372 return result
373+
369374 except Exception as e :
370375 self ._end_agent_trace_span (error = e )
371-
372- # Re-raise the exception to preserve original behavior
373376 raise
374377
375378 def structured_output (self , output_model : Type [T ], prompt : Optional [str ] = None ) -> T :
@@ -383,9 +386,9 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
383386 instruct the model to output the structured data.
384387
385388 Args:
386- output_model(Type[BaseModel]) : The output model (a JSON schema written as a Pydantic BaseModel)
389+ output_model: The output model (a JSON schema written as a Pydantic BaseModel)
387390 that the agent will use when responding.
388- prompt(Optional[str]) : The prompt to use for the agent.
391+ prompt: The prompt to use for the agent.
389392 """
390393 messages = self .messages
391394 if not messages and not prompt :
@@ -396,7 +399,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
396399 messages .append ({"role" : "user" , "content" : [{"text" : prompt }]})
397400
398401 # get the structured output from the model
399- return self .model .structured_output (output_model , messages , self .callback_handler )
402+ events = self .model .structured_output (output_model , messages )
403+ for event in events :
404+ if "callback" in event :
405+ self .callback_handler (** cast (dict , event ["callback" ]))
406+
407+ return event ["output" ]
400408
401409 async def stream_async (self , prompt : str , ** kwargs : Any ) -> AsyncIterator [Any ]:
402410 """Process a natural language prompt and yield events as an async iterator.
@@ -428,94 +436,63 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
428436 yield event["data"]
429437 ```
430438 """
431- self . _start_agent_trace_span ( prompt )
439+ callback_handler = kwargs . get ( "callback_handler" , self . callback_handler )
432440
433- _stop_event = uuid4 ()
434-
435- queue = asyncio .Queue [Any ]()
436- loop = asyncio .get_event_loop ()
437-
438- def enqueue (an_item : Any ) -> None :
439- nonlocal queue
440- nonlocal loop
441- loop .call_soon_threadsafe (queue .put_nowait , an_item )
442-
443- def queuing_callback_handler (** handler_kwargs : Any ) -> None :
444- enqueue (handler_kwargs .copy ())
441+ self ._start_agent_trace_span (prompt )
445442
446- def target_callback () -> None :
447- nonlocal kwargs
443+ try :
444+ events = self ._run_loop (callback_handler , prompt , kwargs )
445+ for event in events :
446+ if "callback" in event :
447+ callback_handler (** event ["callback" ])
448+ yield event ["callback" ]
448449
449- try :
450- result = self ._run_loop (prompt , kwargs , supplementary_callback_handler = queuing_callback_handler )
451- self ._end_agent_trace_span (response = result )
452- except Exception as e :
453- self ._end_agent_trace_span (error = e )
454- enqueue (e )
455- finally :
456- enqueue (_stop_event )
450+ stop_reason , message , metrics , state = event ["stop" ]
451+ result = AgentResult (stop_reason , message , metrics , state )
457452
458- thread = Thread (target = target_callback , daemon = True )
459- thread .start ()
453+ self ._end_agent_trace_span (response = result )
460454
461- try :
462- while True :
463- item = await queue .get ()
464- if item == _stop_event :
465- break
466- if isinstance (item , Exception ):
467- raise item
468- yield item
469- finally :
470- thread .join ()
455+ except Exception as e :
456+ self ._end_agent_trace_span (error = e )
457+ raise
471458
472459 def _run_loop (
473- self , prompt : str , kwargs : Dict [ str , Any ], supplementary_callback_handler : Optional [ Callable [... , Any ]] = None
474- ) -> AgentResult :
460+ self , callback_handler : Callable [... , Any ], prompt : str , kwargs : dict [ str , Any ]
461+ ) -> Generator [ dict [ str , Any ], None , None ] :
475462 """Execute the agent's event loop with the given prompt and parameters."""
476463 try :
477- # If the call had a callback_handler passed in, then for this event_loop
478- # cycle we call both handlers as the callback_handler
479- invocation_callback_handler = (
480- CompositeCallbackHandler (self .callback_handler , supplementary_callback_handler )
481- if supplementary_callback_handler is not None
482- else self .callback_handler
483- )
484-
485464 # Extract key parameters
486- invocation_callback_handler ( init_event_loop = True , ** kwargs )
465+ yield { "callback" : { " init_event_loop" : True , ** kwargs }}
487466
488467 # Set up the user message with optional knowledge base retrieval
489- message_content : List [ContentBlock ] = [{"text" : prompt }]
468+ message_content : list [ContentBlock ] = [{"text" : prompt }]
490469 new_message : Message = {"role" : "user" , "content" : message_content }
491470 self .messages .append (new_message )
492471
493472 # Execute the event loop cycle with retry logic for context limits
494- return self ._execute_event_loop_cycle (invocation_callback_handler , kwargs )
473+ yield from self ._execute_event_loop_cycle (callback_handler , kwargs )
495474
496475 finally :
497476 self .conversation_manager .apply_management (self )
498477
499- def _execute_event_loop_cycle (self , callback_handler : Callable [..., Any ], kwargs : dict [str , Any ]) -> AgentResult :
478+ def _execute_event_loop_cycle (
479+ self , callback_handler : Callable [..., Any ], kwargs : dict [str , Any ]
480+ ) -> Generator [dict [str , Any ], None , None ]:
500481 """Execute the event loop cycle with retry logic for context window limits.
501482
502483 This internal method handles the execution of the event loop cycle and implements
503484 retry logic for handling context window overflow exceptions by reducing the
504485 conversation context and retrying.
505486
506- Args:
507- callback_handler: The callback handler to use for events.
508- kwargs: Additional parameters to pass through event loop.
509-
510- Returns:
511- The result of the event loop cycle.
487+ Yields:
488+ Events of the loop cycle.
512489 """
513490 # Add `Agent` to kwargs to keep backwards-compatibility
514491 kwargs ["agent" ] = self
515492
516493 try :
517494 # Execute the main event loop cycle
518- events = event_loop_cycle (
495+ yield from event_loop_cycle (
519496 model = self .model ,
520497 system_prompt = self .system_prompt ,
521498 messages = self .messages , # will be modified by event_loop_cycle
@@ -527,19 +504,11 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
527504 event_loop_parent_span = self .trace_span ,
528505 kwargs = kwargs ,
529506 )
530- for event in events :
531- if "callback" in event :
532- callback_handler (** event ["callback" ])
533-
534- stop_reason , message , metrics , state = event ["stop" ]
535-
536- return AgentResult (stop_reason , message , metrics , state )
537507
538508 except ContextWindowOverflowException as e :
539509 # Try reducing the context size and retrying
540-
541510 self .conversation_manager .reduce_context (self , e = e )
542- return self ._execute_event_loop_cycle (callback_handler , kwargs )
511+ yield from self ._execute_event_loop_cycle (callback_handler_override , kwargs )
543512
544513 def _record_tool_execution (
545514 self ,
@@ -625,7 +594,7 @@ def _end_agent_trace_span(
625594 error: Error to record as a trace attribute.
626595 """
627596 if self .trace_span :
628- trace_attributes : Dict [str , Any ] = {
597+ trace_attributes : dict [str , Any ] = {
629598 "span" : self .trace_span ,
630599 }
631600
0 commit comments