diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index dc57bfaafde..e0d56e41c79 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -6,26 +6,22 @@ @date:2023/10/20 14:01 @desc: """ -import datetime import logging import os import threading -import time import traceback from typing import List import django.db.models -from django.db import models, transaction from django.db.models import QuerySet from django.db.models.functions import Substr, Reverse from langchain_core.embeddings import Embeddings from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model, native_update -from common.db.sql_execute import sql_execute, update_execute from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from common.util.page_utils import page +from common.util.page_utils import page_desc from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR @@ -162,7 +158,7 @@ def embedding_paragraph_apply(paragraph_list): if is_the_task_interrupted(): break ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model) - post_apply() + post_apply() return embedding_paragraph_apply @@ -241,13 +237,16 @@ def update_status(query_set: QuerySet, taskType: TaskType, state: State): lock.release() @staticmethod - def embedding_by_document(document_id, embedding_model: Embeddings): + def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None): """ 向量化文档 + @param state_list: @param document_id: 文档id @param embedding_model 向量模型 :return: None """ + if state_list is None: + state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED] if not try_lock('embedding' + str(document_id)): return try: @@ -268,11 +267,17 @@ def is_the_task_interrupted(): VectorStore.get_embedding_vector().delete_by_document_id(document_id) # 根据段落进行向量化处理 - page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5, - ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, - ListenerManagement.get_aggregation_document_status( - document_id)), - is_the_task_interrupted) + page_desc(QuerySet(Paragraph) + .annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, + 1), + ).filter(task_type_status__in=state_list, document_id=document_id) + .values('id'), 5, + ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, + ListenerManagement.get_aggregation_document_status( + document_id)), + is_the_task_interrupted) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') finally: diff --git a/apps/common/util/page_utils.py b/apps/common/util/page_utils.py index 92f21849b6d..61c52920d9a 100644 --- a/apps/common/util/page_utils.py +++ b/apps/common/util/page_utils.py @@ -26,3 +26,22 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): offset = i * page_size paragraph_list = query.all()[offset: offset + page_size] handler(paragraph_list) + + +def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False): + """ + + @param query_set: 查询query_set + @param page_size: 每次查询大小 + @param handler: 数据处理器 + @param is_the_task_interrupted: 任务是否被中断 + @return: + """ + query = query_set.order_by("id") + count = query_set.count() + for i in sorted(range(0, ceil(count / page_size)), reverse=True): + if is_the_task_interrupted(): + return + offset = i * page_size + paragraph_list = query.all()[offset: offset + page_size] + handler(paragraph_list) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index b0c2e043016..3d24c01e6f3 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -700,20 +700,24 @@ def edit(self, instance: Dict, with_valid=False): _document.save() return self.one() - @transaction.atomic - def refresh(self, with_valid=True): + def refresh(self, state_list, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING) - ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, + 1), + ).filter(task_type_status__in=state_list, document_id=document_id) + .values('id'), TaskType.EMBEDDING, State.PENDING) ListenerManagement.get_aggregation_document_status(document_id)() embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) try: - embedding_by_document.delay(document_id, embedding_model_id) + embedding_by_document.delay(document_id, embedding_model_id, state_list) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") @@ -1122,14 +1126,14 @@ def batch_refresh(self, instance: Dict, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id_list = instance.get("id_list") - with transaction.atomic(): - dataset_id = self.data.get('dataset_id') - for document_id in document_id_list: - try: - DocumentSerializers.Operate( - data={'dataset_id': dataset_id, 'document_id': document_id}).refresh() - except AlreadyQueued as e: - pass + state_list = instance.get("state_list") + dataset_id = self.data.get('dataset_id') + for document_id in document_id_list: + try: + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).refresh(state_list) + except AlreadyQueued as e: + pass class GenerateRelated(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py index ff72b638f5b..66be0d7f936 100644 --- a/apps/dataset/swagger_api/document_api.py +++ b/apps/dataset/swagger_api/document_api.py @@ -51,3 +51,16 @@ def get_request_body_api(): description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1) } ) + + class EmbeddingState(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'state_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="状态列表", + description="状态列表") + } + ) diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 42218e948db..c074fc543fb 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -262,6 +262,7 @@ class Refresh(APIView): @action(methods=['PUT'], detail=False) @swagger_auto_schema(operation_summary="刷新文档向量库", operation_id="刷新文档向量库", + request_body=DocumentApi.EmbeddingState.get_request_body_api(), manual_parameters=DocumentSerializers.Operate.get_request_params_api(), responses=result.get_default_response(), tags=["知识库/文档"] @@ -272,6 +273,7 @@ class Refresh(APIView): def put(self, request: Request, dataset_id: str, document_id: str): return result.success( DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh( + request.data.get('state_list') )) class BatchRefresh(APIView): diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 7c39ebcddee..026a109ce90 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -56,14 +56,20 @@ def embedding_by_paragraph_list(paragraph_id_list, model_id): @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document') -def embedding_by_document(document_id, model_id): +def embedding_by_document(document_id, model_id, state_list=None): """ 向量化文档 + @param state_list: @param document_id: 文档id @param model_id 向量模型 :return: None """ + if state_list is None: + state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, + State.REVOKE.value, + State.REVOKED.value, State.IGNORED.value] + def exception_handler(e): ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.FAILURE) @@ -71,7 +77,7 @@ def exception_handler(e): f'获取向量模型失败:{str(e)}{traceback.format_exc()}') embedding_model = get_embedding_model(model_id, exception_handler) - ListenerManagement.embedding_by_document(document_id, embedding_model) + ListenerManagement.embedding_by_document(document_id, embedding_model, state_list) @celery_app.task(name='celery:embedding_by_document_list') diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 7ad275949c1..307eb34883c 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -129,11 +129,12 @@ const delMulDocument: ( const batchRefresh: ( dataset_id: string, data: any, + stateList: Array, loading?: Ref -) => Promise> = (dataset_id, data, loading) => { +) => Promise> = (dataset_id, data, stateList, loading) => { return put( `${prefix}/${dataset_id}/document/batch_refresh`, - { id_list: data }, + { id_list: data, state_list: stateList }, undefined, loading ) @@ -157,11 +158,12 @@ const getDocumentDetail: (dataset_id: string, document_id: string) => Promise, loading?: Ref -) => Promise> = (dataset_id, document_id, loading) => { +) => Promise> = (dataset_id, document_id, state_list, loading) => { return put( `${prefix}/${dataset_id}/document/${document_id}/refresh`, - undefined, + { state_list }, undefined, loading ) diff --git a/ui/src/views/document/component/EmbeddingContentDialog.vue b/ui/src/views/document/component/EmbeddingContentDialog.vue new file mode 100644 index 00000000000..64cb6f8a80b --- /dev/null +++ b/ui/src/views/document/component/EmbeddingContentDialog.vue @@ -0,0 +1,41 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index c7d0d3cf82e..3f79be80eee 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -422,6 +422,7 @@ 清空 +