Skip to content

feat(langchain): Support BaseCallbackManager #4486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: szokeasaurusrex/refactor-langchain-args
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from langchain_core.callbacks import (
manager,
BaseCallbackHandler,
BaseCallbackManager,
Callbacks,
)
from langchain_core.agents import AgentAction, AgentFinish
Expand Down Expand Up @@ -436,11 +437,42 @@ def new_configure(
**kwargs,
)

callbacks_list = local_callbacks or []
# Lambda for lazy initialization of the SentryLangchainCallback
sentry_handler_factory = lambda: SentryLangchainCallback(
integration.max_spans,
integration.include_prompts,
integration.tiktoken_encoding_name,
)

local_callbacks = local_callbacks or []

# Handle each possible type of local_callbacks. For each type, we
# extract the list of callbacks to check for SentryLangchainCallback,
# and define a function that would add the SentryLangchainCallback
# to the existing callbacks list.
if isinstance(local_callbacks, BaseCallbackManager):
callbacks_list = local_callbacks.handlers

# For BaseCallbackManager, we want to copy the manager and add the
# SentryLangchainCallback to the copy.
def local_callbacks_with_sentry():
new_manager = local_callbacks.copy()
new_manager.handlers = [*new_manager.handlers, sentry_handler_factory()]
return new_manager

elif isinstance(local_callbacks, BaseCallbackHandler):
callbacks_list = [local_callbacks]

if isinstance(callbacks_list, BaseCallbackHandler):
callbacks_list = [callbacks_list]
elif not isinstance(callbacks_list, list):
def local_callbacks_with_sentry():
return [local_callbacks, sentry_handler_factory()]

elif isinstance(local_callbacks, list):
callbacks_list = local_callbacks

def local_callbacks_with_sentry():
return [*local_callbacks, sentry_handler_factory()]

else:
logger.debug("Unknown callback type: %s", callbacks_list)
# Just proceed with original function call
return f(
Expand All @@ -452,20 +484,12 @@ def new_configure(
)

if not any(isinstance(cb, SentryLangchainCallback) for cb in callbacks_list):
# Avoid mutating the existing callbacks list
callbacks_list = [
*callbacks_list,
SentryLangchainCallback(
integration.max_spans,
integration.include_prompts,
integration.tiktoken_encoding_name,
),
]
local_callbacks = local_callbacks_with_sentry()

return f(
callback_manager_cls,
inheritable_callbacks,
callbacks_list,
local_callbacks,
*args,
**kwargs,
)
Expand Down
136 changes: 134 additions & 2 deletions tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Any, Iterator
from unittest import mock
from unittest.mock import Mock

import pytest
Expand All @@ -12,12 +13,15 @@
# Langchain < 0.2
from langchain_community.chat_models import ChatOpenAI

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk

from sentry_sdk import start_transaction
from sentry_sdk.integrations.langchain import LangchainIntegration
from sentry_sdk.integrations.langchain import (
LangchainIntegration,
SentryLangchainCallback,
)
from langchain.agents import tool, AgentExecutor, create_openai_tools_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

Expand Down Expand Up @@ -342,3 +346,131 @@ def test_span_origin(sentry_init, capture_events):
assert event["contexts"]["trace"]["origin"] == "manual"
for span in event["spans"]:
assert span["origin"] == "auto.ai.langchain"


def test_langchain_callback_manager(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
local_manager = BaseCallbackManager(handlers=[])

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_manager
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_manager = call_args.args[2]
assert passed_manager is not local_manager
assert local_manager.handlers == []

[handler] = passed_manager.handlers
assert isinstance(handler, SentryLangchainCallback)


def test_langchain_callback_manager_with_sentry_callback(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
sentry_callback = SentryLangchainCallback(0, False)
local_manager = BaseCallbackManager(handlers=[sentry_callback])

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_manager
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_manager = call_args.args[2]
assert passed_manager is local_manager

[handler] = passed_manager.handlers
assert handler is sentry_callback


def test_langchain_callback_list(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
local_callbacks = []

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_callbacks
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_callbacks = call_args.args[2]
assert passed_callbacks is not local_callbacks
assert local_callbacks == []

[handler] = passed_callbacks
assert isinstance(handler, SentryLangchainCallback)


def test_langchain_callback_list_existing_callback(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
sentry_callback = SentryLangchainCallback(0, False)
local_callbacks = [sentry_callback]

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_callbacks
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_callbacks = call_args.args[2]
assert passed_callbacks is local_callbacks

[handler] = passed_callbacks
assert handler is sentry_callback
Loading