diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 2d64fd42f..eedbd86be 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -643,7 +643,9 @@ def get_resource_name(ocid: str) -> str: return name -def get_model_by_reference_paths(model_file_description: dict): +def get_model_by_reference_paths( + model_file_description: dict, is_ft_model_v2: bool = False +): """Reads the model file description json dict and returns the base model path and fine-tuned path for models created by reference. @@ -651,6 +653,8 @@ def get_model_by_reference_paths(model_file_description: dict): ---------- model_file_description: dict json dict containing model paths and objects for models created by reference. + is_ft_model_v2: bool + Flag to indicate if it's fine tuned model v2. Defaults to False. Returns ------- @@ -666,8 +670,18 @@ def get_model_by_reference_paths(model_file_description: dict): "Please check if the model created by reference has the correct artifact." ) + if is_ft_model_v2: + # model_file_description json for fine tuned model v2 contains only fine tuned model artifacts + # so first model is always the fine tuned model + ft_model_artifact = models[0] + fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip( + "/" + ) + + return UNKNOWN, fine_tune_output_path + if len(models) > 0: - # since the model_file_description json does not have a flag to identify the base model, we consider + # since the model_file_description json for legacy fine tuned model does not have a flag to identify the base model, we consider # the first instance to be the base model. base_model_artifact = models[0] base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip( diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index ca6e5ed6a..a6795497b 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -46,6 +46,7 @@ MODEL_FILE_DESCRIPTION_VERSION = "1.0" MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription" AQUA_FINE_TUNE_MODEL_VERSION = "v2" +INCLUDE_BASE_MODEL = 1 TRAINING_METRICS_FINAL = "training_metrics_final" VALIDATION_METRICS_FINAL = "validation_metrics_final" diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index f79348d1a..cabe41b9a 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -25,6 +25,7 @@ upload_local_to_os, ) from ads.aqua.constants import ( + AQUA_FINE_TUNE_MODEL_VERSION, DEFAULT_FT_BATCH_SIZE, DEFAULT_FT_BLOCK_STORAGE_SIZE, DEFAULT_FT_REPLICA, @@ -306,7 +307,9 @@ def create( } # needs to add 'fine_tune_model_version' tag when creating the ft model for the # ft container to block merging base model artifact with ft model artifact. - ft_model_freeform_tags = {Tags.AQUA_FINE_TUNE_MODEL_VERSION: "v2"} + ft_model_freeform_tags = { + Tags.AQUA_FINE_TUNE_MODEL_VERSION: AQUA_FINE_TUNE_MODEL_VERSION + } ft_model = self.create_model_catalog( display_name=create_fine_tuning_details.ft_name, diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index a149610f4..9663e460f 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -172,8 +172,10 @@ def create( The instance of DataScienceModel or DataScienceModelGroup. """ fine_tune_weights = [] + model_name = "" if isinstance(model, AquaMultiModelRef): fine_tune_weights = model.fine_tune_weights + model_name = model.model_name model = model.model_id service_model = DataScienceModel.from_id(model) @@ -194,6 +196,7 @@ def create( if fine_tune_weights: custom_model = self._create_model_group( model_id=model, + model_name=model_name, compartment_id=target_compartment, project_id=target_project, freeform_tags=combined_freeform_tags, @@ -268,6 +271,7 @@ def _create_model( def _create_model_group( self, model_id: str, + model_name: str, compartment_id: str, project_id: str, freeform_tags: Dict, @@ -276,6 +280,20 @@ def _create_model_group( service_model: DataScienceModel, ): """Creates a data science model group.""" + member_models = [ + { + "inference_key": fine_tune_weight.model_name, + "model_id": fine_tune_weight.model_id, + } + for fine_tune_weight in fine_tune_weights + ] + # must also include base model info in member models to create stacked model group + member_models.append( + { + "inference_key": model_name or service_model.display_name, + "model_id": model_id, + } + ) custom_model = ( DataScienceModelGroup() .with_compartment_id(compartment_id) @@ -286,15 +304,7 @@ def _create_model_group( .with_defined_tags(**defined_tags) .with_custom_metadata_list(service_model.custom_metadata_list) .with_base_model_id(model_id) - .with_member_models( - [ - { - "inference_key": fine_tune_weight.model_name, - "model_id": fine_tune_weight.model_id, - } - for fine_tune_weight in fine_tune_weights - ] - ) + .with_member_models(member_models) .create() ) diff --git a/ads/aqua/model/utils.py b/ads/aqua/model/utils.py index 3968b40f0..f92bd10f1 100644 --- a/ads/aqua/model/utils.py +++ b/ads/aqua/model/utils.py @@ -5,8 +5,10 @@ from typing import Tuple +from ads.aqua.common.enums import Tags from ads.aqua.common.errors import AquaValueError from ads.aqua.common.utils import get_model_by_reference_paths +from ads.aqua.constants import AQUA_FINE_TUNE_MODEL_VERSION from ads.aqua.finetuning.constants import FineTuneCustomMetadata from ads.common.object_storage_details import ObjectStorageDetails from ads.model.datascience_model import DataScienceModel @@ -34,8 +36,12 @@ def extract_base_model_from_ft(aqua_model: DataScienceModel) -> Tuple[str, str]: def extract_fine_tune_artifacts_path(aqua_model: DataScienceModel) -> Tuple[str, str]: """Extracts the fine tuning source (fine_tune_output_path) and base model path from the DataScienceModel Object""" + is_ft_model_v2 = ( + aqua_model.freeform_tags.get(Tags.AQUA_FINE_TUNE_MODEL_VERSION, "").lower() + == AQUA_FINE_TUNE_MODEL_VERSION + ) base_model_path, fine_tune_output_path = get_model_by_reference_paths( - aqua_model.model_file_description + aqua_model.model_file_description, is_ft_model_v2 ) if not fine_tune_output_path or not ObjectStorageDetails.is_oci_path( diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index fa2e372c6..84bd5a9ee 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -228,7 +228,13 @@ def create( raise AquaValueError( "Invalid 'models' provided. Only one base model is required for model stack deployment." ) + self._validate_input_models(create_deployment_details) model = create_deployment_details.models[0] + else: + try: + create_deployment_details.validate_ft_model_v2(model_id=model) + except ConfigValidationError as err: + raise AquaValueError(f"{err}") from err service_model_id = model if isinstance(model, str) else model.model_id logger.debug( @@ -258,26 +264,9 @@ def create( ) # TODO: add multi model validation from deployment_type else: - # Collect all unique model IDs (including fine-tuned models) - source_model_ids = list( - { - model_id - for model in create_deployment_details.models - for model_id in model.all_model_ids() - } - ) - logger.debug( - "Fetching source model metadata for model IDs: %s", source_model_ids + source_models, source_model_ids = self._validate_input_models( + create_deployment_details ) - # Fetch source model metadata - source_models = self.get_multi_source(source_model_ids) or {} - - try: - create_deployment_details.validate_input_models( - model_details=source_models - ) - except ConfigValidationError as err: - raise AquaValueError(f"{err}") from err base_model_ids = [ model.model_id for model in create_deployment_details.models @@ -394,6 +383,32 @@ def create( container_config=container_config, ) + def _validate_input_models( + self, + create_deployment_details: CreateModelDeploymentDetails, + ): + """Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment.""" + # Collect all unique model IDs (including fine-tuned models) + source_model_ids = list( + { + model_id + for model in create_deployment_details.models + for model_id in model.all_model_ids() + } + ) + logger.debug( + "Fetching source model metadata for model IDs: %s", source_model_ids + ) + # Fetch source model metadata + source_models = self.get_multi_source(source_model_ids) or {} + + try: + create_deployment_details.validate_input_models(model_details=source_models) + except ConfigValidationError as err: + raise AquaValueError(f"{err}") from err + + return source_models, source_model_ids + def _build_model_group_configs( self, models: List[AquaMultiModelRef], @@ -909,6 +924,8 @@ def _create( params_dict = get_params_dict(params) # updates `--served-model-name` with service model id params_dict.update({"--served-model-name": aqua_model.base_model_id}) + # TODO: sets `--max-lora-rank` as 32 in params for now, will revisit later + params_dict.update({"--max-lora-rank": 32}) # adds `--enable_lora` to parameters params_dict.update({"--enable_lora": UNKNOWN}) params = build_params_string(params_dict) diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index 3a95dc37c..cbcb499ad 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -12,7 +12,11 @@ from ads.aqua.common.enums import Tags from ads.aqua.common.errors import AquaValueError from ads.aqua.config.utils.serializer import Serializable -from ads.aqua.constants import UNKNOWN_DICT +from ads.aqua.constants import ( + AQUA_FINE_TUNE_MODEL_VERSION, + INCLUDE_BASE_MODEL, + UNKNOWN_DICT, +) from ads.aqua.data import AquaResourceIdentifier from ads.aqua.finetuning.constants import FineTuneCustomMetadata from ads.aqua.modeldeployment.config_loader import ( @@ -509,11 +513,13 @@ def validate_multimodel_deployment_feasibility( def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> None: """ - Validates the input models for a multi-model deployment configuration. + Validates the input models for a stacked-model or multi-model deployment configuration. Validation Criteria: - The base model must be explicitly provided. - The base model must be in 'ACTIVE' state. + - Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported. + - Fine-tuned models must not have custom metadata 'include_base_model_artifact' as 1. - Fine-tuned model IDs must refer to valid, tagged fine-tuned models. - Fine-tuned models must refer back to the same base model. - All model names (including fine-tuned variants) must be unique. @@ -609,6 +615,8 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N f"Invalid fine-tuned model ID '{ft_model_id}': missing tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}'." ) + self.validate_ft_model_v2(model=ft_model) + ft_base_model_id = ft_model.custom_metadata_list.get( FineTuneCustomMetadata.FINE_TUNE_SOURCE, ModelCustomMetadataItem( @@ -650,6 +658,61 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N f"{', '.join(sorted(duplicate_names))}. Model names must be unique for proper routing in multi-model deployments." ) + def validate_ft_model_v2( + self, model_id: Optional[str] = None, model: Optional[DataScienceModel] = None + ) -> None: + """ + Validates the input fine tuned model for model deployment configuration. + + Validation Criteria: + - Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported. + - Fine-tuned models must not have custom metadata 'include_base_model_artifact' as '1'. + + Parameters + ---------- + model_id : str + The OCID of DataScienceModel instance. + model : DataScienceModel + The DataScienceModel instance. + + Raises + ------ + ConfigValidationError + If any of the above conditions are violated. + """ + base_model = DataScienceModel.from_id(model_id) if model_id else model + if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags: + if ( + base_model.freeform_tags.get( + Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN + ).lower() + != AQUA_FINE_TUNE_MODEL_VERSION + ): + logger.error( + "Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.", + base_model.id, + ) + raise ConfigValidationError( + f"Invalid fine-tuned model ID '{base_model.id}': only fine tune model {AQUA_FINE_TUNE_MODEL_VERSION} is supported for model deployment. " + f"Run 'ads aqua model convert_fine_tune --model_id {base_model.id}' to convert legacy AQUA fine tuned model to version {AQUA_FINE_TUNE_MODEL_VERSION} for deployment." + ) + + include_base_model_artifact = base_model.custom_metadata_list.get( + FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT, + ModelCustomMetadataItem( + key=FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT + ), + ).value + + if include_base_model_artifact == INCLUDE_BASE_MODEL: + logger.error( + "Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.", + base_model.id, + ) + raise ConfigValidationError( + f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. " + ) + class Config: extra = "allow" protected_namespaces = () diff --git a/ads/model/datascience_model_group.py b/ads/model/datascience_model_group.py index 3443e42ec..64898e1fb 100644 --- a/ads/model/datascience_model_group.py +++ b/ads/model/datascience_model_group.py @@ -530,7 +530,6 @@ def _build_model_group_details(self) -> dict: custom_metadata_list=custom_metadata_list, base_model_id=self.base_model_id, ) - member_model_details.append(MemberModelDetails(model_id=self.base_model_id)) else: model_group_details = HomogeneousModelGroupDetails( custom_metadata_list=custom_metadata_list diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index cf3c222b4..172937a22 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -1437,8 +1437,12 @@ def test_verify_compatibility(self): @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch.object(AquaApp, "get_container_config") + @patch( + "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_ft_model_v2" + ) def test_create_deployment_for_foundation_model( self, + mock_validate_ft_model_v2, mock_get_container_config, mock_deploy, mock_get_container_image, @@ -1514,6 +1518,7 @@ def test_create_deployment_for_foundation_model( defined_tags=defined_tags, ) + mock_validate_ft_model_v2.assert_called() mock_create.assert_called_with( model=TestDataset.MODEL_ID, compartment_id=TestDataset.USER_COMPARTMENT_ID, @@ -1538,8 +1543,12 @@ def test_create_deployment_for_foundation_model( @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch.object(AquaApp, "get_container_config") + @patch( + "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_ft_model_v2" + ) def test_create_deployment_for_fine_tuned_model( self, + mock_validate_ft_model_v2, mock_get_container_config, mock_deploy, mock_get_container_image, @@ -1610,6 +1619,7 @@ def test_create_deployment_for_fine_tuned_model( predict_log_id="ocid1.log.oc1..", ) + mock_validate_ft_model_v2.assert_called() mock_create.assert_called_with( model=TestDataset.MODEL_ID, compartment_id=TestDataset.USER_COMPARTMENT_ID, @@ -1632,8 +1642,12 @@ def test_create_deployment_for_fine_tuned_model( @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch.object(AquaApp, "get_container_config") + @patch( + "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_ft_model_v2" + ) def test_create_deployment_for_gguf_model( self, + mock_validate_ft_model_v2, mock_get_container_config, mock_deploy, mock_get_container_image, @@ -1706,6 +1720,7 @@ def test_create_deployment_for_gguf_model( memory_in_gbs=60.0, ) + mock_validate_ft_model_v2.assert_called() mock_create.assert_called_with( model=TestDataset.MODEL_ID, compartment_id=TestDataset.USER_COMPARTMENT_ID, @@ -1732,8 +1747,12 @@ def test_create_deployment_for_gguf_model( @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch.object(AquaApp, "get_container_config") + @patch( + "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_ft_model_v2" + ) def test_create_deployment_for_tei_byoc_embedding_model( self, + mock_validate_ft_model_v2, mock_get_container_config, mock_deploy, mock_get_container_image, @@ -1809,6 +1828,7 @@ def test_create_deployment_for_tei_byoc_embedding_model( cmd_var=[], ) + mock_validate_ft_model_v2.assert_called() mock_create.assert_called_with( model=TestDataset.MODEL_ID, compartment_id=TestDataset.USER_COMPARTMENT_ID, @@ -1838,8 +1858,14 @@ def test_create_deployment_for_tei_byoc_embedding_model( @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch.object(AquaApp, "get_container_config") + @patch( + "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_input_models" + ) + @patch.object(AquaApp, "get_multi_source") def test_create_deployment_for_stack_model( self, + mock_get_multi_source, + mock_validate_input_models, mock_get_container_config, mock_deploy, mock_get_container_image, @@ -1931,6 +1957,8 @@ def test_create_deployment_for_stack_model( deployment_type="STACKED", ) + mock_get_multi_source.assert_called() + mock_validate_input_models.assert_called() mock_create.assert_called() mock_get_container_image.assert_called() mock_deploy.assert_called()