Skip to content

[Ready For Review][AQUA] Add Supporting Fine-Tuned Models in Multi-Model Deployment #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ads.aqua import logger
from ads.aqua.common.entities import ModelConfigResult
from ads.aqua.common.enums import ConfigFolder, Tags
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.errors import AquaValueError
from ads.aqua.common.utils import (
_is_valid_mvs,
get_artifact_path,
Expand Down Expand Up @@ -284,8 +284,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
logger.info(f"Artifact not found in model {model_id}.")
return False

@cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
def get_config_from_metadata(
self, model_id: str, metadata_key: str
self,
model_id: str,
metadata_key: str,
) -> ModelConfigResult:
"""Gets the config for the given Aqua model from model catalog metadata content.

Expand All @@ -300,8 +303,9 @@ def get_config_from_metadata(
ModelConfigResult
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
"""
config = {}
config: Dict[str, Any] = {}
oci_model = self.ds_client.get_model(model_id).data

try:
config = self.ds_client.get_model_defined_metadatum_artifact_content(
model_id, metadata_key
Expand All @@ -321,7 +325,7 @@ def get_config_from_metadata(
)
return ModelConfigResult(config=config, model_details=oci_model)

@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
def get_config(
self,
model_id: str,
Expand All @@ -346,8 +350,10 @@ def get_config(
ModelConfigResult
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
"""
config_folder = config_folder or ConfigFolder.CONFIG
config: Dict[str, Any] = {}
oci_model = self.ds_client.get_model(model_id).data

config_folder = config_folder or ConfigFolder.CONFIG
oci_aqua = (
(
Tags.AQUA_TAG in oci_model.freeform_tags
Expand All @@ -357,9 +363,9 @@ def get_config(
else False
)
if not oci_aqua:
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
return ModelConfigResult(config=config, model_details=oci_model)

config: Dict[str, Any] = {}
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
if not artifact_path:
logger.debug(
Expand Down
36 changes: 31 additions & 5 deletions ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import re
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from oci.data_science.models import Model
from pydantic import BaseModel, Field, model_validator
Expand Down Expand Up @@ -136,6 +136,28 @@ def set_gpu_specs(cls, model: "ComputeShapeSummary") -> "ComputeShapeSummary":
return model


class LoraModuleSpec(Serializable):
"""
Lightweight descriptor for LoRA Modules used in fine-tuning models.

Attributes
----------
model_id : str
The unique identifier of the fine tuned model.
model_name : str
The name of the fine-tuned model.
model_path : str
The model-by-reference path to the LoRA Module within the model artifact
"""

model_id: Optional[str] = Field(None, description="The fine tuned model OCID to deploy.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs protected_namespaces = () in config to avoid warning messages showing up when running via CLI.

model_name: Optional[str] = Field(None, description="The name of the fine-tuned model.")
model_path: Optional[str] = Field(
None,
description="The model-by-reference path to the LoRA Module within the model artifact.",
)


class AquaMultiModelRef(Serializable):
"""
Lightweight model descriptor used for multi-model deployment.
Expand All @@ -157,7 +179,7 @@ class AquaMultiModelRef(Serializable):
Optional environment variables to override during deployment.
artifact_location : Optional[str]
Artifact path of model in the multimodel group.
fine_tune_weights_location : Optional[str]
fine_tune_weights : Optional[List[LoraModuleSpec]]
For fine tuned models, the artifact path of the modified model weights
"""

Expand All @@ -166,15 +188,19 @@ class AquaMultiModelRef(Serializable):
gpu_count: Optional[int] = Field(
None, description="The gpu count allocation for the model."
)
model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
model_task: Optional[str] = Field(
None,
description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType",
)
env_var: Optional[dict] = Field(
default_factory=dict, description="The environment variables of the model."
)
artifact_location: Optional[str] = Field(
None, description="Artifact path of model in the multimodel group."
)
fine_tune_weights_location: Optional[str] = Field(
None, description="For fine tuned models, the artifact path of the modified model weights"
fine_tune_weights: Optional[List[LoraModuleSpec]] = Field(
None,
description="For fine tuned models, the artifact path of the modified model weights",
)

class Config:
Expand Down
35 changes: 35 additions & 0 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,41 @@ def get_combined_params(params1: str = None, params2: str = None) -> str:
return " ".join(combined_params)


def find_restricted_params(
default_params: Union[str, List[str]],
user_params: Union[str, List[str]],
container_family: str,
) -> List[str]:
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
The default parameters coming from the container index json file cannot be overridden.

Parameters
----------
default_params:
Inference container parameter string with default values.
user_params:
Inference container parameter string with user provided values.
container_family: str
The image family of model deployment container runtime.

Returns
-------
A list with params keys common between params1 and params2.

"""
restricted_params = []
if default_params and user_params:
default_params_dict = get_params_dict(default_params)
user_params_dict = get_params_dict(user_params)

restricted_params_set = get_restricted_params_by_container(container_family)
for key, _items in user_params_dict.items():
if key in default_params_dict or key in restricted_params_set:
restricted_params.append(key.lstrip("-"))

return restricted_params


def build_params_string(params: dict) -> str:
"""Builds params string from params dict

Expand Down
9 changes: 5 additions & 4 deletions ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,11 @@ def validate_model_name(
raise AquaRuntimeError(error_message) from ex

# Build the list of valid model names from custom metadata.
model_names = [
AquaMultiModelRef(**metadata).model_name
for metadata in multi_model_metadata
]
model_names = []
for metadata in multi_model_metadata:
model = AquaMultiModelRef(**metadata)
model_names.append(model.model_name)
model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)

# Check if the provided model name is among the valid names.
if user_model_name not in model_names:
Expand Down
20 changes: 19 additions & 1 deletion ads/aqua/model/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ class FineTuningCustomMetadata(ExtendedEnum):

class MultiModelSupportedTaskType(ExtendedEnum):
TEXT_GENERATION = "text_generation"
TEXT_GENERATION_INFERENCE = "text_generation_inference"
TEXT2TEXT_GENERATION = "text2text_generation"
SUMMARIZATION = "summarization"
TRANSLATION = "translation"
CONVERSATIONAL = "conversational"
FEATURE_EXTRACTION = "feature_extraction"
SENTENCE_SIMILARITY = "sentence_similarity"
AUTOMATIC_SPEECH_RECOGNITION = "automatic_speech_recognition"
TEXT_TO_SPEECH = "text_to_speech"
TEXT_TO_IMAGE = "text_to_image"
TEXT_EMBEDDING = "text_embedding"
IMAGE_TEXT_TO_TEXT = "image_text_to_text"
CODE_SYNTHESIS = "code_synthesis"
EMBEDDING = "text_embedding"
QUESTION_ANSWERING = "question_answering"
AUDIO_CLASSIFICATION = "audio_classification"
AUDIO_TO_AUDIO = "audio_to_audio"
IMAGE_CLASSIFICATION = "image_classification"
IMAGE_TO_TEXT = "image_to_text"
IMAGE_TO_IMAGE = "image_to_image"
VIDEO_CLASSIFICATION = "video_classification"
TIME_SERIES_FORECASTING = "time_series_forecasting"
79 changes: 45 additions & 34 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ads.aqua import logger
from ads.aqua.app import AquaApp
from ads.aqua.common.entities import AquaMultiModelRef
from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec
from ads.aqua.common.enums import (
ConfigFolder,
CustomInferenceContainerTypeFamily,
Expand Down Expand Up @@ -89,12 +89,7 @@
)
from ads.common.auth import default_signer
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
from ads.common.utils import (
UNKNOWN,
get_console_link,
is_path_exists,
read_file,
)
from ads.common.utils import UNKNOWN, get_console_link, is_path_exists, read_file
from ads.config import (
AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
Expand Down Expand Up @@ -300,57 +295,73 @@ def create_multi(

selected_models_deployment_containers = set()

# Process each model
# Process each model in the input list
for model in models:
# Retrieve model metadata from the Model Catalog using the model ID
source_model = DataScienceModel.from_id(model.model_id)
display_name = source_model.display_name
model_file_description = source_model.model_file_description
# Update model name in user's input model
# If model_name is not explicitly provided, use the model's display name
model.model_name = model.model_name or display_name

# TODO Uncomment the section below, if only service models should be allowed for multi-model deployment
# if not source_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, UNKNOWN):
# raise AquaValueError(
# f"Invalid selected model {display_name}. "
# "Currently only service models are supported for multi model deployment."
# )
if not model_file_description:
raise AquaValueError(
f"Model '{source_model.display_name}' (ID: {model.model_id}) has no file description. "
"Please register the model first."
)

# check if model is a fine-tuned model and if so, add the fine tuned weights path to the fine_tune_weights_location pydantic field
# Check if the model is a fine-tuned model based on its tags
is_fine_tuned_model = (
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
)

base_model_artifact_path = ""
fine_tune_path = ""

if is_fine_tuned_model:
model.model_id, model.model_name = extract_base_model_from_ft(
source_model
)
model_artifact_path, model.fine_tune_weights_location = (
# Extract artifact paths for the base and fine-tuned model
base_model_artifact_path, fine_tune_path = (
extract_fine_tune_artifacts_path(source_model)
)

else:
# Retrieve model artifact for base models
model_artifact_path = source_model.artifact
# Create a single LoRA module specification for the fine-tuned model
# TODO: Support multiple LoRA modules in the future
model.fine_tune_weights = [
LoraModuleSpec(
model_id=model.model_id,
model_name=model.model_name,
model_path=fine_tune_path,
)
]

display_name_list.append(display_name)
# Use the LoRA module name as the model's display name
display_name = model.model_name

self._extract_model_task(model, source_model)
# Temporarily override model ID and name with those of the base model
# TODO: Revisit this logic once proper base/FT model handling is implemented
model.model_id, model.model_name = extract_base_model_from_ft(
source_model
)
else:
# For base models, use the original artifact path
base_model_artifact_path = source_model.artifact
display_name = model.model_name

if not model_artifact_path:
if not base_model_artifact_path:
# Fail if no artifact is found for the base model model
raise AquaValueError(
f"Model '{display_name}' (ID: {model.model_id}) has no artifacts. "
f"Model '{model.model_name}' (ID: {model.model_id}) has no artifacts. "
"Please register the model first."
)

# Update model artifact location in user's input model
model.artifact_location = model_artifact_path
# Update the artifact path in the model configuration
model.artifact_location = base_model_artifact_path
display_name_list.append(display_name)

if not model_file_description:
raise AquaValueError(
f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
"Please register the model first."
)
# Extract model task metadata from source model
self._extract_model_task(model, source_model)

# Track model file description in a validated structure
model_file_description_list.append(
ModelFileDescription(**model_file_description)
)
Expand Down
3 changes: 1 addition & 2 deletions ads/aqua/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""AQUA model utils"""

from typing import Dict, Optional, Tuple
from typing import Tuple

from ads.aqua.common.entities import AquaMultiModelRef
from ads.aqua.common.errors import AquaValueError
from ads.aqua.common.utils import get_model_by_reference_paths
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
Expand Down
Loading