Skip to content

Commit beec36c

Browse files
feat(langchain): Support BaseCallbackManager
While implementing #4479, I noticed that our Langchain integration lacks support for the `local_callbacks` having type `BaseCallbackManager`, which according to the type hint is possible. This change adds support for this case.
1 parent 0a2d858 commit beec36c

File tree

2 files changed

+173
-16
lines changed

2 files changed

+173
-16
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from langchain_core.callbacks import (
2424
manager,
2525
BaseCallbackHandler,
26+
BaseCallbackManager,
2627
Callbacks,
2728
)
2829
from langchain_core.agents import AgentAction, AgentFinish
@@ -437,12 +438,47 @@ def new_configure(
437438
**kwargs,
438439
)
439440

440-
callbacks_list = local_callbacks or []
441+
# Lambda for lazy initialization of the SentryLangchainCallback
442+
sentry_handler_factory = lambda: SentryLangchainCallback(
443+
integration.max_spans,
444+
integration.include_prompts,
445+
integration.tiktoken_encoding_name,
446+
)
447+
448+
local_callbacks = local_callbacks or []
449+
450+
# Handle each possible type of local_callbacks. For each type, we
451+
# extract the list of callbacks to check for SentryLangchainCallback,
452+
# and define a function that would add the SentryLangchainCallback
453+
# to the existing callbacks list.
454+
if isinstance(local_callbacks, BaseCallbackManager):
455+
callbacks_list = local_callbacks.handlers
456+
manager = local_callbacks
457+
458+
# For BaseCallbackManager, we want to copy the manager and add the
459+
# SentryLangchainCallback to the copy.
460+
def local_callbacks_with_sentry():
461+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
462+
new_manager = manager.copy()
463+
new_manager.handlers = [*new_manager.handlers, sentry_handler_factory()]
464+
return new_manager
465+
466+
elif isinstance(local_callbacks, BaseCallbackHandler):
467+
callbacks_list = [local_callbacks]
441468

442-
if isinstance(callbacks_list, BaseCallbackHandler):
443-
callbacks_list = [callbacks_list]
444-
elif not isinstance(callbacks_list, list):
445-
logger.debug("Unknown callback type: %s", callbacks_list)
469+
def local_callbacks_with_sentry():
470+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
471+
return [*callbacks_list, sentry_handler_factory()]
472+
473+
elif isinstance(local_callbacks, list):
474+
callbacks_list = local_callbacks
475+
476+
def local_callbacks_with_sentry():
477+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
478+
return [*callbacks_list, sentry_handler_factory()]
479+
480+
else:
481+
logger.debug("Unknown callback type: %s", local_callbacks)
446482
# Just proceed with original function call
447483
return f(
448484
callback_manager_cls,
@@ -460,20 +496,12 @@ def new_configure(
460496
isinstance(cb, SentryLangchainCallback)
461497
for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
462498
):
463-
# Avoid mutating the existing callbacks list
464-
callbacks_list = [
465-
*callbacks_list,
466-
SentryLangchainCallback(
467-
integration.max_spans,
468-
integration.include_prompts,
469-
integration.tiktoken_encoding_name,
470-
),
471-
]
499+
local_callbacks = local_callbacks_with_sentry()
472500

473501
return f(
474502
callback_manager_cls,
475503
inheritable_callbacks,
476-
callbacks_list,
504+
local_callbacks,
477505
*args,
478506
**kwargs,
479507
)

tests/integrations/langchain/test_langchain.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Any, Iterator
2+
from unittest import mock
23
from unittest.mock import Mock
34

45
import pytest
@@ -12,7 +13,7 @@
1213
# Langchain < 0.2
1314
from langchain_community.chat_models import ChatOpenAI
1415

15-
from langchain_core.callbacks import CallbackManagerForLLMRun
16+
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
1617
from langchain_core.messages import BaseMessage, AIMessageChunk
1718
from langchain_core.outputs import ChatGenerationChunk, ChatResult
1819
from langchain_core.runnables import RunnableConfig
@@ -416,3 +417,131 @@ def _identifying_params(self):
416417

417418
# Verify the callback ID matches our manual callback
418419
assert id(manual_callback) in tracked_callback_instances
420+
421+
422+
def test_langchain_callback_manager(sentry_init):
423+
sentry_init(
424+
integrations=[LangchainIntegration()],
425+
traces_sample_rate=1.0,
426+
)
427+
local_manager = BaseCallbackManager(handlers=[])
428+
429+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
430+
mock_configure = mock_manager_module._configure
431+
432+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
433+
LangchainIntegration.setup_once()
434+
435+
callback_manager_cls = Mock()
436+
437+
mock_manager_module._configure(
438+
callback_manager_cls, local_callbacks=local_manager
439+
)
440+
441+
assert mock_configure.call_count == 1
442+
443+
call_args = mock_configure.call_args
444+
assert call_args.args[0] is callback_manager_cls
445+
446+
passed_manager = call_args.args[2]
447+
assert passed_manager is not local_manager
448+
assert local_manager.handlers == []
449+
450+
[handler] = passed_manager.handlers
451+
assert isinstance(handler, SentryLangchainCallback)
452+
453+
454+
def test_langchain_callback_manager_with_sentry_callback(sentry_init):
455+
sentry_init(
456+
integrations=[LangchainIntegration()],
457+
traces_sample_rate=1.0,
458+
)
459+
sentry_callback = SentryLangchainCallback(0, False)
460+
local_manager = BaseCallbackManager(handlers=[sentry_callback])
461+
462+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
463+
mock_configure = mock_manager_module._configure
464+
465+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
466+
LangchainIntegration.setup_once()
467+
468+
callback_manager_cls = Mock()
469+
470+
mock_manager_module._configure(
471+
callback_manager_cls, local_callbacks=local_manager
472+
)
473+
474+
assert mock_configure.call_count == 1
475+
476+
call_args = mock_configure.call_args
477+
assert call_args.args[0] is callback_manager_cls
478+
479+
passed_manager = call_args.args[2]
480+
assert passed_manager is local_manager
481+
482+
[handler] = passed_manager.handlers
483+
assert handler is sentry_callback
484+
485+
486+
def test_langchain_callback_list(sentry_init):
487+
sentry_init(
488+
integrations=[LangchainIntegration()],
489+
traces_sample_rate=1.0,
490+
)
491+
local_callbacks = []
492+
493+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
494+
mock_configure = mock_manager_module._configure
495+
496+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
497+
LangchainIntegration.setup_once()
498+
499+
callback_manager_cls = Mock()
500+
501+
mock_manager_module._configure(
502+
callback_manager_cls, local_callbacks=local_callbacks
503+
)
504+
505+
assert mock_configure.call_count == 1
506+
507+
call_args = mock_configure.call_args
508+
assert call_args.args[0] is callback_manager_cls
509+
510+
passed_callbacks = call_args.args[2]
511+
assert passed_callbacks is not local_callbacks
512+
assert local_callbacks == []
513+
514+
[handler] = passed_callbacks
515+
assert isinstance(handler, SentryLangchainCallback)
516+
517+
518+
def test_langchain_callback_list_existing_callback(sentry_init):
519+
sentry_init(
520+
integrations=[LangchainIntegration()],
521+
traces_sample_rate=1.0,
522+
)
523+
sentry_callback = SentryLangchainCallback(0, False)
524+
local_callbacks = [sentry_callback]
525+
526+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
527+
mock_configure = mock_manager_module._configure
528+
529+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
530+
LangchainIntegration.setup_once()
531+
532+
callback_manager_cls = Mock()
533+
534+
mock_manager_module._configure(
535+
callback_manager_cls, local_callbacks=local_callbacks
536+
)
537+
538+
assert mock_configure.call_count == 1
539+
540+
call_args = mock_configure.call_args
541+
assert call_args.args[0] is callback_manager_cls
542+
543+
passed_callbacks = call_args.args[2]
544+
assert passed_callbacks is local_callbacks
545+
546+
[handler] = passed_callbacks
547+
assert handler is sentry_callback

0 commit comments

Comments
 (0)