diff --git a/ads/model/datascience_model_group.py b/ads/model/datascience_model_group.py new file mode 100644 index 000000000..cc32ffa9c --- /dev/null +++ b/ads/model/datascience_model_group.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +from typing import Dict, List, Union + +from ads.common.utils import batch_convert_case +from ads.config import COMPARTMENT_OCID, PROJECT_OCID +from ads.jobs.builders.base import Builder +from ads.model.model_metadata import ModelCustomMetadata +from ads.model.service.oci_datascience_model_group import OCIDataScienceModelGroup + +try: + from oci.data_science.models import ( + CreateModelGroupDetails, + CustomMetadata, + HomogeneousModelGroupDetails, + MemberModelDetails, + MemberModelEntries, + ModelGroup, + ModelGroupDetails, + ModelGroupSummary, + UpdateModelGroupDetails, + ) +except ModuleNotFoundError as err: + raise ModuleNotFoundError( + "The oci model group module was not found. Please run `pip install oci` " + "to install the latest oci sdk." + ) from err + +DEFAULT_WAIT_TIME = 1200 +DEFAULT_POLL_INTERVAL = 10 +ALLOWED_CREATE_TYPES = ["CREATE", "CLONE"] +MODEL_GROUP_KIND = "datascienceModelGroup" + + +class DataScienceModelGroup(Builder): + """Represents a Data Science Model Group. + + Attributes + ---------- + id: str + Model group ID. + project_id: str + Project OCID. + compartment_id: str + Compartment OCID. + display_name: str + Model group name. + description: str + Model group description. + freeform_tags: Dict[str, str] + Model group freeform tags. + defined_tags: Dict[str, Dict[str, object]] + Model group defined tags. + custom_metadata_list: ModelCustomMetadata + Model group custom metadata. + model_group_version_history_name: str + Model group version history name + model_group_version_history_id: str + Model group version history ID + version_label: str + Model group version label + version_id: str + Model group version id + lifecycle_state: str + Model group lifecycle state + lifecycle_details: str + Model group lifecycle details + + Methods + ------- + activate(self, ...) -> "DataScienceModelGroup" + Activates model group. + create(self, ...) -> "DataScienceModelGroup" + Creates model group. + deactivate(self, ...) -> "DataScienceModelGroup" + Deactivates model group. + delete(self, ...) -> "DataScienceModelGroup": + Deletes model group. + to_dict(self) -> dict + Serializes model group to a dictionary. + from_id(cls, id: str) -> "DataScienceModelGroup" + Gets an existing model group by OCID. + from_dict(cls, config: dict) -> "DataScienceModelGroup" + Loads model group instance from a dictionary of configurations. + update(self, ...) -> "DataScienceModelGroup" + Updates datascience model group in model catalog. + list(cls, compartment_id: str = None, **kwargs) -> List["DataScienceModelGroup"] + Lists datascience model groups in a given compartment. + sync(self): + Sync up a datascience model group with OCI datascience model group. + with_project_id(self, project_id: str) -> "DataScienceModelGroup" + Sets the project ID. + with_description(self, description: str) -> "DataScienceModelGroup" + Sets the description. + with_compartment_id(self, compartment_id: str) -> "DataScienceModelGroup" + Sets the compartment ID. + with_display_name(self, name: str) -> "DataScienceModelGroup" + Sets the name. + with_freeform_tags(self, **kwargs: Dict[str, str]) -> "DataScienceModelGroup" + Sets freeform tags. + with_defined_tags(self, **kwargs: Dict[str, Dict[str, object]]) -> "DataScienceModelGroup" + Sets defined tags. + with_custom_metadata_list(self, metadata: Union[ModelCustomMetadata, Dict]) -> "DataScienceModelGroup" + Sets model group custom metadata. + with_model_group_version_history_id(self, model_group_version_history_id: str) -> "DataScienceModelGroup": + Sets the model group version history ID. + with_version_label(self, version_label: str) -> "DataScienceModelGroup": + Sets the model group version label. + with_base_model_id(self, base_model_id) -> "DataScienceModelGroup": + Sets the base model ID. + with_member_models(self, member_models: List[Dict[str, str]]) -> "DataScienceModelGroup": + Sets the list of member models to be grouped. + + + Examples + -------- + >>> ds_model_group = (DataScienceModelGroup() + ... .with_compartment_id(os.environ["NB_SESSION_COMPARTMENT_OCID"]) + ... .with_project_id(os.environ["PROJECT_OCID"]) + ... .with_display_name("TestModelGroup") + ... .with_description("Testing the model group") + ... .with_freeform_tags(tag1="val1", tag2="val2") + >>> ds_model_group.create() + >>> ds_model_group.with_description("new description").update() + >>> ds_model_group.delete() + >>> DataScienceModelGroup.list() + """ + + CONST_ID = "id" + CONST_CREATE_TYPE = "createType" + CONST_COMPARTMENT_ID = "compartmentId" + CONST_PROJECT_ID = "projectId" + CONST_DISPLAY_NAME = "displayName" + CONST_DESCRIPTION = "description" + CONST_FREEFORM_TAG = "freeformTags" + CONST_DEFINED_TAG = "definedTags" + CONST_MODEL_GROUP_DETAILS = "modelGroupDetails" + CONST_MEMBER_MODEL_ENTRIES = "memberModelEntries" + CONST_CUSTOM_METADATA_LIST = "customMetadataList" + CONST_BASE_MODEL_ID = "baseModelId" + CONST_MEMBER_MODELS = "memberModels" + CONST_MODEL_GROUP_VERSION_HISTORY_ID = "modelGroupVersionHistoryId" + CONST_MODEL_GROUP_VERSION_HISTORY_NAME = "modelGroupVersionHistoryName" + CONST_LIFECYCLE_STATE = "lifecycleState" + CONST_LIFECYCLE_DETAILS = "lifecycleDetails" + CONST_TIME_CREATED = "timeCreated" + CONST_TIME_UPDATED = "timeUpdated" + CONST_CREATED_BY = "createdBy" + CONST_VERSION_LABEL = "versionLabel" + CONST_VERSION_ID = "versionId" + + attribute_map = { + CONST_ID: "id", + CONST_COMPARTMENT_ID: "compartment_id", + CONST_PROJECT_ID: "project_id", + CONST_DISPLAY_NAME: "display_name", + CONST_DESCRIPTION: "description", + CONST_FREEFORM_TAG: "freeform_tags", + CONST_DEFINED_TAG: "defined_tags", + CONST_LIFECYCLE_STATE: "lifecycle_state", + CONST_LIFECYCLE_DETAILS: "lifecycle_details", + CONST_TIME_CREATED: "time_created", + CONST_TIME_UPDATED: "time_updated", + CONST_CREATED_BY: "created_by", + CONST_MODEL_GROUP_VERSION_HISTORY_ID: "model_group_version_history_id", + CONST_MODEL_GROUP_VERSION_HISTORY_NAME: "model_group_version_history_name", + CONST_VERSION_LABEL: "version_label", + CONST_VERSION_ID: "version_id", + } + + def __init__(self, spec=None, **kwargs): + """Initializes datascience model group. + + Parameters + ---------- + spec: (Dict, optional). Defaults to None. + Object specification. + + kwargs: Dict + Specification as keyword arguments. + If 'spec' contains the same key as the one in kwargs, + the value from kwargs will be used. + + - project_id: str + - compartment_id: str + - display_name: str + - description: str + - defined_tags: Dict[str, Dict[str, object]] + - freeform_tags: Dict[str, str] + - custom_metadata_list: Union[ModelCustomMetadata, Dict] + - base_model_id: str + - member_models: List[Dict[str, str]] + - model_group_version_history_id: str + - version_label: str + """ + super().__init__(spec, **kwargs) + self.dsc_model_group = OCIDataScienceModelGroup() + + @property + def kind(self) -> str: + """The kind of the model group as showing in a YAML.""" + return MODEL_GROUP_KIND + + @property + def id(self) -> str: + """The model group OCID.""" + return self.get_spec(self.CONST_ID) + + @property + def lifecycle_state(self) -> str: + """The model group lifecycle state.""" + return self.get_spec(self.CONST_LIFECYCLE_STATE) + + @property + def lifecycle_details(self) -> str: + """The model group lifecycle details.""" + return self.get_spec(self.CONST_LIFECYCLE_DETAILS) + + @property + def create_type(self) -> str: + """The model group create type.""" + return self.get_spec(self.CONST_CREATE_TYPE) + + @property + def model_group_version_history_name(self) -> str: + """The model group version history name.""" + return self.get_spec(self.CONST_MODEL_GROUP_VERSION_HISTORY_NAME) + + @property + def version_id(self) -> str: + """The model group version id.""" + return self.get_spec(self.CONST_VERSION_ID) + + def with_create_type(self, create_type: str) -> "DataScienceModelGroup": + """Sets the create type. + + Parameters + ---------- + create_type: str + The create type of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + if create_type not in ALLOWED_CREATE_TYPES: + raise ValueError( + f"Invalid create type. Allowed create type are {ALLOWED_CREATE_TYPES}." + ) + return self.set_spec(self.CONST_CREATE_TYPE, create_type) + + @property + def compartment_id(self) -> str: + """The model group compartment id.""" + return self.get_spec(self.CONST_COMPARTMENT_ID) + + def with_compartment_id(self, compartment_id: str) -> "DataScienceModelGroup": + """Sets the compartment OCID. + + Parameters + ---------- + compartment_id: str + The compartment id of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_COMPARTMENT_ID, compartment_id) + + @property + def project_id(self) -> str: + """The model group project id.""" + return self.get_spec(self.CONST_PROJECT_ID) + + def with_project_id(self, project_id: str) -> "DataScienceModelGroup": + """Sets the project OCID. + + Parameters + ---------- + project_id: str + The project id of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_PROJECT_ID, project_id) + + @property + def display_name(self) -> str: + """The model group display name.""" + return self.get_spec(self.CONST_DISPLAY_NAME) + + def with_display_name(self, display_name: str) -> "DataScienceModelGroup": + """Sets the display name. + + Parameters + ---------- + display_name: str + The display name of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DISPLAY_NAME, display_name) + + @property + def description(self) -> str: + """The model group description.""" + return self.get_spec(self.CONST_DESCRIPTION) + + def with_description(self, description: str) -> "DataScienceModelGroup": + """Sets the description. + + Parameters + ---------- + description: str + The description of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DESCRIPTION, description) + + @property + def freeform_tags(self) -> Dict[str, str]: + """The model group freeform tags.""" + return self.get_spec(self.CONST_FREEFORM_TAG) + + def with_freeform_tags(self, **kwargs) -> "DataScienceModelGroup": + """Sets the freeform tags. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_FREEFORM_TAG, kwargs) + + @property + def defined_tags(self) -> Dict[str, Dict[str, object]]: + """The model group defined tags.""" + return self.get_spec(self.CONST_DEFINED_TAG) + + def with_defined_tags(self, **kwargs) -> "DataScienceModelGroup": + """Sets the defined tags. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DEFINED_TAG, kwargs) + + @property + def custom_metadata_list(self) -> ModelCustomMetadata: + """The model group custom metadata list.""" + return self.get_spec(self.CONST_CUSTOM_METADATA_LIST) + + def with_custom_metadata_list( + self, metadata: Union[ModelCustomMetadata, Dict] + ) -> "DataScienceModelGroup": + """Sets model group custom metadata. + + Parameters + ---------- + metadata: Union[ModelCustomMetadata, Dict] + The custom metadata. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + if metadata and isinstance(metadata, Dict): + metadata = ModelCustomMetadata.from_dict(metadata) + return self.set_spec(self.CONST_CUSTOM_METADATA_LIST, metadata) + + @property + def base_model_id(self) -> str: + """The model group base model id.""" + return self.get_spec(self.CONST_BASE_MODEL_ID) + + def with_base_model_id(self, base_model_id: str) -> "DataScienceModelGroup": + """Sets base model id. + + Parameters + ---------- + base_model_id: str + The base model id. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_BASE_MODEL_ID, base_model_id) + + @property + def member_models(self) -> List[Dict[str, str]]: + """The member models of model group.""" + return self.get_spec(self.CONST_MEMBER_MODELS) + + def with_member_models( + self, member_models: List[Dict[str, str]] + ) -> "DataScienceModelGroup": + """Sets member models to be grouped. + + Parameters + ---------- + member_models: List[Dict[str, str]] + The member models to be grouped. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_MEMBER_MODELS, member_models) + + @property + def model_group_version_history_id(self) -> str: + """The model group version history id.""" + return self.get_spec(self.CONST_MODEL_GROUP_VERSION_HISTORY_ID) + + def with_model_group_version_history_id( + self, model_group_version_history_id: str + ) -> "DataScienceModelGroup": + """Sets model group version history id. + + Parameters + ---------- + model_group_version_history_id: str + The model group version history id. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec( + self.CONST_MODEL_GROUP_VERSION_HISTORY_ID, model_group_version_history_id + ) + + @property + def version_label(self) -> str: + """The model group version label.""" + return self.get_spec(self.CONST_VERSION_LABEL) + + def with_version_label(self, version_label: str) -> "DataScienceModelGroup": + """Sets model group version label. + + Parameters + ---------- + version_label: str + The model group version label. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_VERSION_LABEL, version_label) + + def create( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Creates the datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be created before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.create( + create_model_group_details=CreateModelGroupDetails( + **batch_convert_case(self._build_model_group_details(), "snake") + ), + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def _build_model_group_details(self) -> dict: + """Builds model group details dict for creating or updating oci model group.""" + model_group_details = HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key=custom_metadata.key, + value=custom_metadata.value, + description=custom_metadata.description, + category=custom_metadata.category, + ) + for custom_metadata in self.custom_metadata_list._to_oci_metadata() + ] + ) + + member_model_entries = MemberModelEntries( + member_model_details=[ + MemberModelDetails(**member_model) + for member_model in self.member_models + ] + ) + + build_model_group_details = copy.deepcopy(self._spec) + build_model_group_details.pop(self.CONST_CUSTOM_METADATA_LIST) + build_model_group_details.pop(self.CONST_MEMBER_MODELS) + build_model_group_details.update( + { + self.CONST_COMPARTMENT_ID: self.compartment_id or COMPARTMENT_OCID, + self.CONST_PROJECT_ID: self.project_id or PROJECT_OCID, + self.CONST_MODEL_GROUP_DETAILS: model_group_details, + self.CONST_MEMBER_MODEL_ENTRIES: member_model_entries, + } + ) + + return build_model_group_details + + def _update_from_oci_model( + self, oci_model_group_instance: Union[ModelGroup, ModelGroupSummary] + ) -> "DataScienceModelGroup": + """Updates self spec from oci model group instance. + + Parameters + ---------- + oci_model_group_instance: Union[ModelGroup, ModelGroupSummary] + The oci model group instance, could be an instance of oci.data_science.models.ModelGroup + or oci.data_science.models.ModelGroupSummary. + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + self.dsc_model_group = oci_model_group_instance + for key, value in self.attribute_map.items(): + if hasattr(oci_model_group_instance, value): + self.set_spec(key, getattr(oci_model_group_instance, value)) + + model_group_details: ModelGroupDetails = ( + oci_model_group_instance.model_group_details + ) + custom_metadata_list: List[CustomMetadata] = ( + model_group_details.custom_metadata_list + ) + model_custom_metadata = ModelCustomMetadata() + for metadata in custom_metadata_list: + model_custom_metadata.add( + key=metadata.key, + value=metadata.value, + description=metadata.description, + category=metadata.category, + ) + self.set_spec(self.CONST_CUSTOM_METADATA_LIST, model_custom_metadata) + + # only updates member_models when oci_model_group_instance is an instance of + # oci.data_science.models.ModelGroup as oci.data_science.models.ModelGroupSummary + # doesn't have member_model_entries property. + if isinstance(oci_model_group_instance, ModelGroup): + member_model_entries: MemberModelEntries = ( + oci_model_group_instance.member_model_entries + ) + member_model_details: List[MemberModelDetails] = ( + member_model_entries.member_model_details + ) + + self.set_spec( + self.CONST_MEMBER_MODELS, + [ + { + "inference_key": member_model_detail.inference_key, + "model_id": member_model_detail.model_id, + } + for member_model_detail in member_model_details + ], + ) + + return self + + def update( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Updates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be updated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + update_model_group_details = OCIDataScienceModelGroup( + **self._build_model_group_details() + ).to_oci_model(UpdateModelGroupDetails) + + response = self.dsc_model_group.update( + update_model_group_details=update_model_group_details, + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def activate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Activates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be activated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.activate( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def deactivate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Deactivates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be deactivated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.deactivate( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def delete( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Deletes a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be deleted before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.delete( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + return self._update_from_oci_model(response) + + def sync(self) -> "DataScienceModelGroup": + """Updates the model group instance from backend. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self). + """ + if not self.id: + raise ValueError( + "Model group needs to be created before it can be fetched." + ) + return self._update_from_oci_model(OCIDataScienceModelGroup.from_id(self.id)) + + @classmethod + def list( + cls, + status: str = None, + compartment_id: str = None, + **kwargs, + ) -> List["DataScienceModelGroup"]: + """Lists datascience model groups in a given compartment. + + Parameters + ---------- + status: (str, optional). Defaults to `None`. + The status of model group. Allowed values: `ACTIVE`, `CREATING`, `DELETED`, `DELETING`, `FAILED` and `INACTIVE`. + compartment_id: (str, optional). Defaults to `None`. + The compartment OCID. + kwargs + Additional keyword arguments for filtering model groups. + + Returns + ------- + List[DataScienceModelGroup] + The list of the datascience model groups. + """ + return [ + cls()._update_from_oci_model(model_group_summary) + for model_group_summary in OCIDataScienceModelGroup.list( + status=status, + compartment_id=compartment_id, + **kwargs, + ) + ] + + @classmethod + def from_id(cls, model_group_id: str) -> "DataScienceModelGroup": + """Loads the model group instance from ocid. + + Parameters + ---------- + model_group_id: str + The ocid of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self). + """ + oci_model_group = OCIDataScienceModelGroup.from_id(model_group_id) + return cls()._update_from_oci_model(oci_model_group) + + def to_dict(self) -> Dict: + """Serializes model group to a dictionary. + + Returns + ------- + dict + The model group serialized as a dictionary. + """ + spec = copy.deepcopy(self._spec) + for key, value in spec.items(): + if hasattr(value, "to_dict"): + value = value.to_dict() + spec[key] = value + + return { + "kind": self.kind, + "type": self.type, + "spec": batch_convert_case(spec, "camel"), + } + + @classmethod + def from_dict(cls, config: Dict) -> "DataScienceModelGroup": + """Loads model group instance from a dictionary of configurations. + + Parameters + ---------- + config: Dict + A dictionary of configurations. + + Returns + ------- + DataScienceModelGroup + The model group instance. + """ + return cls(spec=batch_convert_case(copy.deepcopy(config["spec"]), "snake")) diff --git a/ads/model/service/oci_datascience_model_group.py b/ads/model/service/oci_datascience_model_group.py new file mode 100644 index 000000000..b1fb72a3c --- /dev/null +++ b/ads/model/service/oci_datascience_model_group.py @@ -0,0 +1,488 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import logging +from functools import wraps +from typing import Callable + +import oci + +from ads.common.oci_datascience import OCIDataScienceMixin +from ads.common.work_request import DataScienceWorkRequest +from ads.model.deployment.common.utils import OCIClientManager, State + +try: + from oci.data_science.models import CreateModelGroupDetails, UpdateModelGroupDetails +except ModuleNotFoundError as err: + raise ModuleNotFoundError( + "The oci model group module was not found. Please run `pip install oci` " + "to install the latest oci sdk." + ) from err + +logger = logging.getLogger(__name__) + +DEFAULT_WAIT_TIME = 1200 +DEFAULT_POLL_INTERVAL = 10 +ALLOWED_STATUS = [ + State.ACTIVE.name, + State.CREATING.name, + State.DELETED.name, + State.DELETING.name, + State.FAILED.name, + State.INACTIVE.name, +] +MODEL_GROUP_NEEDS_TO_BE_CREATED = ( + "Missing model group id. Model group needs to be created before it can be accessed." +) + + +def check_for_model_group_id(msg: str = MODEL_GROUP_NEEDS_TO_BE_CREATED): + """The decorator helping to check if the ID attribute sepcified for a datascience model group. + + Parameters + ---------- + msg: str + The message that will be thrown. + + Raises + ------ + MissingModelGroupIdError + In case if the ID attribute not specified. + + Examples + -------- + >>> @check_for_id(msg="Some message.") + ... def test_function(self, name: str, last_name: str) + ... pass + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.id: + raise MissingModelGroupIdError(msg) + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +class MissingModelGroupIdError(Exception): # pragma: no cover + pass + + +class OCIDataScienceModelGroup( + OCIDataScienceMixin, + oci.data_science.models.ModelGroup, +): + """Represents an OCI Data Science Model Group. + This class contains all attributes of the `oci.data_science.models.ModelGroup`. + The main purpose of this class is to link the `oci.data_science.models.ModelGroup` + and the related client methods. + Linking the `ModelGroup` (payload) to Create/Update/Delete/Activate/Deactivate methods. + + The `OCIDataScienceModelGroup` can be initialized by unpacking the properties stored in a dictionary: + + .. code-block:: python + + properties = { + "compartment_id": "", + "name": "", + "description": "", + } + ds_model_group = OCIDataScienceModelGroup(**properties) + + The properties can also be OCI REST API payload, in which the keys are in camel format. + + .. code-block:: python + + payload = { + "compartmentId": "", + "name": "", + "description": "", + } + ds_model_group = OCIDataScienceModelGroup(**payload) + + Methods + ------- + activate(self, ...) -> "OCIDataScienceModelGroup": + Activates datascience model group. + create(self, ...) -> "OCIDataScienceModelGroup" + Creates datascience model group. + deactivate(self, ...) -> "OCIDataScienceModelGroup": + Deactivates datascience model group. + delete(self, ...) -> "OCIDataScienceModelGroup": + Deletes datascience model group. + update(self, ...) -> "OCIDataScienceModelGroup": + Updates datascience model group. + list(self, ...) -> list[oci.data_science.models.ModelGroupSummary]: + List oci.data_science.models.ModelGroupSummary instances within given compartment. + from_id(cls, model_group: str) -> "OCIDataScienceModelGroup": + Gets model group by OCID. + + Examples + -------- + >>> oci_model_group = OCIDataScienceModelGroup.from_id() + >>> oci_model_group.deactivate() + >>> oci_model_group.activate(wait_for_completion=False) + >>> oci_model_group.description = "A brand new description" + ... oci_model_group.update() + >>> oci_model_group.sync() + >>> oci_model_group.list(status="ACTIVE") + >>> oci_model_group.delete(wait_for_completion=False) + """ + + def __init__(self, config=None, signer=None, client_kwargs=None, **kwargs): + super().__init__(config, signer, client_kwargs, **kwargs) + self.workflow_req_id = None + + def create( + self, + create_model_group_details: CreateModelGroupDetails, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Creates datascience model group. + + Parameters + ---------- + create_model_group_details: CreateModelGroupDetails + An instance of CreateModelGroupDetails which consists of all + necessary parameters to create a data science model group. + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + response = self.client.create_model_group(create_model_group_details) + self.update_from_oci_model(response.data) + logger.info(f"Creating model group `{self.id}`.") + print(f"Model Group OCID: {self.id}") + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Creating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error("Error while trying to create model group: " + str(e)) + + return self.sync() + + @check_for_model_group_id( + msg="Model group needs to be created before it can be activated or deactivated.." + ) + def activate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Activates datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = OCIDataScienceModelGroup.from_id(self.id) + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE: + raise Exception( + f"Model group {dsc_model_group.id} is already in active state." + ) + + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE: + logger.info(f"Activating model group `{self.id}`.") + response = self.client.activate_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Activating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error( + "Error while trying to activate model group: " + str(e) + ) + + return self.sync() + else: + raise Exception( + f"Can't activate model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + + @check_for_model_group_id( + msg="Model group needs to be created before it can be activated or deactivated.." + ) + def deactivate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Deactivates datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = self.from_id(self.id) + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE: + raise Exception( + f"Model group {dsc_model_group.id} is already in inactive state." + ) + + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE: + logger.info(f"Deactivating model group `{self.id}`.") + response = self.client.deactivate_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Deactivating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error( + "Error while trying to deactivate model group: " + str(e) + ) + + return self.sync() + else: + raise Exception( + f"Can't deactivate model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + + @check_for_model_group_id( + msg="Model group needs to be created before it can be deleted." + ) + def delete( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Deletes datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = self.from_id(self.id) + if dsc_model_group.lifecycle_state in [ + self.LIFECYCLE_STATE_DELETED, + self.LIFECYCLE_STATE_DELETING, + ]: + raise Exception( + f"Model group {dsc_model_group.id} is either deleted or being deleted." + ) + if dsc_model_group.lifecycle_state not in [ + self.LIFECYCLE_STATE_ACTIVE, + self.LIFECYCLE_STATE_FAILED, + self.LIFECYCLE_STATE_INACTIVE, + ]: + raise Exception( + f"Can't delete model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + logger.info(f"Deleting model group `{self.id}`.") + response = self.client.delete_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Deleting model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error("Error while trying to delete model group: " + str(e)) + + return self.sync() + + @check_for_model_group_id( + msg="Model group needs to be created before it can be updated." + ) + def update( + self, + update_model_group_details: UpdateModelGroupDetails, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Updates datascience model group. + + Parameters + ---------- + update_model_group_details: UpdateModelGroupDetails + Details to update model group. + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + if wait_for_completion: + wait_for_states = [ + self.LIFECYCLE_STATE_ACTIVE, + self.LIFECYCLE_STATE_FAILED, + ] + else: + wait_for_states = [] + + try: + response = self.client_composite.update_model_group_and_wait_for_state( + self.id, + update_model_group_details, + wait_for_states=wait_for_states, + waiter_kwargs={ + "max_interval_seconds": poll_interval, + "max_wait_seconds": max_wait_time, + }, + ) + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + except Exception as e: + logger.error("Error while trying to update model group: " + str(e)) + + return self.sync() + + @classmethod + def list( + cls, + status: str = None, + compartment_id: str = None, + **kwargs, + ) -> list: + """Lists the model group associated with current compartment id and status + + Parameters + ---------- + status : str + Status of model group. Defaults to None. + Allowed values: `ACTIVE`, `CREATING`, `DELETED`, `DELETING`, `FAILED` and `INACTIVE`. + compartment_id : str + Target compartment to list model groups from. + Defaults to the compartment set in the environment variable "NB_SESSION_COMPARTMENT_OCID". + If "NB_SESSION_COMPARTMENT_OCID" is not set, the root compartment ID will be used. + An ValueError will be raised if root compartment ID cannot be determined. + kwargs : + The values are passed to oci.data_science.DataScienceClient.list_model_groups. + + Returns + ------- + list + A list of oci.data_science.models.ModelGroupSummary objects. + + Raises + ------ + ValueError + If compartment_id is not specified and cannot be determined from the environment. + """ + compartment_id = compartment_id or OCIClientManager().default_compartment_id() + + if not compartment_id: + raise ValueError( + "Unable to determine compartment ID from environment. Specify `compartment_id`." + ) + + if status is not None: + if status not in ALLOWED_STATUS: + raise ValueError( + f"Allowed `status` values are: {', '.join(ALLOWED_STATUS)}." + ) + kwargs["lifecycle_state"] = status + + # https://oracle-cloud-infrastructure-python-sdk.readthedocs.io/en/latest/api/pagination.html#module-oci.pagination + return oci.pagination.list_call_get_all_results( + cls().client.list_model_groups, compartment_id, **kwargs + ).data + + @classmethod + def from_id(cls, model_group_id: str) -> "OCIDataScienceModelGroup": + """Gets datascience model group by OCID. + + Parameters + ---------- + model_group_id: str + The OCID of the datascience model group. + + Returns + ------- + OCIDataScienceModelGroup + An instance of `OCIDataScienceModelGroup`. + """ + return super().from_ocid(model_group_id) diff --git a/tests/unitary/default_setup/model/test_model_group.py b/tests/unitary/default_setup/model/test_model_group.py new file mode 100644 index 000000000..6b7ac2358 --- /dev/null +++ b/tests/unitary/default_setup/model/test_model_group.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +import unittest +from unittest.mock import patch +from ads.model.datascience_model_group import DataScienceModelGroup +from ads.model.model_metadata import ModelCustomMetadata + +try: + from oci.data_science.models import ( + ModelGroup, + HomogeneousModelGroupDetails, + MemberModelEntries, + CustomMetadata, + MemberModelDetails, + ModelGroupSummary, + ) +except (ImportError, AttributeError) as e: + raise unittest.SkipTest( + "Support for OCI Model Group is not available. Skipping the Model Group tests." + ) + +MODEL_GROUP_DICT = { + "kind": "datascienceModelGroup", + "type": "dataScienceModelGroup", + "spec": { + "displayName": "test_create_model_group", + "description": "test create model group description", + "freeformTags": {"test_key": "test_value"}, + "customMetadataList": { + "data": [ + { + "key": "test_key", + "value": "test_value", + "description": "test_description", + "category": "other", + "has_artifact": False, + } + ] + }, + "memberModels": [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ], + }, +} + +MODEL_GROUP_SPEC = { + "display_name": "test_create_model_group", + "description": "test create model group description", + "freeform_tags": {"test_key": "test_value"}, + "custom_metadata_list": { + "data": [ + { + "key": "test_key", + "value": "test_value", + "description": "test_description", + "category": "other", + "has_artifact": False, + } + ] + }, + "member_models": [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ], +} + +OCI_MODEL_GROUP_RESPONSE = { + "id": "test_model_group_id", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "created_by": "test_create_by", + "time_created": "2025-06-10T18:21:17.613000Z", + "time_updated": "2025-06-10T18:21:17.613000Z", + "lifecycle_state": "ACTIVE", + "lifecycle_details": "test lifecycle details", + "model_group_version_history_id": "test_model_group_version_history_id", + "model_group_version_history_name": "test_model_group_version_history_name", + "version_label": "test_version_label", + "version_id": 1, + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, +} + + +class TestModelGroup: + def initialize_model_group(self): + custom_metadata = ModelCustomMetadata() + custom_metadata.add( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + + model_group = ( + DataScienceModelGroup() + .with_display_name("test_create_model_group") + .with_description("test create model group description") + .with_freeform_tags(**{"test_key": "test_value"}) + .with_custom_metadata_list(custom_metadata) + .with_member_models( + [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ] + ) + ) + + return model_group + + def test_initialize_model_group(self): + model_group_one = self.initialize_model_group() + assert model_group_one.to_dict() == MODEL_GROUP_DICT + + model_group_two = DataScienceModelGroup.from_dict(MODEL_GROUP_DICT) + assert model_group_two.to_dict() == MODEL_GROUP_DICT + + model_group_three = DataScienceModelGroup(spec=MODEL_GROUP_SPEC) + assert model_group_three.to_dict() == MODEL_GROUP_DICT + + model_group_four = DataScienceModelGroup(**MODEL_GROUP_SPEC) + assert model_group_four.to_dict() == MODEL_GROUP_DICT + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.create" + ) + def test_create(self, mock_dsc_model_group_create): + mock_dsc_model_group_create.return_value = ModelGroup( + **OCI_MODEL_GROUP_RESPONSE + ) + model_group = self.initialize_model_group() + model_group.create() + + mock_dsc_model_group_create.assert_called() + + assert model_group.id == OCI_MODEL_GROUP_RESPONSE["id"] + assert model_group.display_name == OCI_MODEL_GROUP_RESPONSE["display_name"] + assert model_group.description == OCI_MODEL_GROUP_RESPONSE["description"] + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.activate" + ) + def test_activate(self, mock_dsc_model_group_activate): + mock_dsc_model_group_activate.return_value = ModelGroup( + **OCI_MODEL_GROUP_RESPONSE + ) + model_group = self.initialize_model_group() + model_group.activate( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_activate.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "ACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.deactivate" + ) + def test_deactivate(self, mock_dsc_model_group_deactivate): + mock_dsc_model_group_deactivate_response = copy.deepcopy( + OCI_MODEL_GROUP_RESPONSE + ) + mock_dsc_model_group_deactivate_response["lifecycle_state"] = "INACTIVE" + + mock_dsc_model_group_deactivate.return_value = ModelGroup( + **mock_dsc_model_group_deactivate_response + ) + model_group = self.initialize_model_group() + model_group.deactivate( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_deactivate.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "INACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.delete" + ) + def test_delete(self, mock_dsc_model_group_delete): + mock_dsc_model_group_delete_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_delete_response["lifecycle_state"] = "DELETED" + + mock_dsc_model_group_delete.return_value = ModelGroup( + **mock_dsc_model_group_delete_response + ) + model_group = self.initialize_model_group() + model_group.delete( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_delete.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "DELETED" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.update" + ) + def test_update(self, mock_dsc_model_group_update): + mock_dsc_model_group_update_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_update_response["display_name"] = "updated display name" + mock_dsc_model_group_update_response["description"] = "updated description" + + mock_dsc_model_group_update.return_value = ModelGroup( + **mock_dsc_model_group_update_response + ) + model_group = self.initialize_model_group() + model_group.update() + + mock_dsc_model_group_update.assert_called() + + assert ( + model_group.display_name + == mock_dsc_model_group_update_response["display_name"] + ) + assert ( + model_group.description + == mock_dsc_model_group_update_response["description"] + ) + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.list" + ) + def test_list(self, mock_dsc_model_group_list): + mock_dsc_model_group_list_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_list_response.pop("member_model_entries") + mock_dsc_model_group_list_response.pop("description") + mock_dsc_model_group_list.return_value = [ + ModelGroupSummary(**mock_dsc_model_group_list_response) + ] + + model_groups = DataScienceModelGroup.list( + status="ACTIVE", compartment_id="test_model_group_compartment_id" + ) + + mock_dsc_model_group_list.assert_called_with( + status="ACTIVE", compartment_id="test_model_group_compartment_id" + ) + + assert len(model_groups) == 1 diff --git a/tests/unitary/default_setup/model/test_oci_model_group.py b/tests/unitary/default_setup/model/test_oci_model_group.py new file mode 100644 index 000000000..90c370609 --- /dev/null +++ b/tests/unitary/default_setup/model/test_oci_model_group.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +import unittest +from unittest.mock import MagicMock, patch +from ads.model.service.oci_datascience_model_group import OCIDataScienceModelGroup + +try: + from oci.data_science.models import ( + ModelGroup, + HomogeneousModelGroupDetails, + MemberModelEntries, + CustomMetadata, + MemberModelDetails, + CreateModelGroupDetails, + UpdateModelGroupDetails, + ) +except (ImportError, AttributeError) as e: + raise unittest.SkipTest( + "Support for OCI Model Group is not available. Skipping the Model Group tests." + ) + +CREATE_MODEL_GROUP_DETAILS = { + "create_type": "CREATE", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, + "model_group_version_history_id": "test_model_group_version_history_id", + "version_label": "test_version_label", +} + +UPDATE_MODEL_GROUP_DETAILS = { + "display_name": "test_update_model_group", + "description": "test update model group description", + "model_group_version_history_id": "test_model_group_version_history_id", + "version_label": "test_version_label", + "freeform_tags": {"test_updated_key": "test_updated_value"}, +} + +OCI_MODEL_GROUP_RESPONSE = { + "id": "test_model_group_id", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "created_by": "test_create_by", + "time_created": "2025-06-10T18:21:17.613000Z", + "time_updated": "2025-06-10T18:21:17.613000Z", + "lifecycle_state": "ACTIVE", + "lifecycle_details": "test lifecycle details", + "model_group_version_history_id": "test_model_group_version_history_id", + "model_group_version_history_name": "test_model_group_version_history_name", + "version_label": "test_version_label", + "version_id": 1, + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, +} + + +class TestOCIModelGroup: + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.create_model_group") + def test_create(self, mock_create_model_group, mock_sync): + mock_sync.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + create_model_group_details = CreateModelGroupDetails( + **CREATE_MODEL_GROUP_DETAILS + ) + oci_model_group = OCIDataScienceModelGroup().create( + create_model_group_details=create_model_group_details, + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_create_model_group.assert_called_with(create_model_group_details) + + assert oci_model_group.id == OCI_MODEL_GROUP_RESPONSE["id"] + assert oci_model_group.display_name == OCI_MODEL_GROUP_RESPONSE["display_name"] + assert oci_model_group.description == OCI_MODEL_GROUP_RESPONSE["description"] + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.activate_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_activate(self, mock_from_id, mock_activate_model_group, mock_sync): + mock_oci_model_group_activate_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_activate_response["lifecycle_state"] = "INACTIVE" + mock_from_id.return_value = ModelGroup(**mock_oci_model_group_activate_response) + mock_sync.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).activate( + wait_for_completion=False, max_wait_time=1, poll_interval=2 + ) + + mock_activate_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "ACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.deactivate_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_deactivate(self, mock_from_id, mock_deactivate_model_group, mock_sync): + mock_oci_model_group_deactivate_response = copy.deepcopy( + OCI_MODEL_GROUP_RESPONSE + ) + mock_oci_model_group_deactivate_response["lifecycle_state"] = "INACTIVE" + mock_from_id.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_deactivate_response) + oci_model_group = OCIDataScienceModelGroup( + **mock_oci_model_group_deactivate_response + ).deactivate(wait_for_completion=False, max_wait_time=1, poll_interval=2) + + mock_deactivate_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "INACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.delete_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_delete(self, mock_from_id, mock_delete_model_group, mock_sync): + mock_oci_model_group_delete_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_delete_response["lifecycle_state"] = "DELETED" + mock_from_id.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_delete_response) + + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).delete( + wait_for_completion=False, max_wait_time=1, poll_interval=2 + ) + + mock_delete_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "DELETED" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch( + "oci.data_science.DataScienceClientCompositeOperations.update_model_group_and_wait_for_state" + ) + def test_update(self, mock_update_model_group, mock_sync): + mock_oci_model_group_update_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_update_response.update(**UPDATE_MODEL_GROUP_DETAILS) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_update_response) + update_model_group_details = UpdateModelGroupDetails( + **UPDATE_MODEL_GROUP_DETAILS + ) + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).update( + update_model_group_details=update_model_group_details, + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_update_model_group.assert_called_with( + oci_model_group.id, + update_model_group_details, + wait_for_states=[], + waiter_kwargs={ + "max_interval_seconds": 2, + "max_wait_seconds": 1, + }, + ) + + assert oci_model_group.id == mock_oci_model_group_update_response["id"] + assert ( + oci_model_group.display_name + == mock_oci_model_group_update_response["display_name"] + ) + assert ( + oci_model_group.description + == mock_oci_model_group_update_response["description"] + ) + assert ( + oci_model_group.freeform_tags + == mock_oci_model_group_update_response["freeform_tags"] + ) + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_from_id(self, mock_from_id): + OCIDataScienceModelGroup.from_id(OCI_MODEL_GROUP_RESPONSE["id"]) + mock_from_id.assert_called_with(OCI_MODEL_GROUP_RESPONSE["id"]) + + @patch("oci.pagination.list_call_get_all_results") + def test_list(self, mock_list_call_get_all_results): + response = MagicMock() + response.data = [MagicMock()] + mock_list_call_get_all_results.return_value = response + model_groups = OCIDataScienceModelGroup.list( + status="ACTIVE", + compartment_id="test_compartment_id", + ) + mock_list_call_get_all_results.assert_called() + assert isinstance(model_groups, list)