@@ -141,19 +141,19 @@ class AquaModelApp(AquaApp):
141141 @telemetry (entry_point = "plugin=model&action=create" , name = "aqua" )
142142 def create (
143143 self ,
144- model_id : Union [str , AquaMultiModelRef ],
144+ model : Union [str , AquaMultiModelRef ],
145145 project_id : Optional [str ] = None ,
146146 compartment_id : Optional [str ] = None ,
147147 freeform_tags : Optional [Dict ] = None ,
148148 defined_tags : Optional [Dict ] = None ,
149149 ** kwargs ,
150- ) -> DataScienceModel :
150+ ) -> Union [ DataScienceModel , DataScienceModelGroup ] :
151151 """
152- Creates a custom Aqua model from a service model.
152+ Creates a custom Aqua model or model group from a service model.
153153
154154 Parameters
155155 ----------
156- model_id : Union[str, AquaMultiModelRef]
156+ model : Union[str, AquaMultiModelRef]
157157 The model ID as a string or a AquaMultiModelRef instance to be deployed.
158158 project_id : Optional[str]
159159 The project ID for the custom model.
@@ -167,28 +167,18 @@ def create(
167167
168168 Returns
169169 -------
170- DataScienceModel
171- The instance of DataScienceModel.
170+ Union[ DataScienceModel, DataScienceModelGroup]
171+ The instance of DataScienceModel or DataScienceModelGroup .
172172 """
173- model_id = (
174- model_id .model_id if isinstance (model_id , AquaMultiModelRef ) else model_id
175- )
176- service_model = DataScienceModel .from_id (model_id )
173+ fine_tune_weights = []
174+ if isinstance (model , AquaMultiModelRef ):
175+ fine_tune_weights = model .fine_tune_weights
176+ model = model .model_id
177+
178+ service_model = DataScienceModel .from_id (model )
177179 target_project = project_id or PROJECT_OCID
178180 target_compartment = compartment_id or COMPARTMENT_OCID
179181
180- # Skip model copying if it is registered model or fine-tuned model
181- if (
182- service_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ) is not None
183- or service_model .freeform_tags .get (Tags .AQUA_FINE_TUNED_MODEL_TAG )
184- is not None
185- ):
186- logger .info (
187- f"Aqua Model { model_id } already exists in the user's compartment."
188- "Skipped copying."
189- )
190- return service_model
191-
192182 # combine tags
193183 combined_freeform_tags = {
194184 ** (service_model .freeform_tags or {}),
@@ -199,29 +189,112 @@ def create(
199189 ** (defined_tags or {}),
200190 }
201191
192+ custom_model = None
193+ if fine_tune_weights :
194+ custom_model = self ._create_model_group (
195+ model_id = model ,
196+ compartment_id = target_compartment ,
197+ project_id = target_project ,
198+ freeform_tags = combined_freeform_tags ,
199+ defined_tags = combined_defined_tags ,
200+ fine_tune_weights = fine_tune_weights ,
201+ service_model = service_model ,
202+ )
203+
204+ logger .info (
205+ f"Aqua Model Group { custom_model .id } created with the service model { model } ."
206+ )
207+ else :
208+ # Skip model copying if it is registered model or fine-tuned model
209+ if (
210+ Tags .BASE_MODEL_CUSTOM in service_model .freeform_tags
211+ or Tags .AQUA_FINE_TUNED_MODEL_TAG in service_model .freeform_tags
212+ ):
213+ logger .info (
214+ f"Aqua Model { model } already exists in the user's compartment."
215+ "Skipped copying."
216+ )
217+ return service_model
218+
219+ custom_model = self ._create_model (
220+ compartment_id = target_compartment ,
221+ project_id = target_project ,
222+ freeform_tags = combined_freeform_tags ,
223+ defined_tags = combined_defined_tags ,
224+ service_model = service_model ,
225+ ** kwargs ,
226+ )
227+ logger .info (
228+ f"Aqua Model { custom_model .id } created with the service model { model } ."
229+ )
230+
231+ # Track unique models that were created in the user's compartment
232+ self .telemetry .record_event_async (
233+ category = "aqua/service/model" ,
234+ action = "create" ,
235+ detail = service_model .display_name ,
236+ )
237+
238+ return custom_model
239+
240+ def _create_model (
241+ self ,
242+ compartment_id : str ,
243+ project_id : str ,
244+ freeform_tags : Dict ,
245+ defined_tags : Dict ,
246+ service_model : DataScienceModel ,
247+ ** kwargs ,
248+ ):
249+ """Creates a data science model by reference."""
202250 custom_model = (
203251 DataScienceModel ()
204- .with_compartment_id (target_compartment )
205- .with_project_id (target_project )
252+ .with_compartment_id (compartment_id )
253+ .with_project_id (project_id )
206254 .with_model_file_description (json_dict = service_model .model_file_description )
207255 .with_display_name (service_model .display_name )
208256 .with_description (service_model .description )
209- .with_freeform_tags (** combined_freeform_tags )
210- .with_defined_tags (** combined_defined_tags )
257+ .with_freeform_tags (** freeform_tags )
258+ .with_defined_tags (** defined_tags )
211259 .with_custom_metadata_list (service_model .custom_metadata_list )
212260 .with_defined_metadata_list (service_model .defined_metadata_list )
213261 .with_provenance_metadata (service_model .provenance_metadata )
214262 .create (model_by_reference = True , ** kwargs )
215263 )
216- logger .info (
217- f"Aqua Model { custom_model .id } created with the service model { model_id } ."
218- )
219264
220- # Track unique models that were created in the user's compartment
221- self .telemetry .record_event_async (
222- category = "aqua/service/model" ,
223- action = "create" ,
224- detail = service_model .display_name ,
265+ return custom_model
266+
267+ def _create_model_group (
268+ self ,
269+ model_id : str ,
270+ compartment_id : str ,
271+ project_id : str ,
272+ freeform_tags : Dict ,
273+ defined_tags : Dict ,
274+ fine_tune_weights : List ,
275+ service_model : DataScienceModel ,
276+ ):
277+ """Creates a data science model group."""
278+ custom_model = (
279+ DataScienceModelGroup ()
280+ .with_compartment_id (compartment_id )
281+ .with_project_id (project_id )
282+ .with_display_name (service_model .display_name )
283+ .with_description (service_model .description )
284+ .with_freeform_tags (** freeform_tags )
285+ .with_defined_tags (** defined_tags )
286+ .with_custom_metadata_list (service_model .custom_metadata_list )
287+ .with_base_model_id (model_id )
288+ .with_member_models (
289+ [
290+ {
291+ "inference_key" : fine_tune_weight .model_name ,
292+ "model_id" : fine_tune_weight .model_id ,
293+ }
294+ for fine_tune_weight in fine_tune_weights
295+ ]
296+ )
297+ .create ()
225298 )
226299
227300 return custom_model
@@ -271,6 +344,16 @@ def create_multi(
271344 DataScienceModelGroup
272345 Instance of DataScienceModelGroup object.
273346 """
347+ member_model_ids = [{"model_id" : model .model_id } for model in models ]
348+ for model in models :
349+ if model .fine_tune_weights :
350+ member_model_ids .extend (
351+ [
352+ {"model_id" : fine_tune_model .model_id }
353+ for fine_tune_model in model .fine_tune_weights
354+ ]
355+ )
356+
274357 custom_model_group = (
275358 DataScienceModelGroup ()
276359 .with_compartment_id (compartment_id )
@@ -281,7 +364,7 @@ def create_multi(
281364 .with_defined_tags (** (defined_tags or {}))
282365 .with_custom_metadata_list (model_custom_metadata )
283366 # TODO: add member model inference key
284- .with_member_models ([{ "model_id" : model . model_id for model in models }] )
367+ .with_member_models (member_model_ids )
285368 )
286369 custom_model_group .create ()
287370
0 commit comments