diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 16d33e66202..099d1f24e7f 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -7,6 +7,7 @@ @desc: """ import os +import uuid from typing import List from django.db.models import QuerySet @@ -20,7 +21,7 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph +from dataset.models import Paragraph, Problem, ProblemParagraphMapping from smartdoc.conf import PROJECT_DIR @@ -79,3 +80,53 @@ def get_request_body_api(): description="主键id列表") } ) + + +class ProblemParagraphObject: + def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str): + self.dataset_id = dataset_id + self.document_id = document_id + self.paragraph_id = paragraph_id + self.problem_content = problem_content + + +def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict): + if content in problem_content_dict: + return problem_content_dict.get(content)[0], document_id, paragraph_id + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + problem_content_dict[content] = exists[0], False + return exists[0], document_id, paragraph_id + else: + problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id) + problem_content_dict[content] = problem, True + return problem, document_id, paragraph_id + + +class ProblemParagraphManage: + def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id): + self.dataset_id = dataset_id + self.problemParagraphObjectList = problemParagraphObjectList + + def to_problem_model_list(self): + problem_list = [item.problem_content for item in self.problemParagraphObjectList] + exists_problem_list = [] + if len(self.problemParagraphObjectList) > 0: + # 查询到已存在的问题列表 + exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id, + content__in=problem_list).all() + problem_content_dict = {} + problem_model_list = [ + or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id, + problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for + problemParagraphObject in self.problemParagraphObjectList] + + problem_paragraph_mapping_list = [ + ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id, + paragraph_id=paragraph_id, + dataset_id=self.dataset_id) for + problem_model, document_id, paragraph_id in problem_model_list] + + result = [problem_model for problem_model, is_create in problem_content_dict.values() if + is_create], problem_paragraph_mapping_list + return result diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 1e60a910313..629719bff81 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -37,7 +37,7 @@ from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping -from dataset.serializers.common_serializers import list_paragraph, MetaSerializer +from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from embedding.models import SearchMode from setting.models import AuthOperate @@ -383,8 +383,7 @@ def save(self, instance: Dict, with_valid=True): document_model_list = [] paragraph_model_list = [] - problem_model_list = [] - problem_paragraph_mapping_list = [] + problem_paragraph_object_list = [] # 插入文档 for document in instance.get('documents') if 'documents' in instance else []: document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, @@ -392,12 +391,12 @@ def save(self, instance: Dict, with_valid=True): document_model_list.append(document_paragraph_dict_model.get('document')) for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): paragraph_model_list.append(paragraph) - for problem in document_paragraph_dict_model.get('problem_model_list'): - problem_model_list.append(problem) - for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'): - problem_paragraph_mapping_list.append(problem_paragraph_mapping) - problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model( - problem_model_list, problem_paragraph_mapping_list) + for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_paragraph_object) + + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id) + .to_problem_model_list()) # 插入知识库 dataset.save() # 插入文档 diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 90b2701b469..263f64c5879 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -41,7 +41,7 @@ from common.util.fork import Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image -from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer +from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -380,8 +380,9 @@ def sync(self, with_valid=True, with_embedding=True): document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) paragraph_model_list = document_paragraph_model.get('paragraph_model_list') - problem_model_list = document_paragraph_model.get('problem_model_list') - problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') + problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage( + problem_paragraph_object_list, document.dataset_id).to_problem_model_list() # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 @@ -626,11 +627,13 @@ def save(self, instance: Dict, with_valid=False, **kwargs): self.is_valid(raise_exception=True) dataset_id = self.data.get('dataset_id') document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance) + document_model = document_paragraph_model.get('document') paragraph_model_list = document_paragraph_model.get('paragraph_model_list') - problem_model_list = document_paragraph_model.get('problem_model_list') - problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') - + problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id) + .to_problem_model_list()) # 插入文档 document_model.save() # 批量插入段落 @@ -685,35 +688,15 @@ def get_paragraph_model(document_model, paragraph_list: List): dataset_id, document_model.id, paragraph) for paragraph in paragraph_list] paragraph_model_list = [] - problem_model_list = [] - problem_paragraph_mapping_list = [] + problem_paragraph_object_list = [] for paragraphs in paragraph_model_dict_list: paragraph = paragraphs.get('paragraph') - for problem_model in paragraphs.get('problem_model_list'): - problem_model_list.append(problem_model) - for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'): - problem_paragraph_mapping_list.append(problem_paragraph_mapping) + for problem_model in paragraphs.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_model) paragraph_model_list.append(paragraph) - problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model( - problem_model_list, problem_paragraph_mapping_list) - return {'document': document_model, 'paragraph_model_list': paragraph_model_list, - 'problem_model_list': problem_model_list, - 'problem_paragraph_mapping_list': problem_paragraph_mapping_list} - - @staticmethod - def reset_problem_model(problem_model_list, problem_paragraph_mapping_list): - new_problem_model_list = [x for i, x in enumerate(problem_model_list) if - len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0] - - for new_problem_model in new_problem_model_list: - old_model_list = [problem.id for problem in problem_model_list if - problem.content == new_problem_model.content] - for problem_paragraph_mapping in problem_paragraph_mapping_list: - if old_model_list.__contains__(problem_paragraph_mapping.problem_id): - problem_paragraph_mapping.problem_id = new_problem_model.id - return new_problem_model_list, problem_paragraph_mapping_list + 'problem_paragraph_object_list': problem_paragraph_object_list} @staticmethod def get_document_paragraph_model(dataset_id, instance: Dict): @@ -834,8 +817,7 @@ def batch_save(self, instance_list: List[Dict], with_valid=True): dataset_id = self.data.get("dataset_id") document_model_list = [] paragraph_model_list = [] - problem_model_list = [] - problem_paragraph_mapping_list = [] + problem_paragraph_object_list = [] # 插入文档 for document in instance_list: document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, @@ -843,11 +825,12 @@ def batch_save(self, instance_list: List[Dict], with_valid=True): document_model_list.append(document_paragraph_dict_model.get('document')) for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): paragraph_model_list.append(paragraph) - for problem in document_paragraph_dict_model.get('problem_model_list'): - problem_model_list.append(problem) - for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'): - problem_paragraph_mapping_list.append(problem_paragraph_mapping) + for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_paragraph_object) + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id) + .to_problem_model_list()) # 插入文档 QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None # 批量插入段落 diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 3188766a782..61ae860b62d 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -21,7 +21,8 @@ from common.util.common import post from common.util.field_message import ErrMessage from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping -from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer +from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ + ProblemParagraphManage from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType @@ -567,8 +568,10 @@ def save(self, instance: Dict, with_valid=True, with_embedding=True): document_id = self.data.get('document_id') paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance) paragraph = paragraph_problem_model.get('paragraph') - problem_model_list = paragraph_problem_model.get('problem_model_list') - problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list') + problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id). + to_problem_model_list()) # 插入段落 paragraph_problem_model.get('paragraph').save() # 插入問題 @@ -591,30 +594,12 @@ def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dic content=instance.get("content"), dataset_id=dataset_id, title=instance.get("title") if 'title' in instance else '') - problem_list = instance.get('problem_list') - exists_problem_list = [] - if 'problem_list' in instance and len(problem_list) > 0: - exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id, - content__in=[p.get('content') for p in - problem_list]).all() - - problem_model_list = [ - ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for - problem in ( - instance.get('problem_list') if 'problem_list' in instance else [])] - # 问题去重 - problem_model_list = [x for i, x in enumerate(problem_model_list) if - len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0] - - problem_paragraph_mapping_list = [ - ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id, - paragraph_id=paragraph.id, - dataset_id=dataset_id) for - problem_model in problem_model_list] + problem_paragraph_object_list = [ + ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in + (instance.get('problem_list') if 'problem_list' in instance else [])] + return {'paragraph': paragraph, - 'problem_model_list': [problem_model for problem_model in problem_model_list if - not list(exists_problem_list).__contains__(problem_model)], - 'problem_paragraph_mapping_list': problem_paragraph_mapping_list} + 'problem_paragraph_object_list': problem_paragraph_object_list} @staticmethod def or_get(exists_problem_list, content, dataset_id):