Skip to content

Commit 536517a

Browse files
committed
post-merge tweaks
1 parent 116e2b4 commit 536517a

File tree

7 files changed

+56
-102
lines changed

7 files changed

+56
-102
lines changed

ex/high_throughput_inference/mock_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
115115
self.measure_time("build_tensor_descriptor")
116116
built_model = None
117117
if isinstance(model, str):
118-
model_arg = MessageHandler.build_model_key(model)
118+
model_arg = MessageHandler.build_model_key(model) # todo: this needs FSD
119119
else:
120120
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
121121
request = MessageHandler.build_request(

smartsim/_core/mli/infrastructure/control/workermanager.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from ...comm.channel.channel import CommChannelBase
3535
from ...comm.channel.dragonchannel import DragonCommChannel
3636
from ...infrastructure.environmentloader import EnvironmentConfigLoader
37-
from ...infrastructure.storage.featurestore import FeatureStore
3837
from ...infrastructure.worker.worker import (
3938
InferenceReply,
4039
InferenceRequest,
@@ -54,32 +53,18 @@
5453
logger = get_logger(__name__)
5554

5655

57-
def build_failure_reply(status: "StatusEnum", message: str) -> Response:
56+
def build_failure_reply(status: "Status", message: str) -> ResponseBuilder:
5857
"""Build a response indicating a failure occurred
5958
:param status: The status of the response
6059
:param message: The error message to include in the response"""
6160
return MessageHandler.build_response(
62-
status=status, # todo: need to indicate correct status
63-
message=message, # todo: decide what these will be
61+
status=status,
62+
message=message,
6463
result=None,
6564
custom_attributes=None,
6665
)
6766

6867

69-
def build_reply(worker: MachineLearningWorkerBase, reply: InferenceReply) -> Response:
70-
"""Builds a response for a successful inference request
71-
:param worker: A worker to process the reply with
72-
:param reply: The internal representation of the reply"""
73-
results = worker.prepare_outputs(reply)
74-
75-
return MessageHandler.build_response(
76-
status=reply.status_enum,
77-
message=reply.message,
78-
result=results,
79-
custom_attributes=None,
80-
)
81-
82-
8368
def exception_handler(
8469
exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str
8570
) -> None:
@@ -143,13 +128,15 @@ def _check_feature_stores(self, request: InferenceRequest) -> bool:
143128
"""Ensures that all feature stores required by the request are available
144129
:param request: The request to validate"""
145130
# collect all feature stores required by the request
146-
fs_model = {request.model_key.descriptor}
131+
fs_model: t.Set[str] = set()
132+
if request.model_key:
133+
fs_model = {request.model_key.descriptor}
147134
fs_inputs = {key.descriptor for key in request.input_keys}
148135
fs_outputs = {key.descriptor for key in request.output_keys}
149136

150137
# identify which feature stores are requested and unknown
151-
fs_desired = fs_model + fs_inputs + fs_outputs
152-
fs_actual = {key for key in self._feature_stores}
138+
fs_desired = fs_model.union(fs_inputs).union(fs_outputs)
139+
fs_actual = {item.descriptor for item in self._feature_stores.values()}
153140
fs_missing = fs_desired - fs_actual
154141

155142
# exit if all desired feature stores are not available
@@ -259,7 +246,7 @@ def _on_iteration(self) -> None:
259246
interm = time.perf_counter() # timing
260247
try:
261248
fetch_model_result = self._worker.fetch_model(
262-
request, self._feature_store
249+
request, self._feature_stores
263250
)
264251
except Exception as e:
265252
exception_handler(
@@ -287,7 +274,7 @@ def _on_iteration(self) -> None:
287274
interm = time.perf_counter() # timing
288275
try:
289276
fetch_model_result = self._worker.fetch_model(
290-
request, self._feature_store
277+
request, self._feature_stores
291278
)
292279
except Exception as e:
293280
exception_handler(
@@ -310,7 +297,9 @@ def _on_iteration(self) -> None:
310297
timings.append(time.perf_counter() - interm) # timing
311298
interm = time.perf_counter() # timing
312299
try:
313-
fetch_input_result = self._worker.fetch_inputs(request, self._feature_store)
300+
fetch_input_result = self._worker.fetch_inputs(
301+
request, self._feature_stores
302+
)
314303
except Exception as e:
315304
exception_handler(e, request.callback, "Failed while fetching the inputs.")
316305
return
@@ -370,10 +359,16 @@ def _on_iteration(self) -> None:
370359
if reply.outputs is None or not reply.outputs:
371360
response = build_failure_reply("fail", "Outputs not found.")
372361
else:
373-
if reply.outputs is None or not reply.outputs:
374-
response = build_failure_reply("fail", "no-results")
375-
376-
response = build_reply(self._worker, reply)
362+
reply.status_enum = "complete"
363+
reply.message = "Success"
364+
365+
results = self._worker.prepare_outputs(reply)
366+
response = MessageHandler.build_response(
367+
status=reply.status_enum,
368+
message=reply.message,
369+
result=results,
370+
custom_attributes=None,
371+
)
377372

378373
timings.append(time.perf_counter() - interm) # timing
379374
interm = time.perf_counter() # timing

smartsim/_core/mli/infrastructure/environmentloader.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,10 @@ class EnvironmentConfigLoader:
4646
"""
4747

4848
def __init__(self) -> None:
49-
self._feature_store_descriptor: t.Optional[str] = os.getenv(
50-
"SSFeatureStore", None
51-
)
5249
self._queue_descriptor: t.Optional[str] = os.getenv("SSQueue", None)
53-
self.feature_store: t.Optional[FeatureStore] = None
54-
self.feature_stores: t.Optional[t.Dict[FeatureStore]] = None
50+
self.feature_stores: t.Optional[t.Dict[str, FeatureStore]] = None
5551
self.queue: t.Optional[DragonFLIChannel] = None
52+
self._prefix = "SSFeatureStore"
5653

5754
def _load_feature_store(self, env_var: str) -> FeatureStore:
5855
"""Load a feature store from a descriptor
@@ -62,10 +59,12 @@ def _load_feature_store(self, env_var: str) -> FeatureStore:
6259

6360
value = os.getenv(env_var)
6461
if not value:
65-
raise SmartSimError(f"Empty feature store descriptor in environment: {env_var}")
62+
raise SmartSimError(
63+
f"Empty feature store descriptor in environment: {env_var}"
64+
)
6665

6766
try:
68-
return pickle.loads(base64.b64decode(value))
67+
return t.cast(FeatureStore, pickle.loads(base64.b64decode(value)))
6968
except:
7069
raise SmartSimError(
7170
f"Invalid feature store descriptor in environment: {env_var}"
@@ -74,9 +73,8 @@ def _load_feature_store(self, env_var: str) -> FeatureStore:
7473
def get_feature_stores(self) -> t.Dict[str, FeatureStore]:
7574
"""Loads multiple Feature Stores by scanning environment for variables
7675
prefixed with `SSFeatureStore`"""
77-
prefix = "SSFeatureStore"
78-
if self.feature_stores is None:
79-
env_vars = [var for var in os.environ if var.startswith(prefix)]
76+
if not self.feature_stores:
77+
env_vars = [var for var in os.environ if var.startswith(self._prefix)]
8078
stores = [self._load_feature_store(var) for var in env_vars]
8179
self.feature_stores = {fs.descriptor: fs for fs in stores}
8280
return self.feature_stores

smartsim/_core/mli/infrastructure/worker/worker.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
import typing as t
2828
from abc import ABC, abstractmethod
2929

30-
import numpy as np
31-
3230
from .....error import SmartSimError
3331
from .....log import get_logger
3432
from ...comm.channel.channel import CommChannelBase
@@ -38,6 +36,7 @@
3836

3937
if t.TYPE_CHECKING:
4038
from smartsim._core.mli.mli_schemas.response.response_capnp import Status
39+
from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor
4140

4241
logger = get_logger(__name__)
4342

@@ -81,13 +80,13 @@ class InferenceReply:
8180
def __init__(
8281
self,
8382
outputs: t.Optional[t.Collection[t.Any]] = None,
84-
output_keys: t.Optional[t.Collection[str]] = None,
83+
output_keys: t.Optional[t.Collection[FeatureStoreKey]] = None,
8584
status_enum: "Status" = "running",
8685
message: str = "In progress",
8786
) -> None:
8887
"""Initialize the object"""
8988
self.outputs: t.Collection[t.Any] = outputs or []
90-
self.output_keys: t.Collection[t.Optional[str]] = output_keys or []
89+
self.output_keys: t.Collection[t.Optional[FeatureStoreKey]] = output_keys or []
9190
self.status_enum = status_enum
9291
self.message = message
9392

@@ -175,27 +174,25 @@ def deserialize_message(
175174
elif request.model.which() == "data":
176175
model_bytes = request.model.data
177176

178-
callback_key = request.replyChannel.reply
177+
callback_key = request.replyChannel.descriptor
179178
comm_channel = channel_type(callback_key)
180-
181179
input_keys: t.Optional[t.List[FeatureStoreKey]] = None
182180
input_bytes: t.Optional[t.List[bytes]] = None
183-
input_meta: t.List[t.Any] = []
181+
output_keys: t.Optional[t.List[FeatureStoreKey]] = None
182+
input_meta: t.Optional[t.List[TensorDescriptor]] = None
184183

185184
if request.input.which() == "keys":
186185
input_keys = [
187-
FeatureStoreKey(input_key.key, input_key.featureStoreDescriptor)
188-
for input_key in request.input.keys
186+
FeatureStoreKey(value.key, value.featureStoreDescriptor)
187+
for value in request.input.keys
189188
]
190-
elif request.input.which() == "data":
191-
input_bytes = [data.blob for data in request.input.data]
192-
input_meta = [data.tensorDescriptor for data in request.input.data]
189+
elif request.input.which() == "descriptors":
190+
input_meta = request.input.descriptors # type: ignore
193191

194-
output_keys: t.List[FeatureStoreKey] = []
195192
if request.output:
196193
output_keys = [
197-
FeatureStoreKey(output_key.key, output_key.featureStoreDescriptor)
198-
for output_key in request.output
194+
FeatureStoreKey(value.key, value.featureStoreDescriptor)
195+
for value in request.output
199196
]
200197

201198
inference_request = InferenceRequest(
@@ -214,27 +211,19 @@ def deserialize_message(
214211
def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]:
215212
prepared_outputs: t.List[t.Any] = []
216213
if reply.output_keys:
217-
for fs_key in reply.output_keys:
218-
if not fs_key:
214+
for value in reply.output_keys:
215+
if not value:
219216
continue
220-
221-
msg_key = MessageHandler.build_tensor_key(fs_key.key, fs_key.descriptor)
217+
msg_key = MessageHandler.build_tensor_key(value.key, value.descriptor)
222218
prepared_outputs.append(msg_key)
223219
elif reply.outputs:
224-
arrays: t.List[np.ndarray[t.Any, np.dtype[t.Any]]] = [
225-
output.numpy() for output in reply.outputs
226-
]
227-
for tensor in arrays:
228-
# todo: need to have the output attributes specified in the req?
229-
# maybe, add `MessageHandler.dtype_of(tensor)`?
230-
# can `build_tensor` do dtype and shape?
231-
msg_tensor = MessageHandler.build_tensor(
232-
tensor,
220+
for _ in reply.outputs:
221+
msg_tensor_desc = MessageHandler.build_tensor_descriptor(
233222
"c",
234223
"float32",
235224
[1],
236225
)
237-
prepared_outputs.append(msg_tensor)
226+
prepared_outputs.append(msg_tensor_desc)
238227
return prepared_outputs
239228

240229
@staticmethod

smartsim/_core/mli/message_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def _assign_result(
439439
result: t.Union[
440440
t.List[tensor_capnp.TensorDescriptor],
441441
t.List[data_references_capnp.TensorKey],
442+
None,
442443
],
443444
) -> None:
444445
"""
@@ -504,7 +505,7 @@ def build_response(
504505
result: t.Union[
505506
t.List[tensor_capnp.TensorDescriptor],
506507
t.List[data_references_capnp.TensorKey],
507-
None
508+
None,
508509
],
509510
custom_attributes: t.Union[
510511
response_attributes_capnp.TorchResponseAttributes,

tests/dragon/test_reply_building.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030

3131
dragon = pytest.importorskip("dragon")
3232

33-
from smartsim._core.mli.infrastructure.control.workermanager import (
34-
build_failure_reply,
35-
build_reply,
36-
)
33+
from smartsim._core.mli.infrastructure.control.workermanager import build_failure_reply
3734
from smartsim._core.mli.infrastructure.worker.worker import InferenceReply
3835

3936
if t.TYPE_CHECKING:
@@ -64,28 +61,3 @@ def test_build_failure_reply_fails():
6461

6562
assert "Error assigning status to response" in ex.value.args[0]
6663

67-
68-
@pytest.mark.parametrize(
69-
"status, message",
70-
[
71-
pytest.param("complete", "Success", id="complete"),
72-
],
73-
)
74-
def test_build_reply(status: "Status", message: str):
75-
"Ensures replies can be built successfully"
76-
reply = InferenceReply()
77-
reply.status_enum = status
78-
reply.message = message
79-
response = build_reply(reply)
80-
assert response.status == status
81-
assert response.message == message
82-
83-
84-
def test_build_reply_fails():
85-
"Ensures ValueError is raised if a Status Enum is not used"
86-
with pytest.raises(ValueError) as ex:
87-
reply = InferenceReply()
88-
reply.status_enum = "not a status enum"
89-
response = build_reply(reply)
90-
91-
assert "Error assigning status to response" in ex.value.args[0]

tests/mli/test_worker_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
import pytest
3434

35+
from tests.mli.featurestore import FileSystemFeatureStore
36+
3537
torch = pytest.importorskip("torch")
3638
dragon = pytest.importorskip("dragon")
3739

@@ -183,14 +185,11 @@ def test_worker_manager(prepare_environment: pathlib.Path) -> None:
183185
)
184186

185187
# create a mock client application to populate the request queue
186-
feature_stores = config_loader.get_feature_stores()
187-
fs_list = list(feature_stores.values())
188-
189188
msg_pump = mp.Process(
190189
target=mock_messages,
191190
args=(
192191
config_loader.get_queue(),
193-
fs_list[0],
192+
FileSystemFeatureStore(fs_path),
194193
fs_path,
195194
comm_path,
196195
),

0 commit comments

Comments
 (0)