Skip to content

Commit 1f722cd

Browse files
authored
fix(api): Some params were ignored when creating empty Datasets through API (#17932)
1 parent 4aecc9f commit 1f722cd

File tree

9 files changed

+115
-20
lines changed

9 files changed

+115
-20
lines changed

api/controllers/console/app/annotation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get(self, app_id, job_id, action):
8989
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
9090
cache_result = redis_client.get(app_annotation_job_key)
9191
if cache_result is None:
92-
raise ValueError("The job is not exist.")
92+
raise ValueError("The job does not exist.")
9393

9494
job_status = cache_result.decode()
9595
error_msg = ""
@@ -226,7 +226,7 @@ def get(self, app_id, job_id):
226226
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
227227
cache_result = redis_client.get(indexing_cache_key)
228228
if cache_result is None:
229-
raise ValueError("The job is not exist.")
229+
raise ValueError("The job does not exist.")
230230
job_status = cache_result.decode()
231231
error_msg = ""
232232
if job_status == "error":

api/controllers/console/datasets/datasets_segments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def get(self, job_id):
398398
indexing_cache_key = "segment_batch_import_{}".format(job_id)
399399
cache_result = redis_client.get(indexing_cache_key)
400400
if cache_result is None:
401-
raise ValueError("The job is not exist.")
401+
raise ValueError("The job does not exist.")
402402

403403
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
404404

api/controllers/service_api/dataset/dataset.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from libs.login import current_user
1414
from models.dataset import Dataset, DatasetPermissionEnum
1515
from services.dataset_service import DatasetPermissionService, DatasetService
16+
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
1617

1718

1819
def _validate_name(name):
@@ -120,8 +121,11 @@ def post(self, tenant_id):
120121
nullable=True,
121122
required=False,
122123
)
123-
args = parser.parse_args()
124+
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
125+
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
126+
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
124127

128+
args = parser.parse_args()
125129
try:
126130
dataset = DatasetService.create_empty_dataset(
127131
tenant_id=tenant_id,
@@ -133,6 +137,9 @@ def post(self, tenant_id):
133137
provider=args["provider"],
134138
external_knowledge_api_id=args["external_knowledge_api_id"],
135139
external_knowledge_id=args["external_knowledge_id"],
140+
embedding_model_provider=args["embedding_model_provider"],
141+
embedding_model_name=args["embedding_model"],
142+
retrieval_model=RetrievalModel(**args["retrieval_model"]),
136143
)
137144
except services.errors.dataset.DatasetNameDuplicateError:
138145
raise DatasetNameDuplicateError()

api/controllers/service_api/dataset/document.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,17 @@ def post(self, tenant_id, dataset_id):
4949
parser.add_argument(
5050
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
5151
)
52-
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
52+
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
53+
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
54+
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
5355

5456
args = parser.parse_args()
5557
dataset_id = str(dataset_id)
5658
tenant_id = str(tenant_id)
5759
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
5860

5961
if not dataset:
60-
raise ValueError("Dataset is not exist.")
62+
raise ValueError("Dataset does not exist.")
6163

6264
if not dataset.indexing_technique and not args["indexing_technique"]:
6365
raise ValueError("indexing_technique is required.")
@@ -114,7 +116,7 @@ def post(self, tenant_id, dataset_id, document_id):
114116
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
115117

116118
if not dataset:
117-
raise ValueError("Dataset is not exist.")
119+
raise ValueError("Dataset does not exist.")
118120

119121
# indexing_technique is already set in dataset since this is an update
120122
args["indexing_technique"] = dataset.indexing_technique
@@ -172,7 +174,7 @@ def post(self, tenant_id, dataset_id):
172174
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
173175

174176
if not dataset:
175-
raise ValueError("Dataset is not exist.")
177+
raise ValueError("Dataset does not exist.")
176178
if not dataset.indexing_technique and not args.get("indexing_technique"):
177179
raise ValueError("indexing_technique is required.")
178180

@@ -239,7 +241,7 @@ def post(self, tenant_id, dataset_id, document_id):
239241
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
240242

241243
if not dataset:
242-
raise ValueError("Dataset is not exist.")
244+
raise ValueError("Dataset does not exist.")
243245

244246
# indexing_technique is already set in dataset since this is an update
245247
args["indexing_technique"] = dataset.indexing_technique
@@ -303,7 +305,7 @@ def delete(self, tenant_id, dataset_id, document_id):
303305
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
304306

305307
if not dataset:
306-
raise ValueError("Dataset is not exist.")
308+
raise ValueError("Dataset does not exist.")
307309

308310
document = DocumentService.get_document(dataset.id, document_id)
309311

api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
444444
if dataset_collection_binding:
445445
collection_name = dataset_collection_binding.collection_name
446446
else:
447-
raise ValueError("Dataset Collection Bindings is not exist!")
447+
raise ValueError("Dataset Collection Bindings does not exist!")
448448
else:
449449
if dataset.index_struct_dict:
450450
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

api/services/dataset_service.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,40 @@ def create_empty_dataset(
169169
provider: str = "vendor",
170170
external_knowledge_api_id: Optional[str] = None,
171171
external_knowledge_id: Optional[str] = None,
172+
embedding_model_provider: Optional[str] = None,
173+
embedding_model_name: Optional[str] = None,
174+
retrieval_model: Optional[RetrievalModel] = None,
172175
):
173176
# check if dataset name already exists
174177
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
175178
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
176179
embedding_model = None
177180
if indexing_technique == "high_quality":
178181
model_manager = ModelManager()
179-
embedding_model = model_manager.get_default_model_instance(
180-
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
181-
)
182+
if embedding_model_provider and embedding_model_name:
183+
# check if embedding model setting is valid
184+
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
185+
embedding_model = model_manager.get_model_instance(
186+
tenant_id=tenant_id,
187+
provider=embedding_model_provider,
188+
model_type=ModelType.TEXT_EMBEDDING,
189+
model=embedding_model_name,
190+
)
191+
else:
192+
embedding_model = model_manager.get_default_model_instance(
193+
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
194+
)
195+
if retrieval_model and retrieval_model.reranking_model:
196+
if (
197+
retrieval_model.reranking_model.reranking_provider_name
198+
and retrieval_model.reranking_model.reranking_model_name
199+
):
200+
# check if reranking model setting is valid
201+
DatasetService.check_embedding_model_setting(
202+
tenant_id,
203+
retrieval_model.reranking_model.reranking_provider_name,
204+
retrieval_model.reranking_model.reranking_model_name,
205+
)
182206
dataset = Dataset(name=name, indexing_technique=indexing_technique)
183207
# dataset = Dataset(name=name, provider=provider, config=config)
184208
dataset.description = description
@@ -187,6 +211,7 @@ def create_empty_dataset(
187211
dataset.tenant_id = tenant_id
188212
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
189213
dataset.embedding_model = embedding_model.model if embedding_model else None
214+
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
190215
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
191216
dataset.provider = provider
192217
db.session.add(dataset)
@@ -923,11 +948,11 @@ def save_document_with_dataset_id(
923948
"score_threshold_enabled": False,
924949
}
925950

926-
dataset.retrieval_model = (
927-
knowledge_config.retrieval_model.model_dump()
928-
if knowledge_config.retrieval_model
929-
else default_retrieval_model
930-
) # type: ignore
951+
dataset.retrieval_model = (
952+
knowledge_config.retrieval_model.model_dump()
953+
if knowledge_config.retrieval_model
954+
else default_retrieval_model
955+
) # type: ignore
931956

932957
documents = []
933958
if knowledge_config.original_document_id:

web/app/(commonLayout)/datasets/template/template.en.mdx

+22-1
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
314314
</Property>
315315
<Property name='indexing_technique' type='string' key='indexing_technique'>
316316
Index technique (optional)
317+
If this is not set, embedding_model, embedding_provider_name and retrieval_model will be set to null
317318
- <code>high_quality</code> High quality
318319
- <code>economy</code> Economy
319320
</Property>
@@ -334,6 +335,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
334335
<Property name='external_knowledge_id' type='str' key='external_knowledge_id'>
335336
External knowledge ID (optional)
336337
</Property>
338+
<Property name='embedding_model' type='str' key='embedding_model'>
339+
Embedding model name (optional)
340+
</Property>
341+
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'>
342+
Embedding model provider name (optional)
343+
</Property>
344+
<Property name='retrieval_model' type='object' key='retrieval_model'>
345+
Retrieval model (optional)
346+
- <code>search_method</code> (string) Search method
347+
- <code>hybrid_search</code> Hybrid search
348+
- <code>semantic_search</code> Semantic search
349+
- <code>full_text_search</code> Full-text search
350+
- <code>reranking_enable</code> (bool) Whether to enable reranking
351+
- <code>reranking_model</code> (object) Rerank model configuration
352+
- <code>reranking_provider_name</code> (string) Rerank model provider
353+
- <code>reranking_model_name</code> (string) Rerank model name
354+
- <code>top_k</code> (int) Number of results to return
355+
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
356+
- <code>score_threshold</code> (float) Score threshold
357+
</Property>
337358
</Properties>
338359
</Col>
339360
<Col sticky>
@@ -2281,4 +2302,4 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
22812302
</tr>
22822303
</tbody>
22832304
</table>
2284-
<div className="pb-4" />
2305+
<div className="pb-4" />

web/app/(commonLayout)/datasets/template/template.ja.mdx

+20
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
334334
<Property name='external_knowledge_id' type='str' key='external_knowledge_id'>
335335
外部ナレッジ ID (オプション)
336336
</Property>
337+
<Property name='embedding_model' type='str' key='embedding_model'>
338+
埋め込みモデル名(任意)
339+
</Property>
340+
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'>
341+
埋め込みモデルのプロバイダ名(任意)
342+
</Property>
343+
<Property name='retrieval_model' type='object' key='retrieval_model'>
344+
検索モデル(任意)
345+
- <code>search_method</code> (文字列) 検索方法
346+
- <code>hybrid_search</code> ハイブリッド検索
347+
- <code>semantic_search</code> セマンティック検索
348+
- <code>full_text_search</code> 全文検索
349+
- <code>reranking_enable</code> (ブール値) リランキングを有効にするかどうか
350+
- <code>reranking_model</code> (オブジェクト) リランクモデルの設定
351+
- <code>reranking_provider_name</code> (文字列) リランクモデルのプロバイダ
352+
- <code>reranking_model_name</code> (文字列) リランクモデル名
353+
- <code>top_k</code> (整数) 返される結果の数
354+
- <code>score_threshold_enabled</code> (ブール値) スコア閾値を有効にするかどうか
355+
- <code>score_threshold</code> (浮動小数点数) スコア閾値
356+
</Property>
337357
</Properties>
338358
</Col>
339359
<Col sticky>

web/app/(commonLayout)/datasets/template/template.zh.mdx

+20
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
335335
<Property name='external_knowledge_id' type='str' key='external_knowledge_id'>
336336
外部知识库 ID(选填)
337337
</Property>
338+
<Property name='embedding_model' type='str' key='embedding_model'>
339+
Embedding 模型名称
340+
</Property>
341+
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'>
342+
Embedding 模型供应商
343+
</Property>
344+
<Property name='retrieval_model' type='object' key='retrieval_model'>
345+
检索模式
346+
- <code>search_method</code> (string) 检索方法
347+
- <code>hybrid_search</code> 混合检索
348+
- <code>semantic_search</code> 语义检索
349+
- <code>full_text_search</code> 全文检索
350+
- <code>reranking_enable</code> (bool) 是否开启rerank
351+
- <code>reranking_model</code> (object) Rerank 模型配置
352+
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
353+
- <code>reranking_model_name</code> (string) Rerank 模型的名称
354+
- <code>top_k</code> (int) 召回条数
355+
- <code>score_threshold_enabled</code> (bool)是否开启召回分数限制
356+
- <code>score_threshold</code> (float) 召回分数限制
357+
</Property>
338358
</Properties>
339359
</Col>
340360
<Col sticky>

0 commit comments

Comments
 (0)