diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index a990b811637..856f3da1584 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -222,3 +222,26 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): if len(dataset_list) == 0: raise Exception(_('Knowledge base setting error, please reset the knowledge base')) return str(dataset_list[0].embedding_mode_id) + + +class GenerateRelatedSerializer(ApiMixin, serializers.Serializer): + model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id'))) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word'))) + state_list = serializers.ListField(required=False, child=serializers.CharField(required=True), + error_messages=ErrMessage.list("state list")) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, + title=_('Model id'), + description=_('Model id')), + 'prompt': openapi.Schema(type=openapi.TYPE_STRING, title=_('Prompt word'), + description=_("Prompt word")), + 'state_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title=_('state list')) + } + ) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index f9e9a713258..5b94c6c0ce9 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -23,6 +23,7 @@ from django.core import validators from django.db import transaction, models from django.db.models import QuerySet +from django.db.models.functions import Reverse, Substr from django.http import HttpResponse from drf_yasg import openapi from rest_framework import serializers @@ -42,9 +43,10 @@ from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \ State, File, Image from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir + get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \ + GenerateRelatedSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer -from dataset.task import sync_web_dataset, sync_replace_web_dataset +from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id from embedding.models import SearchMode from embedding.task import embedding_by_dataset, delete_embedding_by_dataset from setting.models import AuthOperate, Model @@ -814,6 +816,31 @@ def re_embedding(self, with_valid=True): except AlreadyQueued as e: raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) + def generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True) + dataset_id = self.data.get('id') + model_id = instance.get("model_id") + prompt = instance.get("prompt") + state_list = instance.get('state_list') + ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=dataset_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, + 1), + ).filter(task_type_status__in=state_list, dataset_id=dataset_id) + .values('id'), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.get_aggregation_document_status_by_dataset_id(dataset_id)() + try: + generate_related_by_dataset_id.delay(dataset_id, model_id, prompt, state_list) + except AlreadyQueued as e: + raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) + def list_application(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py index fdcd171af79..53b0c71ff06 100644 --- a/apps/dataset/task/generate.py +++ b/apps/dataset/task/generate.py @@ -64,6 +64,17 @@ def is_the_task_interrupted(): return is_the_task_interrupted +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, + name='celery:generate_related_by_dataset') +def generate_related_by_dataset_id(dataset_id, model_id, prompt, state_list=None): + document_list = QuerySet(Document).filter(dataset_id=dataset_id) + for document in document_list: + try: + generate_related_by_document_id.delay(document.id, model_id, prompt, state_list) + except Exception as e: + pass + + @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') def generate_related_by_document_id(document_id, model_id, prompt, state_list=None): diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 50278ac9b0b..302b953ec36 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -11,6 +11,8 @@ path('dataset//export', views.Dataset.Export.as_view(), name="export"), path('dataset//export_zip', views.Dataset.ExportZip.as_view(), name="export_zip"), path('dataset//re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"), + path('dataset//generate_related', views.Dataset.GenerateRelated.as_view(), + name="dataset_generate_related"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index efd50c6bfc3..bbb9e033980 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -21,6 +21,7 @@ from common.response import result from common.response.result import get_page_request_params, get_page_api_response, get_api_response from common.swagger_api.common_api import CommonApi +from dataset.serializers.common_serializers import GenerateRelatedSerializer from dataset.serializers.dataset_serializers import DataSetSerializers from dataset.views.common import get_dataset_operation_object from setting.serializers.provider_serializers import ModelSerializer @@ -173,6 +174,23 @@ def put(self, request: Request, dataset_id: str): return result.success( DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding()) + class GenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary=_('Generate related'), operation_id=_('Generate related'), + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + request_body=GenerateRelatedSerializer.get_request_body_api(), + tags=[_('Knowledge Base')] + ) + @log(menu='document', operate="Generate related documents", + get_operation_object=lambda r, keywords: get_dataset_operation_object(keywords.get('dataset_id')) + ) + def put(self, request: Request, dataset_id: str): + return result.success( + DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).generate_related( + request.data)) + class Export(APIView): authentication_classes = [TokenAuth] diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index f3cfb56c822..d13912928b9 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -7487,4 +7487,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types" msgstr "" msgid "Field: {name} No value set" +msgstr "" + +msgid "Generate related" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 57f2195cd77..b0ab7871bf6 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -7650,4 +7650,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types" msgstr "字段: {name} 类型: {_type} 值: {value} 不支持的类型" msgid "Field: {name} No value set" -msgstr "字段: {name} 未设置值" \ No newline at end of file +msgstr "字段: {name} 未设置值" + +msgid "Generate related" +msgstr "生成问题" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index 99c9c2ca4a2..dab1d176c26 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -7660,4 +7660,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types" msgstr "欄位: {name} 類型: {_type} 值: {value} 不支持的類型" msgid "Field: {name} No value set" -msgstr "欄位: {name} 未設定值" \ No newline at end of file +msgstr "欄位: {name} 未設定值" + +msgid "Generate related" +msgstr "生成問題" \ No newline at end of file diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 4223cb1e61a..a5a663b03c7 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -277,6 +277,20 @@ const importLarkDocument: ( ) => Promise>> = (dataset_id, data, loading) => { return post(`${prefix}/lark/${dataset_id}/import`, data, null, loading) } +/** + * 生成关联问题 + * @param dataset_id 知识库id + * @param data + * @param loading + * @returns + */ +const generateRelated: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise>> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/generate_related`, data, null, loading) +} export default { getDataset, @@ -297,5 +311,6 @@ export default { postLarkDataset, getLarkDocumentList, importLarkDocument, - putLarkDataset + putLarkDataset, + generateRelated } diff --git a/ui/src/components/generate-related-dialog/index.vue b/ui/src/components/generate-related-dialog/index.vue index 01fb68ac4e0..2284e3a590c 100644 --- a/ui/src/components/generate-related-dialog/index.vue +++ b/ui/src/components/generate-related-dialog/index.vue @@ -51,7 +51,7 @@ /> @@ -107,6 +107,7 @@ const stateMap = { error: ['0', '1', '3', '4', '5', 'n'] } const FormRef = ref() +const datasetId = ref() const userId = user.userInfo?.id as string const form = ref(prompt.get(userId)) const rules = reactive({ @@ -133,7 +134,8 @@ watch(dialogVisible, (bool) => { } }) -const open = (ids: string[], type: string) => { +const open = (ids: string[], type: string, _datasetId?: string) => { + datasetId.value = _datasetId getModel() idList.value = ids apiType.value = type @@ -169,6 +171,15 @@ const submitHandle = async (formEl: FormInstance) => { emit('refresh') dialogVisible.value = false }) + } else if (apiType.value === 'dataset') { + const data = { + ...form.value, + state_list: stateMap[state.value] + } + datasetApi.generateRelated(id ? id : datasetId.value, data, loading).then(() => { + MsgSuccess(t('views.document.generateQuestion.successMessage')) + dialogVisible.value = false + }) } } }) @@ -177,7 +188,7 @@ const submitHandle = async (formEl: FormInstance) => { function getModel() { loading.value = true datasetApi - .getDatasetModel(id) + .getDatasetModel(id ? id : datasetId.value) .then((res: any) => { modelOptions.value = groupBy(res?.data, 'provider') loading.value = false diff --git a/ui/src/views/dataset/index.vue b/ui/src/views/dataset/index.vue index f1cd824480d..d13e2dc2dd9 100644 --- a/ui/src/views/dataset/index.vue +++ b/ui/src/views/dataset/index.vue @@ -127,6 +127,7 @@ v-if="item.type === '1'" >{{ $t('views.dataset.setting.sync') }} + {{ $t('views.dataset.setting.vectorization') }} + {{ $t('views.document.generateQuestion.title') }} +