88"""
99
1010from dataclasses import dataclass
11- from typing import TYPE_CHECKING , Any , Callable , Generator , Protocol , Type , TypeVar
11+ from typing import TYPE_CHECKING , Callable , Generator , Generic , Protocol , Type , TypeVar
1212
1313if TYPE_CHECKING :
1414 from ..agent import Agent
@@ -36,7 +36,7 @@ def should_reverse_callbacks(self) -> bool:
3636
3737
3838T = TypeVar ("T" , bound = Callable )
39- TEvent = TypeVar ("TEvent" , bound = HookEvent )
39+ TEvent = TypeVar ("TEvent" , bound = HookEvent , contravariant = True )
4040
4141
4242class HookProvider (Protocol ):
@@ -66,7 +66,7 @@ def register_hooks(self, registry: "HookRegistry") -> None:
6666 ...
6767
6868
69- class HookCallback (Protocol ):
69+ class HookCallback (Protocol , Generic [ TEvent ] ):
7070 """Protocol for callback functions that handle hook events.
7171
7272 Hook callbacks are functions that receive a single strongly-typed event
@@ -80,7 +80,7 @@ def my_callback(event: StartRequestEvent) -> None:
8080 ```
8181 """
8282
83- def __call__ (self , event : Any ) -> None :
83+ def __call__ (self , event : TEvent ) -> None :
8484 """Handle a hook event.
8585
8686 Args:
@@ -104,7 +104,7 @@ def __init__(self) -> None:
104104 """Initialize an empty hook registry."""
105105 self ._registered_callbacks : dict [Type , list [HookCallback ]] = {}
106106
107- def add_callback (self , event_type : Type , callback : HookCallback ) -> None :
107+ def add_callback (self , event_type : Type [ TEvent ] , callback : HookCallback [ TEvent ] ) -> None :
108108 """Register a callback function for a specific event type.
109109
110110 Args:
@@ -166,7 +166,7 @@ def invoke_callbacks(self, event: TEvent) -> None:
166166 for callback in self .get_callbacks_for (event ):
167167 callback (event )
168168
169- def get_callbacks_for (self , event : TEvent ) -> Generator [HookCallback , None , None ]:
169+ def get_callbacks_for (self , event : TEvent ) -> Generator [HookCallback [ TEvent ] , None , None ]:
170170 """Get callbacks registered for the given event in the appropriate order.
171171
172172 This method returns callbacks in registration order for normal events,
0 commit comments