Skip to content

Commit 411eba4

Browse files
committed
fix: strongly type the TEvent
1 parent 59a9245 commit 411eba4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/strands/hooks/registry.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Any, Callable, Generator, Protocol, Type, TypeVar
11+
from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar
1212

1313
if TYPE_CHECKING:
1414
from ..agent import Agent
@@ -36,7 +36,7 @@ def should_reverse_callbacks(self) -> bool:
3636

3737

3838
T = TypeVar("T", bound=Callable)
39-
TEvent = TypeVar("TEvent", bound=HookEvent)
39+
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
4040

4141

4242
class HookProvider(Protocol):
@@ -66,7 +66,7 @@ def register_hooks(self, registry: "HookRegistry") -> None:
6666
...
6767

6868

69-
class HookCallback(Protocol):
69+
class HookCallback(Protocol, Generic[TEvent]):
7070
"""Protocol for callback functions that handle hook events.
7171
7272
Hook callbacks are functions that receive a single strongly-typed event
@@ -80,7 +80,7 @@ def my_callback(event: StartRequestEvent) -> None:
8080
```
8181
"""
8282

83-
def __call__(self, event: Any) -> None:
83+
def __call__(self, event: TEvent) -> None:
8484
"""Handle a hook event.
8585
8686
Args:
@@ -104,7 +104,7 @@ def __init__(self) -> None:
104104
"""Initialize an empty hook registry."""
105105
self._registered_callbacks: dict[Type, list[HookCallback]] = {}
106106

107-
def add_callback(self, event_type: Type, callback: HookCallback) -> None:
107+
def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None:
108108
"""Register a callback function for a specific event type.
109109
110110
Args:
@@ -166,7 +166,7 @@ def invoke_callbacks(self, event: TEvent) -> None:
166166
for callback in self.get_callbacks_for(event):
167167
callback(event)
168168

169-
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback, None, None]:
169+
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
170170
"""Get callbacks registered for the given event in the appropriate order.
171171
172172
This method returns callbacks in registration order for normal events,

0 commit comments

Comments
 (0)