Skip to content

feat(langchain): Support BaseCallbackManager #4486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from langchain_core.callbacks import (
manager,
BaseCallbackHandler,
BaseCallbackManager,
Callbacks,
)
from langchain_core.agents import AgentAction, AgentFinish
Expand Down Expand Up @@ -434,12 +435,20 @@ def new_configure(
**kwargs,
)

callbacks_list = local_callbacks or []

if isinstance(callbacks_list, BaseCallbackHandler):
callbacks_list = [callbacks_list]
elif not isinstance(callbacks_list, list):
logger.debug("Unknown callback type: %s", callbacks_list)
local_callbacks = local_callbacks or []

# Handle each possible type of local_callbacks. For each type, we
# extract the list of callbacks to check for SentryLangchainCallback,
# and define a function that would add the SentryLangchainCallback
# to the existing callbacks list.
if isinstance(local_callbacks, BaseCallbackManager):
callbacks_list = local_callbacks.handlers
elif isinstance(local_callbacks, BaseCallbackHandler):
callbacks_list = [local_callbacks]
elif isinstance(local_callbacks, list):
callbacks_list = local_callbacks
else:
logger.debug("Unknown callback type: %s", local_callbacks)
# Just proceed with original function call
return f(
callback_manager_cls,
Expand All @@ -449,28 +458,38 @@ def new_configure(
**kwargs,
)

inheritable_callbacks_list = (
inheritable_callbacks if isinstance(inheritable_callbacks, list) else []
)
# Handle each possible type of inheritable_callbacks.
if isinstance(inheritable_callbacks, BaseCallbackManager):
inheritable_callbacks_list = inheritable_callbacks.handlers
elif isinstance(inheritable_callbacks, list):
inheritable_callbacks_list = inheritable_callbacks
else:
inheritable_callbacks_list = []

if not any(
isinstance(cb, SentryLangchainCallback)
for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
):
# Avoid mutating the existing callbacks list
callbacks_list = [
*callbacks_list,
SentryLangchainCallback(
integration.max_spans,
integration.include_prompts,
integration.tiktoken_encoding_name,
),
]
sentry_handler = SentryLangchainCallback(
integration.max_spans,
integration.include_prompts,
integration.tiktoken_encoding_name,
)
if isinstance(local_callbacks, BaseCallbackManager):
local_callbacks = local_callbacks.copy()
local_callbacks.handlers = [
*local_callbacks.handlers,
sentry_handler,
]
elif isinstance(local_callbacks, BaseCallbackHandler):
local_callbacks = [local_callbacks, sentry_handler]
else: # local_callbacks is a list
local_callbacks = [*local_callbacks, sentry_handler]

return f(
callback_manager_cls,
inheritable_callbacks,
callbacks_list,
local_callbacks,
*args,
**kwargs,
)
Expand Down
131 changes: 130 additions & 1 deletion tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Any, Iterator
from unittest import mock
from unittest.mock import Mock

import pytest
Expand All @@ -12,7 +13,7 @@
# Langchain < 0.2
from langchain_community.chat_models import ChatOpenAI

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -428,3 +429,131 @@ def test_span_map_is_instance_variable():
assert (
callback1.span_map is not callback2.span_map
), "span_map should be an instance variable, not shared between instances"


def test_langchain_callback_manager(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
local_manager = BaseCallbackManager(handlers=[])

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_manager
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_manager = call_args.args[2]
assert passed_manager is not local_manager
assert local_manager.handlers == []

[handler] = passed_manager.handlers
assert isinstance(handler, SentryLangchainCallback)


def test_langchain_callback_manager_with_sentry_callback(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
sentry_callback = SentryLangchainCallback(0, False)
local_manager = BaseCallbackManager(handlers=[sentry_callback])

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_manager
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_manager = call_args.args[2]
assert passed_manager is local_manager

[handler] = passed_manager.handlers
assert handler is sentry_callback


def test_langchain_callback_list(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
local_callbacks = []

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_callbacks
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_callbacks = call_args.args[2]
assert passed_callbacks is not local_callbacks
assert local_callbacks == []

[handler] = passed_callbacks
assert isinstance(handler, SentryLangchainCallback)


def test_langchain_callback_list_existing_callback(sentry_init):
sentry_init(
integrations=[LangchainIntegration()],
traces_sample_rate=1.0,
)
sentry_callback = SentryLangchainCallback(0, False)
local_callbacks = [sentry_callback]

with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
mock_configure = mock_manager_module._configure

# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
LangchainIntegration.setup_once()

callback_manager_cls = Mock()

mock_manager_module._configure(
callback_manager_cls, local_callbacks=local_callbacks
)

assert mock_configure.call_count == 1

call_args = mock_configure.call_args
assert call_args.args[0] is callback_manager_cls

passed_callbacks = call_args.args[2]
assert passed_callbacks is local_callbacks

[handler] = passed_callbacks
assert handler is sentry_callback
Loading