Skip to content

Commit 292f13d

Browse files
JohnJyongparambharat
authored andcommitted
Support knowledge metadata filter (langgenius#15982)
1 parent 0ff705d commit 292f13d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2501
-573
lines changed

api/controllers/console/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
datasets_segments,
8282
external,
8383
hit_testing,
84+
metadata,
8485
website,
8586
)
8687

api/controllers/console/datasets/datasets_document.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def get(self, dataset_id, document_id):
621621
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
622622

623623
if metadata == "only":
624-
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
624+
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
625625
elif metadata == "without":
626626
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
627627
document_process_rules = document.dataset_process_rule.to_dict()
@@ -682,7 +682,7 @@ def get(self, dataset_id, document_id):
682682
"disabled_by": document.disabled_by,
683683
"archived": document.archived,
684684
"doc_type": document.doc_type,
685-
"doc_metadata": document.doc_metadata,
685+
"doc_metadata": document.doc_metadata_details,
686686
"segment_count": document.segment_count,
687687
"average_segment_length": document.average_segment_length,
688688
"hit_count": document.hit_count,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from flask_login import current_user # type: ignore # type: ignore
2+
from flask_restful import Resource, marshal_with, reqparse # type: ignore
3+
from werkzeug.exceptions import NotFound
4+
5+
from controllers.console import api
6+
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
7+
from fields.dataset_fields import dataset_metadata_fields
8+
from libs.login import login_required
9+
from services.dataset_service import DatasetService
10+
from services.entities.knowledge_entities.knowledge_entities import (
11+
MetadataArgs,
12+
MetadataOperationData,
13+
)
14+
from services.metadata_service import MetadataService
15+
16+
17+
def _validate_name(name):
18+
if not name or len(name) < 1 or len(name) > 40:
19+
raise ValueError("Name must be between 1 to 40 characters.")
20+
return name
21+
22+
23+
def _validate_description_length(description):
24+
if len(description) > 400:
25+
raise ValueError("Description cannot exceed 400 characters.")
26+
return description
27+
28+
29+
class DatasetMetadataCreateApi(Resource):
30+
@setup_required
31+
@login_required
32+
@account_initialization_required
33+
@enterprise_license_required
34+
@marshal_with(dataset_metadata_fields)
35+
def post(self, dataset_id):
36+
parser = reqparse.RequestParser()
37+
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
38+
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
39+
args = parser.parse_args()
40+
metadata_args = MetadataArgs(**args)
41+
42+
dataset_id_str = str(dataset_id)
43+
dataset = DatasetService.get_dataset(dataset_id_str)
44+
if dataset is None:
45+
raise NotFound("Dataset not found.")
46+
DatasetService.check_dataset_permission(dataset, current_user)
47+
48+
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
49+
return metadata, 201
50+
51+
@setup_required
52+
@login_required
53+
@account_initialization_required
54+
@enterprise_license_required
55+
def get(self, dataset_id):
56+
dataset_id_str = str(dataset_id)
57+
dataset = DatasetService.get_dataset(dataset_id_str)
58+
if dataset is None:
59+
raise NotFound("Dataset not found.")
60+
return MetadataService.get_dataset_metadatas(dataset), 200
61+
62+
63+
class DatasetMetadataApi(Resource):
64+
@setup_required
65+
@login_required
66+
@account_initialization_required
67+
@enterprise_license_required
68+
@marshal_with(dataset_metadata_fields)
69+
def patch(self, dataset_id, metadata_id):
70+
parser = reqparse.RequestParser()
71+
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
72+
args = parser.parse_args()
73+
74+
dataset_id_str = str(dataset_id)
75+
metadata_id_str = str(metadata_id)
76+
dataset = DatasetService.get_dataset(dataset_id_str)
77+
if dataset is None:
78+
raise NotFound("Dataset not found.")
79+
DatasetService.check_dataset_permission(dataset, current_user)
80+
81+
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
82+
return metadata, 200
83+
84+
@setup_required
85+
@login_required
86+
@account_initialization_required
87+
@enterprise_license_required
88+
def delete(self, dataset_id, metadata_id):
89+
dataset_id_str = str(dataset_id)
90+
metadata_id_str = str(metadata_id)
91+
dataset = DatasetService.get_dataset(dataset_id_str)
92+
if dataset is None:
93+
raise NotFound("Dataset not found.")
94+
DatasetService.check_dataset_permission(dataset, current_user)
95+
96+
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
97+
return 200
98+
99+
100+
class DatasetMetadataBuiltInFieldApi(Resource):
101+
@setup_required
102+
@login_required
103+
@account_initialization_required
104+
@enterprise_license_required
105+
def get(self):
106+
built_in_fields = MetadataService.get_built_in_fields()
107+
return {"fields": built_in_fields}, 200
108+
109+
110+
class DatasetMetadataBuiltInFieldActionApi(Resource):
111+
@setup_required
112+
@login_required
113+
@account_initialization_required
114+
@enterprise_license_required
115+
def post(self, dataset_id, action):
116+
dataset_id_str = str(dataset_id)
117+
dataset = DatasetService.get_dataset(dataset_id_str)
118+
if dataset is None:
119+
raise NotFound("Dataset not found.")
120+
DatasetService.check_dataset_permission(dataset, current_user)
121+
122+
if action == "enable":
123+
MetadataService.enable_built_in_field(dataset)
124+
elif action == "disable":
125+
MetadataService.disable_built_in_field(dataset)
126+
return 200
127+
128+
129+
class DocumentMetadataEditApi(Resource):
130+
@setup_required
131+
@login_required
132+
@account_initialization_required
133+
@enterprise_license_required
134+
def post(self, dataset_id):
135+
dataset_id_str = str(dataset_id)
136+
dataset = DatasetService.get_dataset(dataset_id_str)
137+
if dataset is None:
138+
raise NotFound("Dataset not found.")
139+
DatasetService.check_dataset_permission(dataset, current_user)
140+
141+
parser = reqparse.RequestParser()
142+
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
143+
args = parser.parse_args()
144+
metadata_args = MetadataOperationData(**args)
145+
146+
MetadataService.update_documents_metadata(dataset, metadata_args)
147+
148+
return 200
149+
150+
151+
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
152+
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
153+
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
154+
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
155+
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

api/core/app/app_config/easy_ui_based_app/dataset/manager.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import uuid
22
from typing import Optional
33

4-
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
4+
from core.app.app_config.entities import (
5+
DatasetEntity,
6+
DatasetRetrieveConfigEntity,
7+
MetadataFilteringCondition,
8+
ModelConfig,
9+
)
510
from core.entities.agent_entities import PlanningStrategy
611
from models.model import AppMode
712
from services.dataset_service import DatasetService
@@ -78,6 +83,15 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]:
7883
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
7984
dataset_configs["retrieval_model"]
8085
),
86+
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
87+
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
88+
if dataset_configs.get("metadata_model_config")
89+
else None,
90+
metadata_filtering_conditions=MetadataFilteringCondition(
91+
**dataset_configs.get("metadata_filtering_conditions", {})
92+
)
93+
if dataset_configs.get("metadata_filtering_conditions")
94+
else None,
8195
),
8296
)
8397
else:
@@ -96,6 +110,15 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]:
96110
weights=dataset_configs.get("weights"),
97111
reranking_enabled=dataset_configs.get("reranking_enabled", True),
98112
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
113+
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
114+
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
115+
if dataset_configs.get("metadata_model_config")
116+
else None,
117+
metadata_filtering_conditions=MetadataFilteringCondition(
118+
**dataset_configs.get("metadata_filtering_conditions", {})
119+
)
120+
if dataset_configs.get("metadata_filtering_conditions")
121+
else None,
99122
),
100123
)
101124

api/core/app/app_config/entities.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections.abc import Sequence
22
from enum import Enum, StrEnum
3-
from typing import Any, Optional
3+
from typing import Any, Literal, Optional
44

55
from pydantic import BaseModel, Field, field_validator
66

77
from core.file import FileTransferMethod, FileType, FileUploadConfig
8+
from core.model_runtime.entities.llm_entities import LLMMode
89
from core.model_runtime.entities.message_entities import PromptMessageRole
910
from models.model import AppMode
1011

@@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
135136
config: dict[str, Any] = Field(default_factory=dict)
136137

137138

139+
SupportedComparisonOperator = Literal[
140+
# for string or array
141+
"contains",
142+
"not contains",
143+
"start with",
144+
"end with",
145+
"is",
146+
"is not",
147+
"empty",
148+
"not empty",
149+
# for number
150+
"=",
151+
"≠",
152+
">",
153+
"<",
154+
"≥",
155+
"≤",
156+
# for time
157+
"before",
158+
"after",
159+
]
160+
161+
162+
class ModelConfig(BaseModel):
163+
provider: str
164+
name: str
165+
mode: LLMMode
166+
completion_params: dict[str, Any] = {}
167+
168+
169+
class Condition(BaseModel):
170+
"""
171+
Conditon detail
172+
"""
173+
174+
name: str
175+
comparison_operator: SupportedComparisonOperator
176+
value: str | Sequence[str] | None | int | float = None
177+
178+
179+
class MetadataFilteringCondition(BaseModel):
180+
"""
181+
Metadata Filtering Condition.
182+
"""
183+
184+
logical_operator: Optional[Literal["and", "or"]] = "and"
185+
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
186+
187+
138188
class DatasetRetrieveConfigEntity(BaseModel):
139189
"""
140190
Dataset Retrieve Config Entity.
@@ -171,6 +221,9 @@ def value_of(cls, value: str):
171221
reranking_model: Optional[dict] = None
172222
weights: Optional[dict] = None
173223
reranking_enabled: Optional[bool] = True
224+
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
225+
metadata_model_config: Optional[ModelConfig] = None
226+
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
174227

175228

176229
class DatasetEntity(BaseModel):

api/core/app/apps/chat/app_runner.py

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def run(
180180
hit_callback=hit_callback,
181181
memory=memory,
182182
message_id=message.id,
183+
inputs=inputs,
183184
)
184185

185186
# reorganize all inputs and template to prompt messages

api/core/app/apps/completion/app_runner.py

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def run(
139139
show_retrieve_source=app_config.additional_features.show_retrieve_source,
140140
hit_callback=hit_callback,
141141
message_id=message.id,
142+
inputs=inputs,
142143
)
143144

144145
# reorganize all inputs and template to prompt messages

api/core/rag/datasource/keyword/jieba/jieba.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,17 @@ def search(self, query: str, **kwargs: Any) -> list[Document]:
8888
keyword_table = self._get_dataset_keyword_table()
8989

9090
k = kwargs.get("top_k", 4)
91-
91+
document_ids_filter = kwargs.get("document_ids_filter")
9292
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
9393

9494
documents = []
9595
for chunk_index in sorted_chunk_indices:
96-
segment = (
97-
db.session.query(DocumentSegment)
98-
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
99-
.first()
96+
segment_query = db.session.query(DocumentSegment).filter(
97+
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
10098
)
99+
if document_ids_filter:
100+
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
101+
segment = segment_query.first()
101102

102103
if segment:
103104
documents.append(

0 commit comments

Comments
 (0)