Skip to content

Commit 92f10f8

Browse files
Add model manager as BackgroundService for handling machine learning models
Signed-off-by: Idlir Shkurti <[email protected]>
1 parent ceef3ba commit 92f10f8

File tree

4 files changed

+262
-0
lines changed

4 files changed

+262
-0
lines changed

src/frequenz/sdk/ml/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Model interface."""
5+
6+
from ._model_manager import ModelManager
7+
8+
__all__ = [
9+
"ModelManager",
10+
]
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Load, update, monitor and retrieve machine learning models."""
5+
6+
import asyncio
7+
import logging
8+
import pickle
9+
from dataclasses import dataclass
10+
from pathlib import Path
11+
from typing import Generic, TypeVar, cast
12+
13+
from frequenz.channels.file_watcher import EventType, FileWatcher
14+
from typing_extensions import override
15+
16+
from frequenz.sdk.actor import BackgroundService
17+
18+
_logger = logging.getLogger(__name__)
19+
20+
T = TypeVar("T")
21+
22+
23+
@dataclass
24+
class _Model(Generic[T]):
25+
"""Represent a machine learning model."""
26+
27+
data: T
28+
path: Path
29+
30+
31+
class ModelManager(BackgroundService, Generic[T]):
32+
"""Load, update, monitor and retrieve machine learning models."""
33+
34+
def __init__(self, model_paths: dict[str, Path], *, name: str | None = None):
35+
"""Initialize the model manager with the specified model paths.
36+
37+
Args:
38+
model_paths: A dictionary of model keys and their corresponding file paths.
39+
name: The name of the model manager service.
40+
"""
41+
super().__init__(name=name)
42+
self._models: dict[str, _Model[T]] = {}
43+
self.model_paths = model_paths
44+
self.load_models()
45+
46+
def load_models(self) -> None:
47+
"""Load the models from the specified paths."""
48+
for key, path in self.model_paths.items():
49+
self._models[key] = _Model(data=self._load(path), path=path)
50+
51+
@staticmethod
52+
def _load(path: Path) -> T:
53+
"""Load the model from the specified path.
54+
55+
Args:
56+
path: The path to the model file.
57+
58+
Returns:
59+
T: The loaded model data.
60+
61+
Raises:
62+
FileNotFoundError: If the model file does not exist.
63+
"""
64+
if not path.exists():
65+
raise FileNotFoundError(f"The model path {path} does not exist.")
66+
with path.open("rb") as file:
67+
return cast(T, pickle.load(file))
68+
69+
@override
70+
def start(self) -> None:
71+
"""Start the model monitoring service by creating a background task."""
72+
if not self.is_running:
73+
task = asyncio.create_task(self._monitor_paths())
74+
self._tasks.add(task)
75+
_logger.info(
76+
"%s: Started ModelManager service with task %s",
77+
self.name,
78+
task,
79+
)
80+
81+
async def _monitor_paths(self) -> None:
82+
"""Monitor model file paths and reload models as necessary."""
83+
model_paths = [model.path for model in self._models.values()]
84+
file_watcher = FileWatcher(
85+
paths=model_paths, event_types=[EventType.CREATE, EventType.MODIFY]
86+
)
87+
_logger.info("%s: Monitoring model paths for changes.", self.name)
88+
async for event in file_watcher:
89+
_logger.info(
90+
"%s: Reloading model from file %s due to a %s event...",
91+
self.name,
92+
event.path,
93+
event.type.name,
94+
)
95+
self.reload_model(Path(event.path))
96+
97+
def reload_model(self, path: Path) -> None:
98+
"""Reload the model from the specified path.
99+
100+
Args:
101+
path: The path to the model file.
102+
"""
103+
for key, model in self._models.items():
104+
if model.path == path:
105+
try:
106+
model.data = self._load(path)
107+
_logger.info(
108+
"%s: Successfully reloaded model from %s",
109+
self.name,
110+
path,
111+
)
112+
except Exception: # pylint: disable=broad-except
113+
_logger.exception("Failed to reload model from %s", path)
114+
115+
def get_model(self, key: str) -> T:
116+
"""Retrieve a loaded model by key.
117+
118+
Args:
119+
key: The key of the model to retrieve.
120+
121+
Returns:
122+
The loaded model data.
123+
124+
Raises:
125+
KeyError: If the model with the specified key is not found.
126+
"""
127+
try:
128+
return self._models[key].data
129+
except KeyError as exc:
130+
raise KeyError(f"Model with key '{key}' is not found.") from exc

tests/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the model package."""

tests/model/test_model_manager.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for machine learning model manager."""
5+
6+
import pickle
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
from typing import Any
10+
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
11+
12+
import pytest
13+
14+
from frequenz.sdk.ml import ModelManager
15+
16+
17+
@dataclass
18+
class MockModel:
19+
"""Mock model for unit testing purposes."""
20+
21+
data: int | str
22+
23+
def predict(self) -> int | str:
24+
"""Make a prediction based on the model data."""
25+
return self.data
26+
27+
28+
async def test_model_manager_loading() -> None:
29+
"""Test loading models using ModelManager with direct configuration."""
30+
model1 = MockModel("Model 1 Data")
31+
model2 = MockModel("Model 2 Data")
32+
pickled_model1 = pickle.dumps(model1)
33+
pickled_model2 = pickle.dumps(model2)
34+
35+
model_paths = {
36+
"model1": Path("path/to/model1.pkl"),
37+
"model2": Path("path/to/model2.pkl"),
38+
}
39+
40+
mock_files = {
41+
"path/to/model1.pkl": mock_open(read_data=pickled_model1)(),
42+
"path/to/model2.pkl": mock_open(read_data=pickled_model2)(),
43+
}
44+
45+
def mock_open_func(file_path: Path, *__args: Any, **__kwargs: Any) -> Any:
46+
"""Mock open function to return the correct mock file object.
47+
48+
Args:
49+
file_path: The path to the file to open.
50+
*__args: Variable length argument list. This can be used to pass additional
51+
positional parameters typically used in file opening operations,
52+
such as `mode` or `buffering`.
53+
**__kwargs: Arbitrary keyword arguments. This can include parameters like
54+
`encoding` and `errors`, common in file opening operations.
55+
56+
Returns:
57+
Any: The mock file object.
58+
59+
Raises:
60+
FileNotFoundError: If the file path is not in the mock files dictionary.
61+
"""
62+
file_path_str = str(file_path)
63+
if file_path_str in mock_files:
64+
file_handle = MagicMock()
65+
file_handle.__enter__.return_value = mock_files[file_path_str]
66+
return file_handle
67+
raise FileNotFoundError(f"No mock setup for {file_path_str}")
68+
69+
with patch("pathlib.Path.open", new=mock_open_func):
70+
with patch.object(Path, "exists", return_value=True):
71+
model_manager: ModelManager[MockModel] = ModelManager(
72+
model_paths=model_paths
73+
)
74+
75+
with patch(
76+
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
77+
):
78+
model_manager.start() # Start the service
79+
80+
assert isinstance(model_manager.get_model("model1"), MockModel)
81+
assert model_manager.get_model("model1").data == "Model 1 Data"
82+
assert model_manager.get_model("model2").data == "Model 2 Data"
83+
84+
with pytest.raises(KeyError):
85+
model_manager.get_model("key3")
86+
87+
await model_manager.stop() # Stop the service to clean up
88+
89+
90+
async def test_model_manager_update() -> None:
91+
"""Test updating a model in ModelManager."""
92+
original_model = MockModel("Original Data")
93+
updated_model = MockModel("Updated Data")
94+
pickled_original_model = pickle.dumps(original_model)
95+
pickled_updated_model = pickle.dumps(updated_model)
96+
97+
model_paths = {"model1": Path("path/to/model1.pkl")}
98+
99+
mock_file = mock_open(read_data=pickled_original_model)
100+
with (
101+
patch("pathlib.Path.open", mock_file),
102+
patch.object(Path, "exists", return_value=True),
103+
):
104+
model_manager = ModelManager[MockModel](model_paths=model_paths)
105+
with patch(
106+
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
107+
):
108+
model_manager.start() # Start the service
109+
110+
assert model_manager.get_model("model1").data == "Original Data"
111+
112+
# Simulate updating the model file
113+
mock_file.return_value.read.return_value = pickled_updated_model
114+
with patch("pathlib.Path.open", mock_file):
115+
model_manager.reload_model(Path("path/to/model1.pkl"))
116+
assert model_manager.get_model("model1").data == "Updated Data"
117+
118+
await model_manager.stop() # Stop the service to clean up

0 commit comments

Comments
 (0)