@@ -169,16 +169,40 @@ def create_empty_dataset(
169
169
provider : str = "vendor" ,
170
170
external_knowledge_api_id : Optional [str ] = None ,
171
171
external_knowledge_id : Optional [str ] = None ,
172
+ embedding_model_provider : Optional [str ] = None ,
173
+ embedding_model_name : Optional [str ] = None ,
174
+ retrieval_model : Optional [RetrievalModel ] = None ,
172
175
):
173
176
# check if dataset name already exists
174
177
if Dataset .query .filter_by (name = name , tenant_id = tenant_id ).first ():
175
178
raise DatasetNameDuplicateError (f"Dataset with name { name } already exists." )
176
179
embedding_model = None
177
180
if indexing_technique == "high_quality" :
178
181
model_manager = ModelManager ()
179
- embedding_model = model_manager .get_default_model_instance (
180
- tenant_id = tenant_id , model_type = ModelType .TEXT_EMBEDDING
181
- )
182
+ if embedding_model_provider and embedding_model_name :
183
+ # check if embedding model setting is valid
184
+ DatasetService .check_embedding_model_setting (tenant_id , embedding_model_provider , embedding_model_name )
185
+ embedding_model = model_manager .get_model_instance (
186
+ tenant_id = tenant_id ,
187
+ provider = embedding_model_provider ,
188
+ model_type = ModelType .TEXT_EMBEDDING ,
189
+ model = embedding_model_name ,
190
+ )
191
+ else :
192
+ embedding_model = model_manager .get_default_model_instance (
193
+ tenant_id = tenant_id , model_type = ModelType .TEXT_EMBEDDING
194
+ )
195
+ if retrieval_model and retrieval_model .reranking_model :
196
+ if (
197
+ retrieval_model .reranking_model .reranking_provider_name
198
+ and retrieval_model .reranking_model .reranking_model_name
199
+ ):
200
+ # check if reranking model setting is valid
201
+ DatasetService .check_embedding_model_setting (
202
+ tenant_id ,
203
+ retrieval_model .reranking_model .reranking_provider_name ,
204
+ retrieval_model .reranking_model .reranking_model_name ,
205
+ )
182
206
dataset = Dataset (name = name , indexing_technique = indexing_technique )
183
207
# dataset = Dataset(name=name, provider=provider, config=config)
184
208
dataset .description = description
@@ -187,6 +211,7 @@ def create_empty_dataset(
187
211
dataset .tenant_id = tenant_id
188
212
dataset .embedding_model_provider = embedding_model .provider if embedding_model else None
189
213
dataset .embedding_model = embedding_model .model if embedding_model else None
214
+ dataset .retrieval_model = retrieval_model .model_dump () if retrieval_model else None
190
215
dataset .permission = permission or DatasetPermissionEnum .ONLY_ME
191
216
dataset .provider = provider
192
217
db .session .add (dataset )
@@ -923,11 +948,11 @@ def save_document_with_dataset_id(
923
948
"score_threshold_enabled" : False ,
924
949
}
925
950
926
- dataset .retrieval_model = (
927
- knowledge_config .retrieval_model .model_dump ()
928
- if knowledge_config .retrieval_model
929
- else default_retrieval_model
930
- ) # type: ignore
951
+ dataset .retrieval_model = (
952
+ knowledge_config .retrieval_model .model_dump ()
953
+ if knowledge_config .retrieval_model
954
+ else default_retrieval_model
955
+ ) # type: ignore
931
956
932
957
documents = []
933
958
if knowledge_config .original_document_id :
0 commit comments