Skip to content

Commit cbb0e58

Browse files
committed
Fix some type issues
1 parent 24301ad commit cbb0e58

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,26 @@ def __init__(self, span, num_tokens=0):
5252
self.num_tokens = num_tokens
5353

5454

55-
class SentryLangchainCallback(BaseCallbackHandler):
55+
class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
5656
"""Base callback handler that can be used to handle callbacks from langchain."""
5757

5858
span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]
5959

6060
max_span_map_size = 0
6161

6262
def __init__(self, max_span_map_size, include_prompts):
63+
# type: (int, bool) -> None
6364
self.max_span_map_size = max_span_map_size
6465
self.include_prompts = include_prompts
6566

6667
def gc_span_map(self):
68+
# type: () -> None
69+
6770
while len(self.span_map) > self.max_span_map_size:
6871
self.span_map.popitem(last=False)[1].span.__exit__(None, None, None)
6972

7073
def _handle_error(self, run_id, error):
71-
# type: (str, Any) -> None
74+
# type: (UUID, Any) -> None
7275
if not run_id or not self.span_map[run_id]:
7376
return
7477

@@ -80,13 +83,13 @@ def _handle_error(self, run_id, error):
8083
del self.span_map[run_id]
8184

8285
def _normalize_langchain_message(self, message):
83-
# type: (BaseMessage) -> dict
86+
# type: (BaseMessage) -> Any
8487
parsed = {"content": message.content, "role": message.type}
8588
parsed.update(message.additional_kwargs)
8689
return parsed
8790

8891
def _create_span(self, run_id, parent_id, **kwargs):
89-
# type: (UUID, Optional[UUID], Any) -> Span
92+
# type: (SentryLangchainCallback, UUID, Optional[Any], Dict[str, Any]) -> Span
9093

9194
span = None # type: Optional[Span]
9295
if parent_id:
@@ -113,7 +116,7 @@ def on_llm_start(
113116
metadata=None,
114117
**kwargs,
115118
):
116-
# type: (Dict[str, Any], List[str], Any, UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
119+
# type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Dict[str, Any]) -> Any
117120
"""Run when LLM starts running."""
118121
with capture_internal_exceptions():
119122
if not run_id:
@@ -128,7 +131,7 @@ def on_llm_start(
128131
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts)
129132

130133
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
131-
# type: (Dict[str, Any], List[List[BaseMessage]], Any, UUID, Any) -> Any
134+
# type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Dict[str, Any]) -> Any
132135
"""Run when Chat Model starts running."""
133136
if not run_id:
134137
return
@@ -151,7 +154,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
151154
)
152155

153156
def on_llm_new_token(self, token, *, run_id, **kwargs):
154-
# type: (str, Any, UUID, Any) -> Any
157+
# type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any
155158
"""Run on new LLM token. Only available when streaming is enabled."""
156159
with capture_internal_exceptions():
157160
if not run_id or not self.span_map[run_id]:
@@ -162,7 +165,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
162165
span_data.num_tokens += 1
163166

164167
def on_llm_end(self, response, *, run_id, **kwargs):
165-
# type: (LLMResult, Any, UUID, Any) -> Any
168+
# type: (SentryLangchainCallback, LLMResult, UUID, Dict[str, Any]) -> Any
166169
"""Run when LLM ends running."""
167170
with capture_internal_exceptions():
168171
if not run_id:
@@ -206,13 +209,13 @@ def on_llm_end(self, response, *, run_id, **kwargs):
206209
del self.span_map[run_id]
207210

208211
def on_llm_error(self, error, *, run_id, **kwargs):
209-
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
212+
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
210213
"""Run when LLM errors."""
211214
with capture_internal_exceptions():
212215
self._handle_error(run_id, error)
213216

214217
def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
215-
# type: (Dict[str, Any], Dict[str, Any], Any, UUID, Any) -> Any
218+
# type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Dict[str, Any]) -> Any
216219
"""Run when chain starts running."""
217220
with capture_internal_exceptions():
218221
if not run_id:
@@ -225,7 +228,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
225228
)
226229

227230
def on_chain_end(self, outputs, *, run_id, **kwargs):
228-
# type: (Dict[str, Any], Any, UUID, Any) -> Any
231+
# type: (SentryLangchainCallback, Dict[str, Any], UUID, Dict[str, Any]) -> Any
229232
"""Run when chain ends running."""
230233
with capture_internal_exceptions():
231234
if not run_id or not self.span_map[run_id]:
@@ -238,12 +241,12 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
238241
del self.span_map[run_id]
239242

240243
def on_chain_error(self, error, *, run_id, **kwargs):
241-
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
244+
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
242245
"""Run when chain errors."""
243246
self._handle_error(run_id, error)
244247

245248
def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
246-
# type: (Dict[str, Any], str, Any, UUID, Any) -> Any
249+
# type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Dict[str, Any]) -> Any
247250
"""Run when tool starts running."""
248251
with capture_internal_exceptions():
249252
if not run_id:
@@ -260,7 +263,7 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
260263
)
261264

262265
def on_tool_end(self, output, *, run_id, **kwargs):
263-
# type: (str, Any, UUID, Any) -> Any
266+
# type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any
264267
"""Run when tool ends running."""
265268
with capture_internal_exceptions():
266269
if not run_id or not self.span_map[run_id]:
@@ -275,7 +278,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
275278
del self.span_map[run_id]
276279

277280
def on_tool_error(self, error, *args, run_id, **kwargs):
278-
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
281+
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
279282
"""Run when tool errors."""
280283
self._handle_error(run_id, error)
281284

@@ -290,7 +293,7 @@ def new_configure(*args, **kwargs):
290293
integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
291294

292295
with capture_internal_exceptions():
293-
new_callbacks = []
296+
new_callbacks = [] # type: List[BaseCallbackHandler]
294297
if "local_callbacks" in kwargs:
295298
existing_callbacks = kwargs["local_callbacks"]
296299
kwargs["local_callbacks"] = new_callbacks

0 commit comments

Comments
 (0)