@@ -52,23 +52,26 @@ def __init__(self, span, num_tokens=0):
5252 self .num_tokens = num_tokens
5353
5454
55- class SentryLangchainCallback (BaseCallbackHandler ):
55+ class SentryLangchainCallback (BaseCallbackHandler ): # type: ignore[misc]
5656 """Base callback handler that can be used to handle callbacks from langchain."""
5757
5858 span_map = OrderedDict () # type: OrderedDict[UUID, WatchedSpan]
5959
6060 max_span_map_size = 0
6161
6262 def __init__ (self , max_span_map_size , include_prompts ):
63+ # type: (int, bool) -> None
6364 self .max_span_map_size = max_span_map_size
6465 self .include_prompts = include_prompts
6566
6667 def gc_span_map (self ):
68+ # type: () -> None
69+
6770 while len (self .span_map ) > self .max_span_map_size :
6871 self .span_map .popitem (last = False )[1 ].span .__exit__ (None , None , None )
6972
7073 def _handle_error (self , run_id , error ):
71- # type: (str , Any) -> None
74+ # type: (UUID , Any) -> None
7275 if not run_id or not self .span_map [run_id ]:
7376 return
7477
@@ -80,7 +83,7 @@ def _handle_error(self, run_id, error):
8083 del self .span_map [run_id ]
8184
8285 def _normalize_langchain_message (self , message ):
83- # type: (BaseMessage) -> dict
86+ # type: (BaseMessage) -> Any
8487 parsed = {"content" : message .content , "role" : message .type }
8588 parsed .update (message .additional_kwargs )
8689 return parsed
@@ -113,7 +116,7 @@ def on_llm_start(
113116 metadata = None ,
114117 ** kwargs ,
115118 ):
116- # type: (Dict[str, Any], List[str], Any, UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
119+ # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Dict[str, Any] ) -> Any
117120 """Run when LLM starts running."""
118121 with capture_internal_exceptions ():
119122 if not run_id :
@@ -128,7 +131,7 @@ def on_llm_start(
128131 set_data_normalized (span , SPANDATA .AI_INPUT_MESSAGES , prompts )
129132
130133 def on_chat_model_start (self , serialized , messages , * , run_id , ** kwargs ):
131- # type: (Dict[str, Any], List[List[BaseMessage]], Any, UUID , Any) -> Any
134+ # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Dict[str , Any] ) -> Any
132135 """Run when Chat Model starts running."""
133136 if not run_id :
134137 return
@@ -151,7 +154,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
151154 )
152155
153156 def on_llm_new_token (self , token , * , run_id , ** kwargs ):
154- # type: (str, Any , UUID, Any) -> Any
157+ # type: (SentryLangchainCallback, str , UUID, Dict[str, Any] ) -> Any
155158 """Run on new LLM token. Only available when streaming is enabled."""
156159 with capture_internal_exceptions ():
157160 if not run_id or not self .span_map [run_id ]:
@@ -162,7 +165,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
162165 span_data .num_tokens += 1
163166
164167 def on_llm_end (self , response , * , run_id , ** kwargs ):
165- # type: (LLMResult, Any , UUID, Any) -> Any
168+ # type: (SentryLangchainCallback, LLMResult , UUID, Dict[str, Any] ) -> Any
166169 """Run when LLM ends running."""
167170 with capture_internal_exceptions ():
168171 if not run_id :
@@ -206,13 +209,13 @@ def on_llm_end(self, response, *, run_id, **kwargs):
206209 del self .span_map [run_id ]
207210
208211 def on_llm_error (self , error , * , run_id , ** kwargs ):
209- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
212+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
210213 """Run when LLM errors."""
211214 with capture_internal_exceptions ():
212215 self ._handle_error (run_id , error )
213216
214217 def on_chain_start (self , serialized , inputs , * , run_id , ** kwargs ):
215- # type: (Dict[str, Any], Dict[str, Any], Any, UUID , Any) -> Any
218+ # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Dict[str , Any] ) -> Any
216219 """Run when chain starts running."""
217220 with capture_internal_exceptions ():
218221 if not run_id :
@@ -225,7 +228,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
225228 )
226229
227230 def on_chain_end (self , outputs , * , run_id , ** kwargs ):
228- # type: (Dict[str, Any], Any, UUID , Any) -> Any
231+ # type: (SentryLangchainCallback, Dict[str, Any], UUID, Dict[str , Any] ) -> Any
229232 """Run when chain ends running."""
230233 with capture_internal_exceptions ():
231234 if not run_id or not self .span_map [run_id ]:
@@ -238,12 +241,12 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
238241 del self .span_map [run_id ]
239242
240243 def on_chain_error (self , error , * , run_id , ** kwargs ):
241- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
244+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
242245 """Run when chain errors."""
243246 self ._handle_error (run_id , error )
244247
245248 def on_tool_start (self , serialized , input_str , * , run_id , ** kwargs ):
246- # type: (Dict[str, Any], str, Any, UUID , Any) -> Any
249+ # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Dict[str , Any] ) -> Any
247250 """Run when tool starts running."""
248251 with capture_internal_exceptions ():
249252 if not run_id :
@@ -260,7 +263,7 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
260263 )
261264
262265 def on_tool_end (self , output , * , run_id , ** kwargs ):
263- # type: (str, Any , UUID, Any) -> Any
266+ # type: (SentryLangchainCallback, str , UUID, Dict[str, Any] ) -> Any
264267 """Run when tool ends running."""
265268 with capture_internal_exceptions ():
266269 if not run_id or not self .span_map [run_id ]:
@@ -275,7 +278,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
275278 del self .span_map [run_id ]
276279
277280 def on_tool_error (self , error , * args , run_id , ** kwargs ):
278- # type: (Union[Exception, KeyboardInterrupt], Any, UUID , Any) -> Any
281+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str , Any] ) -> Any
279282 """Run when tool errors."""
280283 self ._handle_error (run_id , error )
281284
@@ -290,7 +293,7 @@ def new_configure(*args, **kwargs):
290293 integration = sentry_sdk .get_client ().get_integration (LangchainIntegration )
291294
292295 with capture_internal_exceptions ():
293- new_callbacks = []
296+ new_callbacks = [] # type: List[BaseCallbackHandler]
294297 if "local_callbacks" in kwargs :
295298 existing_callbacks = kwargs ["local_callbacks" ]
296299 kwargs ["local_callbacks" ] = new_callbacks
0 commit comments