Skip to content

Commit d4bda26

Browse files
Add model manager for machine learning models
Signed-off-by: Idlir Shkurti <[email protected]>
1 parent 9fd4865 commit d4bda26

File tree

4 files changed

+419
-0
lines changed

4 files changed

+419
-0
lines changed

src/frequenz/sdk/model/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# License: MIT
2+
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Model interface."""
5+
6+
from ._manager import (
7+
DirectModelRepository,
8+
ModelManager,
9+
ModelRepository,
10+
SingleModelRepository,
11+
)
12+
13+
# Explicitly declare the public API.
14+
__all__ = [
15+
"DirectModelRepository",
16+
"ModelManager",
17+
"ModelRepository",
18+
"SingleModelRepository",
19+
]

src/frequenz/sdk/model/_manager.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# License: MIT
2+
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Load, update, monitor and retrieve machine learning models."""
5+
6+
import abc
7+
import asyncio
8+
import logging
9+
import os
10+
import pickle
11+
from dataclasses import dataclass
12+
from pathlib import Path
13+
from typing import Any, Generic, TypeVar
14+
15+
from frequenz.channels.file_watcher import EventType, FileWatcher
16+
17+
from frequenz.sdk._internal._asyncio import cancel_and_await
18+
19+
_logger = logging.getLogger(__name__)
20+
21+
T = TypeVar("T")
22+
"""The model type."""
23+
24+
25+
class ModelRepository(abc.ABC):
26+
"""Provide an interface for building the model repository configuration.
27+
28+
The ModelRepository is meant to be subclassed to provide specific
29+
implementations based on the type of models being managed.
30+
"""
31+
32+
@abc.abstractmethod
33+
def _get_allowed_keys(self) -> list[Any]:
34+
"""Get the list of allowed keys for the model repository.
35+
36+
Returns:
37+
a list of allowed keys. # noqa: DAR202
38+
39+
Raises:
40+
NotImplementedError: if the method is not implemented in the subclass.
41+
"""
42+
raise NotImplementedError
43+
44+
@abc.abstractmethod
45+
def _build_path(self, key: Any) -> str:
46+
"""Build the model path for the given key.
47+
48+
Args:
49+
key: the key for which to build the model path.
50+
51+
Returns:
52+
the model path. # noqa: DAR202
53+
54+
Raises:
55+
NotImplementedError: if the method is not implemented in the subclass.
56+
"""
57+
raise NotImplementedError
58+
59+
def get_models_config(self) -> dict[Any, str]:
60+
"""Get the models configuration.
61+
62+
Returns:
63+
the model keys to model paths mapping.
64+
"""
65+
allowed_keys = self._get_allowed_keys()
66+
assert allowed_keys, "Allowed keys must be set."
67+
return {key: self._build_path(key) for key in allowed_keys}
68+
69+
70+
@dataclass
71+
class SingleModelRepository(ModelRepository):
72+
"""Repository to manage a single model."""
73+
74+
model_path: str
75+
"""The path to a single model file."""
76+
77+
def _get_allowed_keys(self) -> list[Any]:
78+
"""Get the allowed keys for the single model repository.
79+
80+
Returns:
81+
the allowed single key.
82+
"""
83+
return ["Single"]
84+
85+
def _build_path(self, key: Any) -> str:
86+
"""Build the model path for the given key.
87+
88+
Args:
89+
key: the key for which to build the model path.
90+
91+
Returns:
92+
the model path.
93+
"""
94+
return self.model_path
95+
96+
97+
@dataclass
98+
class DirectModelRepository(ModelRepository):
99+
"""Repository to manage models based on a direct configuration dictionary."""
100+
101+
model_config: dict[str, str]
102+
"""A dictionary mapping keys to the full paths of model files."""
103+
104+
def _get_allowed_keys(self) -> list[str]:
105+
"""Get the list of allowed keys for the model repository.
106+
107+
Returns:
108+
A list of keys as specified in the model_config dictionary.
109+
"""
110+
assert self.model_config, "Model configuration must be set."
111+
return list(self.model_config.keys())
112+
113+
def _build_path(self, key: str) -> str:
114+
"""Build the model path for the given key, directly using the path from the model_config.
115+
116+
Args:
117+
key: The key for which to build the model path, must be in the model_config.
118+
119+
Returns:
120+
The full path to the model file associated with the given key.
121+
122+
Raises:
123+
KeyError: If the given key is not found in the model_config.
124+
"""
125+
assert self.model_config, "Model configuration must be set."
126+
if key not in self.model_config:
127+
raise KeyError(f"Key '{key}' not found in the model configuration.")
128+
return self.model_config[key]
129+
130+
131+
@dataclass
132+
class _Model(Generic[T]):
133+
"""Represent a machine learning model."""
134+
135+
data: T
136+
"""The machine learning model."""
137+
138+
path: str
139+
"""The path to the model file."""
140+
141+
142+
class ModelManager(Generic[T]):
143+
"""Manage machine learning models.
144+
145+
The model manager acts as a central hub for retrieving models, reloading
146+
models, and monitoring model paths for changes.
147+
"""
148+
149+
def __init__(self, model_repository: ModelRepository) -> None:
150+
"""Initialize the model manager.
151+
152+
Args:
153+
model_repository: a model repository that defines the mapping
154+
between model keys and paths.
155+
"""
156+
self._models: dict[Any, _Model[T]] = {}
157+
158+
models_config = model_repository.get_models_config()
159+
for key, model_path in models_config.items():
160+
self._models[key] = _Model(data=self._load(model_path), path=model_path)
161+
162+
self._monitoring_task: asyncio.Task[None] = asyncio.create_task(
163+
self._start_model_monitor()
164+
)
165+
166+
def __getitem__(self, key: Any) -> T:
167+
"""Get a specific loaded model using the subscript operator.
168+
169+
Args:
170+
key: the key to identify the model.
171+
172+
Returns:
173+
the model instance corresponding to the key.
174+
"""
175+
return self.get_model(key)
176+
177+
def get_model(self, key: Any = None) -> T:
178+
"""Get a specific loaded model.
179+
180+
Args:
181+
key: the key to identify the model.
182+
183+
Raises:
184+
KeyError: if the key is invalid or it is not provided for
185+
multi-model retrieval.
186+
187+
Returns:
188+
the model instance corresponding to the key.
189+
"""
190+
if key is None:
191+
if len(self._models) != 1:
192+
raise KeyError("No key provided for multi-model retrieval")
193+
key = list(self._models.keys())[0]
194+
195+
if key not in self._models:
196+
raise KeyError("Invalid key")
197+
198+
return self._models[key].data
199+
200+
def get_paths(self) -> list[Path | str]:
201+
"""Get the paths of all loaded models.
202+
203+
Returns:
204+
the paths of all loaded models.
205+
"""
206+
return [model.path for model in self._models.values()]
207+
208+
def reload_model(self, model_path: str) -> None:
209+
"""Reload a model from the given path.
210+
211+
Args:
212+
model_path: the path to the model file.
213+
214+
Raises:
215+
ValueError: if the model with the specified path is not found.
216+
"""
217+
if model_path not in self.get_paths():
218+
raise ValueError(f"No model found with path: {model_path}")
219+
220+
for model in self._models.values():
221+
if model.path == model_path:
222+
model.data = self._load(model_path)
223+
224+
async def join(self) -> None:
225+
"""Await the monitoring task, and return when the task completes."""
226+
if self._monitoring_task and not self._monitoring_task.done():
227+
await self._monitoring_task
228+
229+
async def _start_model_monitor(self) -> None:
230+
"""Start monitoring the model paths for changes."""
231+
# model_paths = self._model_loader.get_paths()
232+
model_paths = self.get_paths()
233+
file_watcher = FileWatcher(paths=model_paths)
234+
235+
_logger.info("Monitoring the paths of the models for changes: %s", model_paths)
236+
237+
async for event in file_watcher:
238+
# The model could be deleted and created again.
239+
if event.type in (
240+
EventType.CREATE,
241+
EventType.MODIFY,
242+
):
243+
_logger.info(
244+
"Model file %s has been modified. Reloading model...",
245+
str(event.path),
246+
)
247+
self.reload_model(str(event.path))
248+
249+
_logger.debug("ModelManager stopped.")
250+
251+
async def _stop_model_monitor(self) -> None:
252+
"""Stop monitoring the model paths for changes."""
253+
if self._monitoring_task:
254+
await cancel_and_await(self._monitoring_task)
255+
256+
def _load(self, model_path: str) -> T:
257+
"""Load the model file.
258+
259+
Args:
260+
model_path: the path of the model to be loaded.
261+
262+
Raises:
263+
FileNotFoundError: if the model path does not exist.
264+
265+
Returns:
266+
the model instance.
267+
"""
268+
if not os.path.exists(model_path):
269+
raise FileNotFoundError(f"The model path {model_path} does not exist.")
270+
271+
with open(model_path, "rb") as file:
272+
model: T = pickle.load(file)
273+
return model

tests/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# License: MIT
2+
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the model package."""

0 commit comments

Comments
 (0)