From e8b190a3610a8852a77c56a96820d56c5a3fa455 Mon Sep 17 00:00:00 2001 From: Leila Date: Wed, 18 Sep 2024 13:50:49 -0400 Subject: [PATCH 1/7] implement: multiple save and load for mlflow registry --- .gitignore | 5 + numalogic/registry/mlflow_registry.py | 133 ++++++++++++++++++++++++- tests/registry/_mlflow_utils.py | 91 +++++++++++++++++ tests/registry/test_mlflow_registry.py | 56 ++++++++++- 4 files changed, 280 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 3db98ce0..c50fd26c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,8 @@ target/ #mlflow /.mlruns *.db +mlruns/ +mlartifacts/ # Jupyter Notebook .ipynb_checkpoints @@ -169,4 +171,7 @@ cython_debug/ # Mac related *.DS_Store +# vscode +.vscode/ + .python-version diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 43c5b519..62948928 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -18,15 +18,17 @@ import mlflow.pyfunc import mlflow.pytorch import mlflow.sklearn +import mlflow from mlflow.entities.model_registry import ModelVersion from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient +from sortedcontainers import SortedSet from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError -from numalogic.tools.types import artifact_t, KEYS, META_VT +from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT _LOGGER = logging.getLogger(__name__) @@ -187,6 +189,39 @@ def load( self._save_in_cache(model_key, artifact_data) return artifact_data + def load_multiple( + self, + skeys: KEYS, + dkeys_list: list[list[str]], + ) -> Optional[dict[str, ArtifactData]]: + """ + Load multiple artifacts from the registry for pyfunc models. + Args: + skeys (KEYS): The source keys of the artifacts to load. + dkeys_list (list[list[str]]): + A list of lists containing the dkeys of the artifacts to load. + + Returns + ------- + Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys + to the loaded artifacts, or None if no artifacts were found. + """ + dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") + if loaded_model is not None: + metadata = loaded_model.artifact.unwrap_python_model().metadata + dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts + artifacts_dict = {} + for artifact in dict_artifacts.values(): + artifact_data = ArtifactData( + artifact=artifact.artifact, metadata=metadata, extras=None + ) + dynamic_key = ":".join(artifact.dkeys) + artifacts_dict[dynamic_key] = artifact_data + else: + artifacts_dict = None + return artifacts_dict + @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST: @@ -225,7 +260,10 @@ def save( handler = self.handler_from_type(artifact_type) try: mlflow.start_run(run_id=run_id) - handler.log_model(artifact, "model", registered_model_name=model_key) + if artifact_type == "pyfunc": + handler.log_model("model", python_model=artifact, registered_model_name=model_key) + else: + handler.log_model(artifact, "model", registered_model_name=model_key) if metadata: mlflow.log_params(metadata) model_version = self.transition_stage(skeys=skeys, dkeys=dkeys) @@ -238,6 +276,37 @@ def save( finally: mlflow.end_run() + def save_multiple( + self, + skeys: KEYS, + dict_artifacts: dict[str, KeyedArtifact], + **metadata: META_VT, + ) -> Optional[ModelVersion]: + """ + Saves multiple artifacts into mlflow registry. The last save stores all the + artifact versions in the metadata. + + Args: + ---- + skeys: static key fields as list/tuple of strings + dict_artifacts: dict of artifacts to save + metadata: additional metadata surrounding the artifact that needs to be saved. + + Returns + ------- + mlflow ModelVersion instance + """ + multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) + dkeys_list = multiple_artifacts.get_dkeys_list() + sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + return self.save( + skeys=multiple_artifacts.skeys, + dkeys=sorted_dkeys, + artifact=multiple_artifacts, + artifact_type="pyfunc", + metadata=multiple_artifacts.metadata, + ) + @staticmethod def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """Returns whether the given artifact is stale or not, i.e. if @@ -338,3 +407,63 @@ def __load_artifacts( version_info.version, ) return model, metadata + + def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]: + """ + Returns a unique sorted list of all dkeys in the stored artifacts. + + Args: + ---- + dkeys_list: A list of lists containing the destination keys of the artifacts. + + Returns + ------- + List[str] + A list of all unique dkeys in the stored artifacts, sorted in ascending order. + """ + return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys])) + + +class CompositeModels(mlflow.pyfunc.PythonModel): + """A composite model that represents multiple artifacts. + + This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load + multiple artifacts in the MLflow registry. It provides a convenient way to manage and + organize multiple artifacts associated with a single model. + + Args: + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + **metadata (META_VT): Additional metadata associated with the artifacts. + + Methods + ------- + get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts. + + Attributes + ---------- + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + metadata (META_VT): Additional metadata associated with the artifacts. + """ + + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): + self.skeys = skeys + self.dict_artifacts = dict_artifacts + self.metadata = metadata + + def get_dkeys_list(self): + """ + Returns a list of all dynamic keys in the stored artifacts. + + Returns + ------- + list[list[str]]: A list of all dynamic keys in the stored artifacts. + """ + dkeys_list = [] + artifacts = self.dict_artifacts.values() + for artifact in artifacts: + dkeys_list.append(artifact.dkeys) + return dkeys_list diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 4b61eddc..1120c3b7 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -9,8 +9,12 @@ from mlflow.store.entities import PagedList from sklearn.preprocessing import StandardScaler from torch import tensor +from mlflow.models import Model +from numalogic.models.autoencoder.variants.vanilla import VanillaAE from numalogic.models.threshold import StdDevThreshold +from numalogic.registry.mlflow_registry import CompositeModels +from numalogic.tools.types import KeyedArtifact def create_model(): @@ -135,6 +139,71 @@ def mock_log_model_sklearn(*_, **__): ) +def mock_log_model_pyfunc(*_, **__): + return ModelInfo( + artifact_path="model", + flavors={ + "pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None}, + "python_function": { + "pickle_module_name": "mlflow.pyfunc.pickle_module", + "loader_module": "mlflow.pyfunc", + "python_version": "3.8.5", + "data": "data", + "env": "conda.yaml", + }, + }, + model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model", + model_uuid="a7c0b376530b40d7b23e6ce2081c899c", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + saved_input_example_info=None, + signature_dict=None, + utc_time_created="2022-05-23 22:35:59.557372", + mlflow_version="2.0.1", + signature=None, + ) + + +def mock_load_model_pyfunc(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pyfunc.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, + model_impl=TestObject( + python_model=CompositeModels( + skeys=["model"], + dict_artifacts={ + "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + + def mock_transition_stage(*_, **__): return ModelVersion( creation_timestamp=1653402941169, @@ -303,6 +372,23 @@ def return_sklearn_rundata(): ) +def return_pyfunc_rundata(): + return Run( + run_info=RunInfo( + artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model", + end_time=None, + experiment_id="0", + lifecycle_stage="active", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + run_uuid="a7c0b376530b40d7b23e6ce2081c899c", + start_time=1658788772612, + status="RUNNING", + user_id="lol", + ), + run_data=RunData(metrics={}, tags={}, params={}), + ) + + def return_pytorch_rundata_dict(): return Run( run_info=RunInfo( @@ -318,3 +404,8 @@ def return_pytorch_rundata_dict(): ), run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]), ) + + +class TestObject(mlflow.pyfunc.PythonModel): + def __init__(self, python_model): + self.python_model = python_model diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 8482c6e8..de2cbd44 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -15,7 +15,10 @@ from numalogic.registry.mlflow_registry import ModelStage +from numalogic.tools.types import KeyedArtifact from tests.registry._mlflow_utils import ( + mock_load_model_pyfunc, + mock_log_model_pyfunc, model_sklearn, create_model, mock_log_model_pytorch, @@ -23,6 +26,7 @@ mock_get_model_version, mock_transition_stage, mock_log_model_sklearn, + return_pyfunc_rundata, return_pytorch_rundata_dict, return_empty_rundata, mock_list_of_model_version, @@ -56,22 +60,68 @@ def test_construct_key(self): self.assertEqual("model_:nnet::error1", key) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) - @patch("mlflow.log_param", mock_log_state_dict) + @patch("mlflow.log_params", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) def test_save_model(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys status = ml.save( - skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch" + skeys=skeys, + dkeys=dkeys, + artifact=self.model, + run_id="1234", + artifact_type="pytorch", + **{"lr": 0.01}, ) mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_pyfunc(self): + ml = MLflowRegistry(TRACKING_URI) + status = ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), + "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + }, + **{"learning_rate": 0.01}, + ) + self.assertIsNotNone(status) + mock_status = "READY" + self.assertEqual(mock_status, status.status) + + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_models_when_pyfunc_model_exist(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys_list = [["AE", "infer"], ["scaler", "infer"]] + data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) + self.assertIsNotNone(data["AE:infer"].metadata) + self.assertIsNotNone(data["scaler:infer"].metadata) + self.assertIsInstance(data, dict) + self.assertIsInstance(data["AE:infer"].artifact, VanillaAE) + self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata()))) From 8db67b1dc173d4839e011f07c184d5e8c094ec8c Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Mon, 23 Sep 2024 16:31:07 -0400 Subject: [PATCH 2/7] fix: load_multiple Signed-off-by: Leila Wang --- numalogic/registry/mlflow_registry.py | 39 +++++++++++++------------- pyproject.toml | 2 +- tests/registry/test_mlflow_registry.py | 15 ++++++++++ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 62948928..75c4bd7a 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -193,7 +193,7 @@ def load_multiple( self, skeys: KEYS, dkeys_list: list[list[str]], - ) -> Optional[dict[str, ArtifactData]]: + ) -> Optional[ArtifactData]: """ Load multiple artifacts from the registry for pyfunc models. Args: @@ -203,24 +203,23 @@ def load_multiple( Returns ------- - Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys - to the loaded artifacts, or None if no artifacts were found. + Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None. + ArtifactData should contain a dictionary of artifacts. """ dkeys = self.__get_sorted_unique_dkeys(dkeys_list) loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") if loaded_model is not None: - metadata = loaded_model.artifact.unwrap_python_model().metadata - dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts - artifacts_dict = {} - for artifact in dict_artifacts.values(): - artifact_data = ArtifactData( - artifact=artifact.artifact, metadata=metadata, extras=None - ) - dynamic_key = ":".join(artifact.dkeys) - artifacts_dict[dynamic_key] = artifact_data - else: - artifacts_dict = None - return artifacts_dict + try: + unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except Exception: + _LOGGER.exception("Error occurred while unwrapping python model") + return None + + dict_artifacts = unwrapped_composite_model.dict_artifacts + metadata = loaded_model.metadata + version_info = loaded_model.extras + return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info) + return None @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: @@ -297,14 +296,14 @@ def save_multiple( mlflow ModelVersion instance """ multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) - dkeys_list = multiple_artifacts.get_dkeys_list() + dkeys_list = multiple_artifacts._get_dkeys_list() sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) return self.save( - skeys=multiple_artifacts.skeys, + skeys=skeys, dkeys=sorted_dkeys, artifact=multiple_artifacts, artifact_type="pyfunc", - metadata=multiple_artifacts.metadata, + **metadata, ) @staticmethod @@ -449,12 +448,14 @@ class CompositeModels(mlflow.pyfunc.PythonModel): metadata (META_VT): Additional metadata associated with the artifacts. """ + __slots__ = ("skeys", "dict_artifacts", "metadata") + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): self.skeys = skeys self.dict_artifacts = dict_artifacts self.metadata = metadata - def get_dkeys_list(self): + def _get_dkeys_list(self): """ Returns a list of all dynamic keys in the stored artifacts. diff --git a/pyproject.toml b/pyproject.toml index 222b544e..5fb0804b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.13.2" +version = "0.13.3" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index de2cbd44..f755e0a1 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -457,6 +457,21 @@ def test_cache_loading(self): key = MLflowRegistry.construct_key(self.skeys, self.dkeys) self.assertIsNotNone(ml._load_from_cache(key)) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + def test_cache_loading_pyfunc(self): + cache_registry = LocalLRUCache(ttl=50000) + ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) + dkeys_list = [["AE", "infer"], ["scaler", "infer"]] + ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list) + unique_sorted_dkeys = ["AE", "infer", "scaler"] + key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys) + self.assertIsNotNone(ml._load_from_cache(key)) + if __name__ == "__main__": unittest.main() From 82ea731ac25e876fbeef8123aa7cc63eaf6fae1c Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Mon, 23 Sep 2024 16:51:01 -0400 Subject: [PATCH 3/7] fix: test cases Signed-off-by: Leila Wang --- tests/registry/_mlflow_utils.py | 4 +++- tests/registry/test_mlflow_registry.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 1120c3b7..1ffc0830 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -385,7 +385,9 @@ def return_pyfunc_rundata(): status="RUNNING", user_id="lol", ), - run_data=RunData(metrics={}, tags={}, params={}), + run_data=RunData( + metrics={}, tags={}, params=[mlflow.entities.Param("learning_rate", "0.01")] + ), ) diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index f755e0a1..611d803a 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -116,11 +116,11 @@ def test_load_multiple_models_when_pyfunc_model_exist(self): skeys = self.skeys dkeys_list = [["AE", "infer"], ["scaler", "infer"]] data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) - self.assertIsNotNone(data["AE:infer"].metadata) - self.assertIsNotNone(data["scaler:infer"].metadata) - self.assertIsInstance(data, dict) - self.assertIsInstance(data["AE:infer"].artifact, VanillaAE) - self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler) + self.assertIsNotNone(data.metadata) + self.assertIsInstance(data, ArtifactData) + self.assertIsInstance(data.artifact, dict) + self.assertIsInstance(data.artifact["AE"].artifact, VanillaAE) + self.assertIsInstance(data.artifact["scaler"].artifact, StandardScaler) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) From 69caf21da852a447ccc5319eee91bd901aa6c96f Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Wed, 25 Sep 2024 17:55:09 -0400 Subject: [PATCH 4/7] fix: dkeys format Signed-off-by: Leila Wang --- numalogic/registry/mlflow_registry.py | 97 +++++++++++--------------- tests/registry/_mlflow_utils.py | 12 ++-- tests/registry/test_mlflow_registry.py | 33 ++++++--- 3 files changed, 67 insertions(+), 75 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 75c4bd7a..65c6812c 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -23,12 +23,11 @@ from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient -from sortedcontainers import SortedSet from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError -from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT +from numalogic.tools.types import artifact_t, KEYS, META_VT _LOGGER = logging.getLogger(__name__) @@ -192,34 +191,39 @@ def load( def load_multiple( self, skeys: KEYS, - dkeys_list: list[list[str]], + dkeys: KEYS, ) -> Optional[ArtifactData]: """ Load multiple artifacts from the registry for pyfunc models. Args: skeys (KEYS): The source keys of the artifacts to load. - dkeys_list (list[list[str]]): - A list of lists containing the dkeys of the artifacts to load. + dkeys: dynamic key fields as list/tuple of strings. Returns ------- Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None. ArtifactData should contain a dictionary of artifacts. """ - dkeys = self.__get_sorted_unique_dkeys(dkeys_list) loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") - if loaded_model is not None: - try: - unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() - except Exception: - _LOGGER.exception("Error occurred while unwrapping python model") - return None - - dict_artifacts = unwrapped_composite_model.dict_artifacts - metadata = loaded_model.metadata - version_info = loaded_model.extras - return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info) - return None + if loaded_model is None: + return None + if loaded_model.artifact.loader_module != "mlflow.pyfunc.model": + raise TypeError("The loaded model is not a valid pyfunc Python model.") + + try: + unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except AttributeError: + _LOGGER.exception("The loaded model does not have an unwrap_python_model method") + return None + except Exception: + _LOGGER.exception("Unexpected error occurred while unwrapping python model.") + return None + else: + return ArtifactData( + artifact=unwrapped_composite_model.dict_artifacts, + metadata=loaded_model.metadata, + extras=loaded_model.extras, + ) @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: @@ -278,7 +282,8 @@ def save( def save_multiple( self, skeys: KEYS, - dict_artifacts: dict[str, KeyedArtifact], + dkeys: KEYS, + dict_artifacts: dict[str, artifact_t], **metadata: META_VT, ) -> Optional[ModelVersion]: """ @@ -287,20 +292,22 @@ def save_multiple( Args: ---- - skeys: static key fields as list/tuple of strings - dict_artifacts: dict of artifacts to save - metadata: additional metadata surrounding the artifact that needs to be saved. + skeys (KEYS): Static key fields as a list or tuple of strings. + dkeys (KEYS): Dynamic key fields as a list or tuple of strings. + dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save. + **metadata (META_VT): Additional metadata to be saved with the artifacts. Returns ------- - mlflow ModelVersion instance + Optional[ModelVersion]: An instance of the MLflow ModelVersion. + """ - multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) - dkeys_list = multiple_artifacts._get_dkeys_list() - sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) + if len(dict_artifacts) == 1: + _LOGGER.warning("Only one element in dict_artifacts. Please use save directly.") + multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) return self.save( skeys=skeys, - dkeys=sorted_dkeys, + dkeys=dkeys, artifact=multiple_artifacts, artifact_type="pyfunc", **metadata, @@ -407,23 +414,8 @@ def __load_artifacts( ) return model, metadata - def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]: - """ - Returns a unique sorted list of all dkeys in the stored artifacts. - - Args: - ---- - dkeys_list: A list of lists containing the destination keys of the artifacts. - - Returns - ------- - List[str] - A list of all unique dkeys in the stored artifacts, sorted in ascending order. - """ - return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys])) - -class CompositeModels(mlflow.pyfunc.PythonModel): +class CompositeModel(mlflow.pyfunc.PythonModel): """A composite model that represents multiple artifacts. This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load @@ -438,7 +430,7 @@ class CompositeModels(mlflow.pyfunc.PythonModel): Methods ------- - get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts. + predict: Not implemented for our use case. Attributes ---------- @@ -450,21 +442,10 @@ class CompositeModels(mlflow.pyfunc.PythonModel): __slots__ = ("skeys", "dict_artifacts", "metadata") - def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT): self.skeys = skeys self.dict_artifacts = dict_artifacts self.metadata = metadata - def _get_dkeys_list(self): - """ - Returns a list of all dynamic keys in the stored artifacts. - - Returns - ------- - list[list[str]]: A list of all dynamic keys in the stored artifacts. - """ - dkeys_list = [] - artifacts = self.dict_artifacts.values() - for artifact in artifacts: - dkeys_list.append(artifact.dkeys) - return dkeys_list + def predict(self): + raise NotImplementedError() diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 1ffc0830..ca82f246 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -13,8 +13,7 @@ from numalogic.models.autoencoder.variants.vanilla import VanillaAE from numalogic.models.threshold import StdDevThreshold -from numalogic.registry.mlflow_registry import CompositeModels -from numalogic.tools.types import KeyedArtifact +from numalogic.registry.mlflow_registry import CompositeModel def create_model(): @@ -192,11 +191,12 @@ def mock_load_model_pyfunc(*_, **__): return mlflow.pyfunc.PyFuncModel( model_meta=model, model_impl=TestObject( - python_model=CompositeModels( - skeys=["model"], + python_model=CompositeModel( + skeys=["error"], dict_artifacts={ - "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), - "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), }, **{"learning_rate": 0.01}, ) diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 611d803a..e0fa686d 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -11,11 +11,11 @@ from sklearn.preprocessing import StandardScaler from numalogic.models.autoencoder.variants import VanillaAE +from numalogic.models.threshold._std import StdDevThreshold from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache from numalogic.registry.mlflow_registry import ModelStage -from numalogic.tools.types import KeyedArtifact from tests.registry._mlflow_utils import ( mock_load_model_pyfunc, mock_log_model_pyfunc, @@ -93,9 +93,11 @@ def test_save_multiple_models_pyfunc(self): status = ml.save_multiple( skeys=self.skeys, dict_artifacts={ - "AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)), - "scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()), + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), }, + dkeys=["unique", "sorted"], **{"learning_rate": 0.01}, ) self.assertIsNotNone(status) @@ -114,13 +116,23 @@ def test_save_multiple_models_pyfunc(self): def test_load_multiple_models_when_pyfunc_model_exist(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys - dkeys_list = [["AE", "infer"], ["scaler", "infer"]] - data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) + dkeys = ["unique", "sorted"] + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + data = ml.load_multiple(skeys=skeys, dkeys=dkeys) self.assertIsNotNone(data.metadata) self.assertIsInstance(data, ArtifactData) self.assertIsInstance(data.artifact, dict) - self.assertIsInstance(data.artifact["AE"].artifact, VanillaAE) - self.assertIsInstance(data.artifact["scaler"].artifact, StandardScaler) + self.assertIsInstance(data.artifact["inference"], VanillaAE) + self.assertIsInstance(data.artifact["precrocessing"], StandardScaler) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @@ -466,10 +478,9 @@ def test_cache_loading(self): def test_cache_loading_pyfunc(self): cache_registry = LocalLRUCache(ttl=50000) ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) - dkeys_list = [["AE", "infer"], ["scaler", "infer"]] - ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list) - unique_sorted_dkeys = ["AE", "infer", "scaler"] - key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys) + dkeys = ["unique", "sorted"] + ml.load_multiple(skeys=self.skeys, dkeys=dkeys) + key = MLflowRegistry.construct_key(self.skeys, dkeys) self.assertIsNotNone(ml._load_from_cache(key)) From 049d686776480253cf99f9966f65cf72420f205a Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Thu, 26 Sep 2024 11:45:29 -0400 Subject: [PATCH 5/7] fix: code patch test cases Signed-off-by: Leila Wang --- numalogic/registry/mlflow_registry.py | 12 +++- tests/registry/_mlflow_utils.py | 31 +++++++++ tests/registry/test_mlflow_registry.py | 94 +++++++++++++++++++++++++- 3 files changed, 133 insertions(+), 4 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 65c6812c..1eb07939 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -207,11 +207,11 @@ def load_multiple( loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") if loaded_model is None: return None - if loaded_model.artifact.loader_module != "mlflow.pyfunc.model": - raise TypeError("The loaded model is not a valid pyfunc Python model.") try: unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except mlflow.exceptions.MlflowException as e: + raise TypeError("The loaded model is not a valid pyfunc Python model.") from e except AttributeError: _LOGGER.exception("The loaded model does not have an unwrap_python_model method") return None @@ -448,4 +448,10 @@ def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadat self.metadata = metadata def predict(self): - raise NotImplementedError() + """ + Predict method is not implemented for our use case. + + The CompositeModel class is designed to store and load multiple artifacts, + and the predict method is not required for this functionality. + """ + raise NotImplementedError("The predict method is not implemented for CompositeModel.") diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index ca82f246..3afbc96a 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -204,6 +204,37 @@ def mock_load_model_pyfunc(*_, **__): ) +def mock_load_model_pyfunc_type_error(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pytorch.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu") + ) + + def mock_transition_stage(*_, **__): return ModelVersion( creation_timestamp=1653402941169, diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index e0fa686d..13345853 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -15,9 +15,10 @@ from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache -from numalogic.registry.mlflow_registry import ModelStage +from numalogic.registry.mlflow_registry import CompositeModel, ModelStage from tests.registry._mlflow_utils import ( mock_load_model_pyfunc, + mock_load_model_pyfunc_type_error, mock_log_model_pyfunc, model_sklearn, create_model, @@ -104,6 +105,25 @@ def test_save_multiple_models_pyfunc(self): mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_when_only_one_model(self): + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="WARNING"): + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) @patch("mlflow.log_params", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) @@ -134,6 +154,78 @@ def test_load_multiple_models_when_pyfunc_model_exist(self): self.assertIsInstance(data.artifact["inference"], VanillaAE) self.assertIsInstance(data.artifact["precrocessing"], StandardScaler) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch( + "mlflow.tracking.MlflowClient.get_latest_versions", + Mock(return_value=PagedList(items=[], token=None)), + ) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", Mock(return_value=None)) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + def test_load_model_when_no_model_pyfunc(self): + fake_skeys = ["Fakemodel_"] + fake_dkeys = ["error"] + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="ERROR") as log: + o = ml.load_multiple(skeys=fake_skeys, dkeys=fake_dkeys) + self.assertIsNone(o) + self.assertTrue(log.output) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj) + @patch( + "mlflow.pyfunc.load_model", + Mock( + return_value=CompositeModel( + skeys=["error"], + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_attribute_error(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys = ["unique", "sorted"] + with self.assertLogs(level="ERROR") as log: + result = ml.load_multiple(skeys=skeys, dkeys=dkeys) + self.assertIsNone(result) + self.assertTrue( + any( + "The loaded model does not have an unwrap_python_model method" in message + for message in log.output + ) + ) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc_type_error) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_type_error(self): + ml = MLflowRegistry(TRACKING_URI) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=self.model, + artifact_type="pytorch", + **{"lr": 0.01}, + ) + with self.assertRaises(TypeError): + ml.load_multiple(skeys=self.skeys, dkeys=self.dkeys) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata()))) From 74c10bbad19d870ada0bb70789d4eb1010b00440 Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Thu, 26 Sep 2024 13:18:27 -0400 Subject: [PATCH 6/7] fix: minor issues Signed-off-by: Leila Wang --- numalogic/registry/mlflow_registry.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 1eb07939..feb5a8d1 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -218,12 +218,12 @@ def load_multiple( except Exception: _LOGGER.exception("Unexpected error occurred while unwrapping python model.") return None - else: - return ArtifactData( - artifact=unwrapped_composite_model.dict_artifacts, - metadata=loaded_model.metadata, - extras=loaded_model.extras, - ) + + return ArtifactData( + artifact=unwrapped_composite_model.dict_artifacts, + metadata=loaded_model.metadata, + extras=loaded_model.extras, + ) @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: @@ -303,7 +303,7 @@ def save_multiple( """ if len(dict_artifacts) == 1: - _LOGGER.warning("Only one element in dict_artifacts. Please use save directly.") + _LOGGER.warning("Only one element in dict_artifacts. Saving directly is recommended.") multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) return self.save( skeys=skeys, @@ -447,7 +447,7 @@ def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadat self.dict_artifacts = dict_artifacts self.metadata = metadata - def predict(self): + def predict(self, context, model_input, params: Optional[dict[str, Any]] = None): """ Predict method is not implemented for our use case. From 36fd6a40f7667053ce002cb9c7a69771b99c7b0f Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Thu, 26 Sep 2024 14:41:05 -0400 Subject: [PATCH 7/7] fix: imports Signed-off-by: Leila Wang --- numalogic/registry/mlflow_registry.py | 11 +++++------ tests/registry/test_mlflow_registry.py | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index feb5a8d1..3b406e69 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -15,12 +15,9 @@ from enum import Enum from typing import Optional, Any -import mlflow.pyfunc -import mlflow.pytorch -import mlflow.sklearn import mlflow from mlflow.entities.model_registry import ModelVersion -from mlflow.exceptions import RestException +from mlflow.exceptions import RestException, MlflowException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient @@ -210,7 +207,7 @@ def load_multiple( try: unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() - except mlflow.exceptions.MlflowException as e: + except MlflowException as e: raise TypeError("The loaded model is not a valid pyfunc Python model.") from e except AttributeError: _LOGGER.exception("The loaded model does not have an unwrap_python_model method") @@ -303,7 +300,9 @@ def save_multiple( """ if len(dict_artifacts) == 1: - _LOGGER.warning("Only one element in dict_artifacts. Saving directly is recommended.") + _LOGGER.warning( + "Only one artifact present in dict_artifacts. Saving directly is recommended." + ) multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) return self.save( skeys=skeys, diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 13345853..716265b6 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -3,6 +3,9 @@ from contextlib import contextmanager from unittest.mock import patch, Mock +import mlflow.pytorch # noqa: F401 +import mlflow.pyfunc # noqa: F401 +import mlflow.sklearn # noqa: F401 from freezegun import freeze_time from mlflow import ActiveRun from mlflow.exceptions import RestException