-
Notifications
You must be signed in to change notification settings - Fork 18
Add model manager for machine learning models #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
idlir-shkurti-frequenz
wants to merge
1
commit into
frequenz-floss:v1.x.x
from
idlir-shkurti-frequenz:model-manager
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# License: MIT | ||
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Model interface.""" | ||
|
||
from ._model_manager import ModelManager | ||
|
||
__all__ = [ | ||
"ModelManager", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# License: MIT | ||
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Load, update, monitor and retrieve machine learning models.""" | ||
|
||
import asyncio | ||
import logging | ||
import pickle | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Generic, TypeVar, cast | ||
|
||
from frequenz.channels.file_watcher import EventType, FileWatcher | ||
from typing_extensions import override | ||
|
||
from frequenz.sdk.actor import BackgroundService | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
@dataclass | ||
class _Model(Generic[T]): | ||
"""Represent a machine learning model.""" | ||
|
||
data: T | ||
path: Path | ||
|
||
|
||
class ModelNotFoundError(Exception): | ||
"""Exception raised when a model is not found.""" | ||
|
||
def __init__(self, key: str) -> None: | ||
"""Initialize the exception with the specified model key. | ||
|
||
Args: | ||
key: The key of the model that was not found. | ||
""" | ||
super().__init__(f"Model with key '{key}' is not found.") | ||
|
||
|
||
class ModelManager(BackgroundService, Generic[T]): | ||
"""Load, update, monitor and retrieve machine learning models.""" | ||
|
||
def __init__(self, model_paths: dict[str, Path], *, name: str | None = None): | ||
"""Initialize the model manager with the specified model paths. | ||
|
||
Args: | ||
model_paths: A dictionary of model keys and their corresponding file paths. | ||
name: The name of the model manager service. | ||
""" | ||
super().__init__(name=name) | ||
self._models: dict[str, _Model[T]] = {} | ||
self.model_paths = model_paths | ||
self.load_models() | ||
|
||
def load_models(self) -> None: | ||
"""Load the models from the specified paths.""" | ||
for key, path in self.model_paths.items(): | ||
self._models[key] = _Model(data=self._load(path), path=path) | ||
|
||
@staticmethod | ||
def _load(path: Path) -> T: | ||
"""Load the model from the specified path. | ||
|
||
Args: | ||
path: The path to the model file. | ||
|
||
Returns: | ||
T: The loaded model data. | ||
|
||
Raises: | ||
ModelNotFoundError: If the model file does not exist. | ||
""" | ||
try: | ||
with path.open("rb") as file: | ||
return cast(T, pickle.load(file)) | ||
except FileNotFoundError as exc: | ||
raise ModelNotFoundError(str(path)) from exc | ||
|
||
@override | ||
def start(self) -> None: | ||
idlir-shkurti-frequenz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Start the model monitoring service by creating a background task.""" | ||
if not self.is_running: | ||
task = asyncio.create_task(self._monitor_paths()) | ||
self._tasks.add(task) | ||
_logger.info( | ||
"%s: Started ModelManager service with task %s", | ||
self.name, | ||
task, | ||
) | ||
|
||
async def _monitor_paths(self) -> None: | ||
"""Monitor model file paths and reload models as necessary.""" | ||
model_paths = [model.path for model in self._models.values()] | ||
file_watcher = FileWatcher( | ||
paths=list(model_paths), event_types=[EventType.CREATE, EventType.MODIFY] | ||
) | ||
_logger.info("%s: Monitoring model paths for changes.", self.name) | ||
async for event in file_watcher: | ||
_logger.info( | ||
"%s: Reloading model from file %s due to a %s event...", | ||
self.name, | ||
event.path, | ||
event.type.name, | ||
) | ||
self.reload_model(Path(event.path)) | ||
|
||
def reload_model(self, path: Path) -> None: | ||
"""Reload the model from the specified path. | ||
|
||
Args: | ||
path: The path to the model file. | ||
""" | ||
for key, model in self._models.items(): | ||
if model.path == path: | ||
try: | ||
model.data = self._load(path) | ||
_logger.info( | ||
"%s: Successfully reloaded model from %s", | ||
self.name, | ||
path, | ||
) | ||
except Exception: # pylint: disable=broad-except | ||
_logger.exception("Failed to reload model from %s", path) | ||
|
||
def get_model(self, key: str) -> T: | ||
"""Retrieve a loaded model by key. | ||
|
||
Args: | ||
key: The key of the model to retrieve. | ||
|
||
Returns: | ||
The loaded model data. | ||
|
||
Raises: | ||
KeyError: If the model with the specified key is not found. | ||
""" | ||
try: | ||
return self._models[key].data | ||
except KeyError as exc: | ||
raise KeyError(f"Model with key '{key}' is not found.") from exc |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# License: MIT | ||
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Tests for the model package.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# License: MIT | ||
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Tests for machine learning model manager.""" | ||
|
||
import pickle | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any | ||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch | ||
|
||
import pytest | ||
|
||
from frequenz.sdk.ml import ModelManager | ||
|
||
|
||
@dataclass | ||
class MockModel: | ||
"""Mock model for unit testing purposes.""" | ||
|
||
data: int | str | ||
|
||
def predict(self) -> int | str: | ||
"""Make a prediction based on the model data.""" | ||
return self.data | ||
|
||
|
||
async def test_model_manager_loading() -> None: | ||
"""Test loading models using ModelManager with direct configuration.""" | ||
model1 = MockModel("Model 1 Data") | ||
model2 = MockModel("Model 2 Data") | ||
pickled_model1 = pickle.dumps(model1) | ||
pickled_model2 = pickle.dumps(model2) | ||
|
||
model_paths = { | ||
"model1": Path("path/to/model1.pkl"), | ||
"model2": Path("path/to/model2.pkl"), | ||
} | ||
|
||
mock_files = { | ||
"path/to/model1.pkl": mock_open(read_data=pickled_model1)(), | ||
"path/to/model2.pkl": mock_open(read_data=pickled_model2)(), | ||
} | ||
|
||
def mock_open_func(file_path: Path, *__args: Any, **__kwargs: Any) -> Any: | ||
"""Mock open function to return the correct mock file object. | ||
|
||
Args: | ||
file_path: The path to the file to open. | ||
*__args: Variable length argument list. This can be used to pass additional | ||
positional parameters typically used in file opening operations, | ||
such as `mode` or `buffering`. | ||
**__kwargs: Arbitrary keyword arguments. This can include parameters like | ||
`encoding` and `errors`, common in file opening operations. | ||
|
||
Returns: | ||
Any: The mock file object. | ||
|
||
Raises: | ||
FileNotFoundError: If the file path is not in the mock files dictionary. | ||
""" | ||
file_path_str = str(file_path) | ||
if file_path_str in mock_files: | ||
file_handle = MagicMock() | ||
file_handle.__enter__.return_value = mock_files[file_path_str] | ||
return file_handle | ||
raise FileNotFoundError(f"No mock setup for {file_path_str}") | ||
|
||
with patch("pathlib.Path.open", new=mock_open_func): | ||
with patch.object(Path, "exists", return_value=True): | ||
model_manager: ModelManager[MockModel] = ModelManager( | ||
model_paths=model_paths | ||
) | ||
|
||
with patch( | ||
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock | ||
): | ||
model_manager.start() # Start the service | ||
|
||
assert isinstance(model_manager.get_model("model1"), MockModel) | ||
assert model_manager.get_model("model1").data == "Model 1 Data" | ||
assert model_manager.get_model("model2").data == "Model 2 Data" | ||
|
||
with pytest.raises(KeyError): | ||
model_manager.get_model("key3") | ||
|
||
await model_manager.stop() # Stop the service to clean up | ||
|
||
|
||
async def test_model_manager_update() -> None: | ||
"""Test updating a model in ModelManager.""" | ||
original_model = MockModel("Original Data") | ||
updated_model = MockModel("Updated Data") | ||
pickled_original_model = pickle.dumps(original_model) | ||
pickled_updated_model = pickle.dumps(updated_model) | ||
|
||
model_paths = {"model1": Path("path/to/model1.pkl")} | ||
|
||
mock_file = mock_open(read_data=pickled_original_model) | ||
with ( | ||
patch("pathlib.Path.open", mock_file), | ||
patch.object(Path, "exists", return_value=True), | ||
): | ||
model_manager = ModelManager[MockModel](model_paths=model_paths) | ||
with patch( | ||
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock | ||
): | ||
model_manager.start() # Start the service | ||
|
||
assert model_manager.get_model("model1").data == "Original Data" | ||
|
||
# Simulate updating the model file | ||
mock_file.return_value.read.return_value = pickled_updated_model | ||
with patch("pathlib.Path.open", mock_file): | ||
model_manager.reload_model(Path("path/to/model1.pkl")) | ||
assert model_manager.get_model("model1").data == "Updated Data" | ||
|
||
await model_manager.stop() # Stop the service to clean up |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.