Skip to content

Commit a1f62ba

Browse files
feat(langchain): Support BaseCallbackManager (#4486)
While implementing #4479, I noticed that our Langchain integration lacks support for the `local_callbacks` having type `BaseCallbackManager`, which according to the type hint is possible. This change adds support for this case. Fixes #4537 --- Thank you for contributing to `sentry-python`! Please add tests to validate your changes, and lint your code using `tox -e linters`. Running the test suite on your PR might require maintainer approval.
1 parent c31ba06 commit a1f62ba

File tree

2 files changed

+168
-20
lines changed

2 files changed

+168
-20
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from langchain_core.callbacks import (
2424
manager,
2525
BaseCallbackHandler,
26+
BaseCallbackManager,
2627
Callbacks,
2728
)
2829
from langchain_core.agents import AgentAction, AgentFinish
@@ -434,12 +435,20 @@ def new_configure(
434435
**kwargs,
435436
)
436437

437-
callbacks_list = local_callbacks or []
438-
439-
if isinstance(callbacks_list, BaseCallbackHandler):
440-
callbacks_list = [callbacks_list]
441-
elif not isinstance(callbacks_list, list):
442-
logger.debug("Unknown callback type: %s", callbacks_list)
438+
local_callbacks = local_callbacks or []
439+
440+
# Handle each possible type of local_callbacks. For each type, we
441+
# extract the list of callbacks to check for SentryLangchainCallback,
442+
# and define a function that would add the SentryLangchainCallback
443+
# to the existing callbacks list.
444+
if isinstance(local_callbacks, BaseCallbackManager):
445+
callbacks_list = local_callbacks.handlers
446+
elif isinstance(local_callbacks, BaseCallbackHandler):
447+
callbacks_list = [local_callbacks]
448+
elif isinstance(local_callbacks, list):
449+
callbacks_list = local_callbacks
450+
else:
451+
logger.debug("Unknown callback type: %s", local_callbacks)
443452
# Just proceed with original function call
444453
return f(
445454
callback_manager_cls,
@@ -449,28 +458,38 @@ def new_configure(
449458
**kwargs,
450459
)
451460

452-
inheritable_callbacks_list = (
453-
inheritable_callbacks if isinstance(inheritable_callbacks, list) else []
454-
)
461+
# Handle each possible type of inheritable_callbacks.
462+
if isinstance(inheritable_callbacks, BaseCallbackManager):
463+
inheritable_callbacks_list = inheritable_callbacks.handlers
464+
elif isinstance(inheritable_callbacks, list):
465+
inheritable_callbacks_list = inheritable_callbacks
466+
else:
467+
inheritable_callbacks_list = []
455468

456469
if not any(
457470
isinstance(cb, SentryLangchainCallback)
458471
for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
459472
):
460-
# Avoid mutating the existing callbacks list
461-
callbacks_list = [
462-
*callbacks_list,
463-
SentryLangchainCallback(
464-
integration.max_spans,
465-
integration.include_prompts,
466-
integration.tiktoken_encoding_name,
467-
),
468-
]
473+
sentry_handler = SentryLangchainCallback(
474+
integration.max_spans,
475+
integration.include_prompts,
476+
integration.tiktoken_encoding_name,
477+
)
478+
if isinstance(local_callbacks, BaseCallbackManager):
479+
local_callbacks = local_callbacks.copy()
480+
local_callbacks.handlers = [
481+
*local_callbacks.handlers,
482+
sentry_handler,
483+
]
484+
elif isinstance(local_callbacks, BaseCallbackHandler):
485+
local_callbacks = [local_callbacks, sentry_handler]
486+
else: # local_callbacks is a list
487+
local_callbacks = [*local_callbacks, sentry_handler]
469488

470489
return f(
471490
callback_manager_cls,
472491
inheritable_callbacks,
473-
callbacks_list,
492+
local_callbacks,
474493
*args,
475494
**kwargs,
476495
)

tests/integrations/langchain/test_langchain.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Any, Iterator
2+
from unittest import mock
23
from unittest.mock import Mock
34

45
import pytest
@@ -12,7 +13,7 @@
1213
# Langchain < 0.2
1314
from langchain_community.chat_models import ChatOpenAI
1415

15-
from langchain_core.callbacks import CallbackManagerForLLMRun
16+
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
1617
from langchain_core.messages import BaseMessage, AIMessageChunk
1718
from langchain_core.outputs import ChatGenerationChunk, ChatResult
1819
from langchain_core.runnables import RunnableConfig
@@ -428,3 +429,131 @@ def test_span_map_is_instance_variable():
428429
assert (
429430
callback1.span_map is not callback2.span_map
430431
), "span_map should be an instance variable, not shared between instances"
432+
433+
434+
def test_langchain_callback_manager(sentry_init):
435+
sentry_init(
436+
integrations=[LangchainIntegration()],
437+
traces_sample_rate=1.0,
438+
)
439+
local_manager = BaseCallbackManager(handlers=[])
440+
441+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
442+
mock_configure = mock_manager_module._configure
443+
444+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
445+
LangchainIntegration.setup_once()
446+
447+
callback_manager_cls = Mock()
448+
449+
mock_manager_module._configure(
450+
callback_manager_cls, local_callbacks=local_manager
451+
)
452+
453+
assert mock_configure.call_count == 1
454+
455+
call_args = mock_configure.call_args
456+
assert call_args.args[0] is callback_manager_cls
457+
458+
passed_manager = call_args.args[2]
459+
assert passed_manager is not local_manager
460+
assert local_manager.handlers == []
461+
462+
[handler] = passed_manager.handlers
463+
assert isinstance(handler, SentryLangchainCallback)
464+
465+
466+
def test_langchain_callback_manager_with_sentry_callback(sentry_init):
467+
sentry_init(
468+
integrations=[LangchainIntegration()],
469+
traces_sample_rate=1.0,
470+
)
471+
sentry_callback = SentryLangchainCallback(0, False)
472+
local_manager = BaseCallbackManager(handlers=[sentry_callback])
473+
474+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
475+
mock_configure = mock_manager_module._configure
476+
477+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
478+
LangchainIntegration.setup_once()
479+
480+
callback_manager_cls = Mock()
481+
482+
mock_manager_module._configure(
483+
callback_manager_cls, local_callbacks=local_manager
484+
)
485+
486+
assert mock_configure.call_count == 1
487+
488+
call_args = mock_configure.call_args
489+
assert call_args.args[0] is callback_manager_cls
490+
491+
passed_manager = call_args.args[2]
492+
assert passed_manager is local_manager
493+
494+
[handler] = passed_manager.handlers
495+
assert handler is sentry_callback
496+
497+
498+
def test_langchain_callback_list(sentry_init):
499+
sentry_init(
500+
integrations=[LangchainIntegration()],
501+
traces_sample_rate=1.0,
502+
)
503+
local_callbacks = []
504+
505+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
506+
mock_configure = mock_manager_module._configure
507+
508+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
509+
LangchainIntegration.setup_once()
510+
511+
callback_manager_cls = Mock()
512+
513+
mock_manager_module._configure(
514+
callback_manager_cls, local_callbacks=local_callbacks
515+
)
516+
517+
assert mock_configure.call_count == 1
518+
519+
call_args = mock_configure.call_args
520+
assert call_args.args[0] is callback_manager_cls
521+
522+
passed_callbacks = call_args.args[2]
523+
assert passed_callbacks is not local_callbacks
524+
assert local_callbacks == []
525+
526+
[handler] = passed_callbacks
527+
assert isinstance(handler, SentryLangchainCallback)
528+
529+
530+
def test_langchain_callback_list_existing_callback(sentry_init):
531+
sentry_init(
532+
integrations=[LangchainIntegration()],
533+
traces_sample_rate=1.0,
534+
)
535+
sentry_callback = SentryLangchainCallback(0, False)
536+
local_callbacks = [sentry_callback]
537+
538+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
539+
mock_configure = mock_manager_module._configure
540+
541+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
542+
LangchainIntegration.setup_once()
543+
544+
callback_manager_cls = Mock()
545+
546+
mock_manager_module._configure(
547+
callback_manager_cls, local_callbacks=local_callbacks
548+
)
549+
550+
assert mock_configure.call_count == 1
551+
552+
call_args = mock_configure.call_args
553+
assert call_args.args[0] is callback_manager_cls
554+
555+
passed_callbacks = call_args.args[2]
556+
assert passed_callbacks is local_callbacks
557+
558+
[handler] = passed_callbacks
559+
assert handler is sentry_callback

0 commit comments

Comments
 (0)