diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index f3602901aa5..0ce1d5fedd1 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -25,13 +25,14 @@ from .start_node import * from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode from .variable_assign_node import BaseVariableAssignNode +from .mcp_node import BaseMcpNode node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode, - BaseImageGenerateNode, BaseVariableAssignNode] + BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 336753450fa..a83d2ef5771 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -33,6 +33,9 @@ class ChatNodeSerializer(serializers.Serializer): error_messages=ErrMessage.dict('Model settings')) dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char(_("Context Type"))) + mcp_enable = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean(_("Whether to enable MCP"))) + mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server"))) class IChatNode(INode): @@ -49,5 +52,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record model_params_setting=None, dialogue_type=None, model_setting=None, + mcp_enable=False, + mcp_servers=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 7da004822fe..ae84c4bbbd7 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,14 +6,19 @@ @date:2024/6/4 14:30 @desc: """ +import asyncio +import json import re import time from functools import reduce +from types import AsyncGeneratorType from typing import List, Dict from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage -from langchain_core.messages import BaseMessage, AIMessage +from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.prebuilt import create_react_agent from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode @@ -56,6 +61,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo reasoning = Reasoning(model_setting.get('reasoning_content_start', ''), model_setting.get('reasoning_content_end', '')) response_reasoning_content = False + for chunk in response: reasoning_chunk = reasoning.get_reasoning_content(chunk) content_chunk = reasoning_chunk.get('content') @@ -84,6 +90,47 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + +async def _yield_mcp_response(chat_model, message_list, mcp_servers): + async with MultiServerMCPClient(json.loads(mcp_servers)) as client: + agent = create_react_agent(chat_model, client.get_tools()) + response = agent.astream({"messages": message_list}, stream_mode='messages') + async for chunk in response: + # if isinstance(chunk[0], ToolMessage): + # print(chunk[0]) + if isinstance(chunk[0], AIMessageChunk): + yield chunk[0] + +def mcp_response_generator(chat_model, message_list, mcp_servers): + loop = asyncio.new_event_loop() + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers) + while True: + try: + chunk = loop.run_until_complete(anext_async(async_gen)) + yield chunk + except StopAsyncIteration: + break + except Exception as e: + print(f'exception: {e}') + finally: + loop.close() + +async def anext_async(agen): + return await agen.__anext__() + +async def _get_mcp_response(chat_model, message_list, mcp_servers): + async with MultiServerMCPClient(json.loads(mcp_servers)) as client: + agent = create_react_agent(chat_model, client.get_tools()) + response = agent.astream({"messages": message_list}, stream_mode='messages') + result = [] + async for chunk in response: + # if isinstance(chunk[0], ToolMessage): + # print(chunk[0].content) + if isinstance(chunk[0], AIMessageChunk): + result.append(chunk[0]) + return result + def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 @@ -142,6 +189,8 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record model_params_setting=None, dialogue_type=None, model_setting=None, + mcp_enable=False, + mcp_servers=None, **kwargs) -> NodeResult: if dialogue_type is None: dialogue_type = 'WORKFLOW' @@ -163,6 +212,14 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record self.context['system'] = system message_list = self.generate_message_list(system, prompt, history_message) self.context['message_list'] = message_list + + if mcp_enable and mcp_servers is not None: + r = mcp_response_generator(chat_model, message_list, mcp_servers) + return NodeResult( + {'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, diff --git a/apps/application/flow/step_node/mcp_node/__init__.py b/apps/application/flow/step_node/mcp_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/mcp_node/i_mcp_node.py b/apps/application/flow/step_node/mcp_node/i_mcp_node.py new file mode 100644 index 00000000000..94cb4da7729 --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/i_mcp_node.py @@ -0,0 +1,35 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class McpNodeSerializer(serializers.Serializer): + mcp_servers = serializers.JSONField(required=True, + error_messages=ErrMessage.char(_("Mcp servers"))) + + mcp_server = serializers.CharField(required=True, + error_messages=ErrMessage.char(_("Mcp server"))) + + mcp_tool = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Mcp tool"))) + + tool_params = serializers.DictField(required=True, + error_messages=ErrMessage.char(_("Tool parameters"))) + + +class IMcpNode(INode): + type = 'mcp-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return McpNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/mcp_node/impl/__init__.py b/apps/application/flow/step_node/mcp_node/impl/__init__.py new file mode 100644 index 00000000000..8c9a5ee197c --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_mcp_node import BaseMcpNode diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py new file mode 100644 index 00000000000..b3c1f2d9bed --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py @@ -0,0 +1,56 @@ +# coding=utf-8 +import asyncio +import json +from typing import List + +from langchain_mcp_adapters.client import MultiServerMCPClient + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode + + +class BaseMcpNode(IMcpNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.context['tool_params'] = details.get('tool_params') + self.context['mcp_tool'] = details.get('mcp_tool') + self.answer_text = details.get('result') + + def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: + servers = json.loads(mcp_servers) + params = self.handle_variables(tool_params) + + async def call_tool(s, session, t, a): + async with MultiServerMCPClient(s) as client: + s = await client.sessions[session].call_tool(t, a) + return s + + res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params)) + return NodeResult({'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) + + def handle_variables(self, tool_params): + # 处理参数中的变量 + for k, v in tool_params.items(): + if type(v) == str: + tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k]) + if type(v) == dict: + self.handle_variables(v) + return tool_params + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'status': self.status, + 'err_message': self.err_message, + 'type': self.node.type, + 'mcp_tool': self.context.get('mcp_tool'), + 'tool_params': self.context.get('tool_params'), + 'result': self.context.get('result'), + } diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 69d518cfd1c..514a9e5f00a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -6,6 +6,7 @@ @date:2023/11/7 10:02 @desc: """ +import asyncio import datetime import hashlib import json @@ -23,6 +24,8 @@ from django.db.models.expressions import RawSQL from django.http import HttpResponse from django.template import Template, Context +from langchain_mcp_adapters.client import MultiServerMCPClient +from mcp.client.sse import sse_client from rest_framework import serializers, status from rest_framework.utils.formatting import lazy_format @@ -1305,3 +1308,28 @@ def edit(self, instance, with_valid=True): application_api_key.save() # 写入缓存 get_application_api_key(application_api_key.secret_key, False) + + class McpServers(serializers.Serializer): + mcp_servers = serializers.JSONField(required=True) + + def get_mcp_servers(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + servers = json.loads(self.data.get('mcp_servers')) + + async def get_mcp_tools(servers): + async with MultiServerMCPClient(servers) as client: + return client.get_tools() + + tools = [] + for server in servers: + tools += [ + { + 'server': server, + 'name': tool.name, + 'description': tool.description, + 'args_schema': tool.args_schema, + } + for tool in asyncio.run(get_mcp_tools({server: servers[server]}))] + return tools + diff --git a/apps/application/urls.py b/apps/application/urls.py index 6dc2ae5af63..b294289541e 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -9,6 +9,7 @@ path('application/profile', views.Application.Profile.as_view(), name='application/profile'), path('application/embed', views.Application.Embed.as_view()), path('application/authentication', views.Application.Authentication.as_view()), + path('application/mcp_servers', views.Application.McpServers.as_view()), path('application//publish', views.Application.Publish.as_view()), path('application//edit_icon', views.Application.EditIcon.as_view()), path('application//export', views.Application.Export.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index ab97de6262e..991a6f4d5dc 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -700,3 +700,13 @@ def post(self, request: Request, application_id: str): data={'application_id': application_id, 'user_id': request.user.id}).play_demo_text(request.data) return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + + class McpServers(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + @log(menu='Application', operate="Get the MCP server tools") + def get(self, request: Request): + return result.success(ApplicationSerializer.McpServers( + data={'mcp_servers': request.query_params.get('mcp_servers')}).get_mcp_servers()) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index a85b031c00f..efd4a4985a8 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -350,6 +350,13 @@ const getFunctionLib: ( return get(`${prefix}/${application_id}/function_lib/${function_lib_id}`, undefined, loading) } +const getMcpTools: ( + data: any, + loading?: Ref +) => Promise> = (data, loading) => { + return get(`${prefix}/mcp_servers`, data, loading) +} + const getApplicationById: ( application_id: String, app_id: String, @@ -576,5 +583,6 @@ export default { uploadFile, exportApplication, importApplication, - getApplicationById + getApplicationById, + getMcpTools } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index cfb6f54d0ae..c36e36e52c6 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -639,6 +639,40 @@ + + +