Skip to content

Block legacy ft model #1238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: feature/model_group
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,14 +643,18 @@ def get_resource_name(ocid: str) -> str:
return name


def get_model_by_reference_paths(model_file_description: dict):
def get_model_by_reference_paths(
model_file_description: dict, is_ft_model_v2: bool = False
):
"""Reads the model file description json dict and returns the base model path and fine-tuned path for
models created by reference.

Parameters
----------
model_file_description: dict
json dict containing model paths and objects for models created by reference.
is_ft_model_v2: bool
Flag to indicate if it's fine tuned model v2. Defaults to False.

Returns
-------
Expand All @@ -666,8 +670,18 @@ def get_model_by_reference_paths(model_file_description: dict):
"Please check if the model created by reference has the correct artifact."
)

if is_ft_model_v2:
# model_file_description json for fine tuned model v2 contains only fine tuned model artifacts
# so first model is always the fine tuned model
ft_model_artifact = models[0]
fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip(
"/"
)

return UNKNOWN, fine_tune_output_path

if len(models) > 0:
# since the model_file_description json does not have a flag to identify the base model, we consider
# since the model_file_description json for legacy fine tuned model does not have a flag to identify the base model, we consider
# the first instance to be the base model.
base_model_artifact = models[0]
base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip(
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
MODEL_FILE_DESCRIPTION_VERSION = "1.0"
MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"
AQUA_FINE_TUNE_MODEL_VERSION = "v2"
INCLUDE_BASE_MODEL = 1

TRAINING_METRICS_FINAL = "training_metrics_final"
VALIDATION_METRICS_FINAL = "validation_metrics_final"
Expand Down
5 changes: 4 additions & 1 deletion ads/aqua/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
upload_local_to_os,
)
from ads.aqua.constants import (
AQUA_FINE_TUNE_MODEL_VERSION,
DEFAULT_FT_BATCH_SIZE,
DEFAULT_FT_BLOCK_STORAGE_SIZE,
DEFAULT_FT_REPLICA,
Expand Down Expand Up @@ -306,7 +307,9 @@ def create(
}
# needs to add 'fine_tune_model_version' tag when creating the ft model for the
# ft container to block merging base model artifact with ft model artifact.
ft_model_freeform_tags = {Tags.AQUA_FINE_TUNE_MODEL_VERSION: "v2"}
ft_model_freeform_tags = {
Tags.AQUA_FINE_TUNE_MODEL_VERSION: AQUA_FINE_TUNE_MODEL_VERSION
}

ft_model = self.create_model_catalog(
display_name=create_fine_tuning_details.ft_name,
Expand Down
28 changes: 19 additions & 9 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def create(
The instance of DataScienceModel or DataScienceModelGroup.
"""
fine_tune_weights = []
model_name = ""
if isinstance(model, AquaMultiModelRef):
fine_tune_weights = model.fine_tune_weights
model_name = model.model_name
model = model.model_id

service_model = DataScienceModel.from_id(model)
Expand All @@ -194,6 +196,7 @@ def create(
if fine_tune_weights:
custom_model = self._create_model_group(
model_id=model,
model_name=model_name,
compartment_id=target_compartment,
project_id=target_project,
freeform_tags=combined_freeform_tags,
Expand Down Expand Up @@ -268,6 +271,7 @@ def _create_model(
def _create_model_group(
self,
model_id: str,
model_name: str,
compartment_id: str,
project_id: str,
freeform_tags: Dict,
Expand All @@ -276,6 +280,20 @@ def _create_model_group(
service_model: DataScienceModel,
):
"""Creates a data science model group."""
member_models = [
{
"inference_key": fine_tune_weight.model_name,
"model_id": fine_tune_weight.model_id,
}
for fine_tune_weight in fine_tune_weights
]
# must also include base model info in member models to create stacked model group
member_models.append(
{
"inference_key": model_name or service_model.display_name,
"model_id": model_id,
}
)
custom_model = (
DataScienceModelGroup()
.with_compartment_id(compartment_id)
Expand All @@ -286,15 +304,7 @@ def _create_model_group(
.with_defined_tags(**defined_tags)
.with_custom_metadata_list(service_model.custom_metadata_list)
.with_base_model_id(model_id)
.with_member_models(
[
{
"inference_key": fine_tune_weight.model_name,
"model_id": fine_tune_weight.model_id,
}
for fine_tune_weight in fine_tune_weights
]
)
.with_member_models(member_models)
.create()
)

Expand Down
8 changes: 7 additions & 1 deletion ads/aqua/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from typing import Tuple

from ads.aqua.common.enums import Tags
from ads.aqua.common.errors import AquaValueError
from ads.aqua.common.utils import get_model_by_reference_paths
from ads.aqua.constants import AQUA_FINE_TUNE_MODEL_VERSION
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
from ads.common.object_storage_details import ObjectStorageDetails
from ads.model.datascience_model import DataScienceModel
Expand Down Expand Up @@ -34,8 +36,12 @@ def extract_base_model_from_ft(aqua_model: DataScienceModel) -> Tuple[str, str]:
def extract_fine_tune_artifacts_path(aqua_model: DataScienceModel) -> Tuple[str, str]:
"""Extracts the fine tuning source (fine_tune_output_path) and base model path from the DataScienceModel Object"""

is_ft_model_v2 = (
aqua_model.freeform_tags.get(Tags.AQUA_FINE_TUNE_MODEL_VERSION, "").lower()
== AQUA_FINE_TUNE_MODEL_VERSION
)
base_model_path, fine_tune_output_path = get_model_by_reference_paths(
aqua_model.model_file_description
aqua_model.model_file_description, is_ft_model_v2
)

if not fine_tune_output_path or not ObjectStorageDetails.is_oci_path(
Expand Down
55 changes: 36 additions & 19 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,13 @@ def create(
raise AquaValueError(
"Invalid 'models' provided. Only one base model is required for model stack deployment."
)
self._validate_input_models(create_deployment_details)
model = create_deployment_details.models[0]
else:
try:
create_deployment_details.validate_ft_model_v2(model_id=model)
except ConfigValidationError as err:
raise AquaValueError(f"{err}") from err

service_model_id = model if isinstance(model, str) else model.model_id
logger.debug(
Expand Down Expand Up @@ -258,26 +264,9 @@ def create(
)
# TODO: add multi model validation from deployment_type
else:
# Collect all unique model IDs (including fine-tuned models)
source_model_ids = list(
{
model_id
for model in create_deployment_details.models
for model_id in model.all_model_ids()
}
)
logger.debug(
"Fetching source model metadata for model IDs: %s", source_model_ids
source_models, source_model_ids = self._validate_input_models(
create_deployment_details
)
# Fetch source model metadata
source_models = self.get_multi_source(source_model_ids) or {}

try:
create_deployment_details.validate_input_models(
model_details=source_models
)
except ConfigValidationError as err:
raise AquaValueError(f"{err}") from err

base_model_ids = [
model.model_id for model in create_deployment_details.models
Expand Down Expand Up @@ -394,6 +383,32 @@ def create(
container_config=container_config,
)

def _validate_input_models(
self,
create_deployment_details: CreateModelDeploymentDetails,
):
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
# Collect all unique model IDs (including fine-tuned models)
source_model_ids = list(
{
model_id
for model in create_deployment_details.models
for model_id in model.all_model_ids()
}
)
logger.debug(
"Fetching source model metadata for model IDs: %s", source_model_ids
)
# Fetch source model metadata
source_models = self.get_multi_source(source_model_ids) or {}

try:
create_deployment_details.validate_input_models(model_details=source_models)
except ConfigValidationError as err:
raise AquaValueError(f"{err}") from err

return source_models, source_model_ids

def _build_model_group_configs(
self,
models: List[AquaMultiModelRef],
Expand Down Expand Up @@ -909,6 +924,8 @@ def _create(
params_dict = get_params_dict(params)
# updates `--served-model-name` with service model id
params_dict.update({"--served-model-name": aqua_model.base_model_id})
# TODO: sets `--max-lora-rank` as 32 in params for now, will revisit later
params_dict.update({"--max-lora-rank": 32})
# adds `--enable_lora` to parameters
params_dict.update({"--enable_lora": UNKNOWN})
params = build_params_string(params_dict)
Expand Down
67 changes: 65 additions & 2 deletions ads/aqua/modeldeployment/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from ads.aqua.common.enums import Tags
from ads.aqua.common.errors import AquaValueError
from ads.aqua.config.utils.serializer import Serializable
from ads.aqua.constants import UNKNOWN_DICT
from ads.aqua.constants import (
AQUA_FINE_TUNE_MODEL_VERSION,
INCLUDE_BASE_MODEL,
UNKNOWN_DICT,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
from ads.aqua.modeldeployment.config_loader import (
Expand Down Expand Up @@ -509,11 +513,13 @@ def validate_multimodel_deployment_feasibility(

def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> None:
"""
Validates the input models for a multi-model deployment configuration.
Validates the input models for a stacked-model or multi-model deployment configuration.

Validation Criteria:
- The base model must be explicitly provided.
- The base model must be in 'ACTIVE' state.
- Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported.
- Fine-tuned models must not have custom metadata 'include_base_model_artifact' as 1.
- Fine-tuned model IDs must refer to valid, tagged fine-tuned models.
- Fine-tuned models must refer back to the same base model.
- All model names (including fine-tuned variants) must be unique.
Expand Down Expand Up @@ -609,6 +615,8 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N
f"Invalid fine-tuned model ID '{ft_model_id}': missing tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}'."
)

self.validate_ft_model_v2(model=ft_model)

ft_base_model_id = ft_model.custom_metadata_list.get(
FineTuneCustomMetadata.FINE_TUNE_SOURCE,
ModelCustomMetadataItem(
Expand Down Expand Up @@ -650,6 +658,61 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N
f"{', '.join(sorted(duplicate_names))}. Model names must be unique for proper routing in multi-model deployments."
)

def validate_ft_model_v2(
self, model_id: Optional[str] = None, model: Optional[DataScienceModel] = None
) -> None:
"""
Validates the input fine tuned model for model deployment configuration.

Validation Criteria:
- Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported.
- Fine-tuned models must not have custom metadata 'include_base_model_artifact' as '1'.

Parameters
----------
model_id : str
The OCID of DataScienceModel instance.
model : DataScienceModel
The DataScienceModel instance.

Raises
------
ConfigValidationError
If any of the above conditions are violated.
"""
base_model = DataScienceModel.from_id(model_id) if model_id else model
if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags:
if (
base_model.freeform_tags.get(
Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN
).lower()
!= AQUA_FINE_TUNE_MODEL_VERSION
):
logger.error(
"Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.",
base_model.id,
)
raise ConfigValidationError(
f"Invalid fine-tuned model ID '{base_model.id}': only fine tune model {AQUA_FINE_TUNE_MODEL_VERSION} is supported for model deployment. "
f"Run 'ads aqua model convert_fine_tune --model_id {base_model.id}' to convert legacy AQUA fine tuned model to version {AQUA_FINE_TUNE_MODEL_VERSION} for deployment."
)

include_base_model_artifact = base_model.custom_metadata_list.get(
FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT,
ModelCustomMetadataItem(
key=FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT
),
).value

if include_base_model_artifact == INCLUDE_BASE_MODEL:
logger.error(
"Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.",
base_model.id,
)
raise ConfigValidationError(
f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. "
)

class Config:
extra = "allow"
protected_namespaces = ()
1 change: 0 additions & 1 deletion ads/model/datascience_model_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ def _build_model_group_details(self) -> dict:
custom_metadata_list=custom_metadata_list,
base_model_id=self.base_model_id,
)
member_model_details.append(MemberModelDetails(model_id=self.base_model_id))
else:
model_group_details = HomogeneousModelGroupDetails(
custom_metadata_list=custom_metadata_list
Expand Down
Loading