@@ -52,23 +52,26 @@ def __init__(self, span, num_tokens=0):
52
52
self .num_tokens = num_tokens
53
53
54
54
55
- class SentryLangchainCallback (BaseCallbackHandler ):
55
+ class SentryLangchainCallback (BaseCallbackHandler ): # type: ignore[misc]
56
56
"""Base callback handler that can be used to handle callbacks from langchain."""
57
57
58
58
span_map = OrderedDict () # type: OrderedDict[UUID, WatchedSpan]
59
59
60
60
max_span_map_size = 0
61
61
62
62
def __init__ (self , max_span_map_size , include_prompts ):
63
+ # type: (int, bool) -> None
63
64
self .max_span_map_size = max_span_map_size
64
65
self .include_prompts = include_prompts
65
66
66
67
def gc_span_map (self ):
68
+ # type: () -> None
69
+
67
70
while len (self .span_map ) > self .max_span_map_size :
68
71
self .span_map .popitem (last = False )[1 ].span .__exit__ (None , None , None )
69
72
70
73
def _handle_error (self , run_id , error ):
71
- # type: (str , Any) -> None
74
+ # type: (UUID , Any) -> None
72
75
if not run_id or not self .span_map [run_id ]:
73
76
return
74
77
@@ -80,7 +83,7 @@ def _handle_error(self, run_id, error):
80
83
del self .span_map [run_id ]
81
84
82
85
def _normalize_langchain_message (self , message ):
83
- # type: (BaseMessage) -> dict
86
+ # type: (BaseMessage) -> Any
84
87
parsed = {"content" : message .content , "role" : message .type }
85
88
parsed .update (message .additional_kwargs )
86
89
return parsed
@@ -113,7 +116,7 @@ def on_llm_start(
113
116
metadata = None ,
114
117
** kwargs ,
115
118
):
116
- # type: (Dict[str, Any], List[str], Any, UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
119
+ # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Dict[str, Any] ) -> Any
117
120
"""Run when LLM starts running."""
118
121
with capture_internal_exceptions ():
119
122
if not run_id :
@@ -128,7 +131,7 @@ def on_llm_start(
128
131
set_data_normalized (span , SPANDATA .AI_INPUT_MESSAGES , prompts )
129
132
130
133
def on_chat_model_start (self , serialized , messages , * , run_id , ** kwargs ):
131
- # type: (Dict[str, Any], List[List[BaseMessage]], Any, UUID , Any) -> Any
134
+ # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Dict[str , Any] ) -> Any
132
135
"""Run when Chat Model starts running."""
133
136
if not run_id :
134
137
return
@@ -151,7 +154,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
151
154
)
152
155
153
156
def on_llm_new_token (self , token , * , run_id , ** kwargs ):
154
- # type: (str, Any , UUID, Any) -> Any
157
+ # type: (SentryLangchainCallback, str , UUID, Dict[str, Any] ) -> Any
155
158
"""Run on new LLM token. Only available when streaming is enabled."""
156
159
with capture_internal_exceptions ():
157
160
if not run_id or not self .span_map [run_id ]:
@@ -162,7 +165,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
162
165
span_data .num_tokens += 1
163
166
164
167
def on_llm_end (self , response , * , run_id , ** kwargs ):
165
- # type: (LLMResult, Any , UUID, Any) -> Any
168
+ # type: (SentryLangchainCallback, LLMResult , UUID, Dict[str, Any] ) -> Any
166
169
"""Run when LLM ends running."""
167
170
with capture_internal_exceptions ():
168
171
if not run_id :
@@ -206,13 +209,13 @@ def on_llm_end(self, response, *, run_id, **kwargs):
206
209
del self .span_map [run_id ]
207
210
208
211
def on_llm_error (self , error , * , run_id , ** kwargs ):
209
- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
212
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
210
213
"""Run when LLM errors."""
211
214
with capture_internal_exceptions ():
212
215
self ._handle_error (run_id , error )
213
216
214
217
def on_chain_start (self , serialized , inputs , * , run_id , ** kwargs ):
215
- # type: (Dict[str, Any], Dict[str, Any], Any, UUID , Any) -> Any
218
+ # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Dict[str , Any] ) -> Any
216
219
"""Run when chain starts running."""
217
220
with capture_internal_exceptions ():
218
221
if not run_id :
@@ -225,7 +228,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
225
228
)
226
229
227
230
def on_chain_end (self , outputs , * , run_id , ** kwargs ):
228
- # type: (Dict[str, Any], Any, UUID , Any) -> Any
231
+ # type: (SentryLangchainCallback, Dict[str, Any], UUID, Dict[str , Any] ) -> Any
229
232
"""Run when chain ends running."""
230
233
with capture_internal_exceptions ():
231
234
if not run_id or not self .span_map [run_id ]:
@@ -238,12 +241,12 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
238
241
del self .span_map [run_id ]
239
242
240
243
def on_chain_error (self , error , * , run_id , ** kwargs ):
241
- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
244
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
242
245
"""Run when chain errors."""
243
246
self ._handle_error (run_id , error )
244
247
245
248
def on_tool_start (self , serialized , input_str , * , run_id , ** kwargs ):
246
- # type: (Dict[str, Any], str, Any, UUID , Any) -> Any
249
+ # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Dict[str , Any] ) -> Any
247
250
"""Run when tool starts running."""
248
251
with capture_internal_exceptions ():
249
252
if not run_id :
@@ -260,7 +263,7 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
260
263
)
261
264
262
265
def on_tool_end (self , output , * , run_id , ** kwargs ):
263
- # type: (str, Any , UUID, Any) -> Any
266
+ # type: (SentryLangchainCallback, str , UUID, Dict[str, Any] ) -> Any
264
267
"""Run when tool ends running."""
265
268
with capture_internal_exceptions ():
266
269
if not run_id or not self .span_map [run_id ]:
@@ -275,7 +278,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
275
278
del self .span_map [run_id ]
276
279
277
280
def on_tool_error (self , error , * args , run_id , ** kwargs ):
278
- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
281
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
279
282
"""Run when tool errors."""
280
283
self ._handle_error (run_id , error )
281
284
@@ -290,7 +293,7 @@ def new_configure(*args, **kwargs):
290
293
integration = sentry_sdk .get_client ().get_integration (LangchainIntegration )
291
294
292
295
with capture_internal_exceptions ():
293
- new_callbacks = []
296
+ new_callbacks = [] # type: List[BaseCallbackHandler]
294
297
if "local_callbacks" in kwargs :
295
298
existing_callbacks = kwargs ["local_callbacks" ]
296
299
kwargs ["local_callbacks" ] = new_callbacks
0 commit comments