3434from ...comm .channel .channel import CommChannelBase
3535from ...comm .channel .dragonchannel import DragonCommChannel
3636from ...infrastructure .environmentloader import EnvironmentConfigLoader
37- from ...infrastructure .storage .featurestore import FeatureStore
3837from ...infrastructure .worker .worker import (
3938 InferenceReply ,
4039 InferenceRequest ,
5453logger = get_logger (__name__ )
5554
5655
57- def build_failure_reply (status : "StatusEnum " , message : str ) -> Response :
56+ def build_failure_reply (status : "Status " , message : str ) -> ResponseBuilder :
5857 """Build a response indicating a failure occurred
5958 :param status: The status of the response
6059 :param message: The error message to include in the response"""
6160 return MessageHandler .build_response (
62- status = status , # todo: need to indicate correct status
63- message = message , # todo: decide what these will be
61+ status = status ,
62+ message = message ,
6463 result = None ,
6564 custom_attributes = None ,
6665 )
6766
6867
69- def build_reply (worker : MachineLearningWorkerBase , reply : InferenceReply ) -> Response :
70- """Builds a response for a successful inference request
71- :param worker: A worker to process the reply with
72- :param reply: The internal representation of the reply"""
73- results = worker .prepare_outputs (reply )
74-
75- return MessageHandler .build_response (
76- status = reply .status_enum ,
77- message = reply .message ,
78- result = results ,
79- custom_attributes = None ,
80- )
81-
82-
8368def exception_handler (
8469 exc : Exception , reply_channel : t .Optional [CommChannelBase ], failure_message : str
8570) -> None :
@@ -143,13 +128,15 @@ def _check_feature_stores(self, request: InferenceRequest) -> bool:
143128 """Ensures that all feature stores required by the request are available
144129 :param request: The request to validate"""
145130 # collect all feature stores required by the request
146- fs_model = {request .model_key .descriptor }
131+ fs_model : t .Set [str ] = set ()
132+ if request .model_key :
133+ fs_model = {request .model_key .descriptor }
147134 fs_inputs = {key .descriptor for key in request .input_keys }
148135 fs_outputs = {key .descriptor for key in request .output_keys }
149136
150137 # identify which feature stores are requested and unknown
151- fs_desired = fs_model + fs_inputs + fs_outputs
152- fs_actual = {key for key in self ._feature_stores }
138+ fs_desired = fs_model . union ( fs_inputs ). union ( fs_outputs )
139+ fs_actual = {item . descriptor for item in self ._feature_stores . values () }
153140 fs_missing = fs_desired - fs_actual
154141
155142 # exit if all desired feature stores are not available
@@ -259,7 +246,7 @@ def _on_iteration(self) -> None:
259246 interm = time .perf_counter () # timing
260247 try :
261248 fetch_model_result = self ._worker .fetch_model (
262- request , self ._feature_store
249+ request , self ._feature_stores
263250 )
264251 except Exception as e :
265252 exception_handler (
@@ -287,7 +274,7 @@ def _on_iteration(self) -> None:
287274 interm = time .perf_counter () # timing
288275 try :
289276 fetch_model_result = self ._worker .fetch_model (
290- request , self ._feature_store
277+ request , self ._feature_stores
291278 )
292279 except Exception as e :
293280 exception_handler (
@@ -310,7 +297,9 @@ def _on_iteration(self) -> None:
310297 timings .append (time .perf_counter () - interm ) # timing
311298 interm = time .perf_counter () # timing
312299 try :
313- fetch_input_result = self ._worker .fetch_inputs (request , self ._feature_store )
300+ fetch_input_result = self ._worker .fetch_inputs (
301+ request , self ._feature_stores
302+ )
314303 except Exception as e :
315304 exception_handler (e , request .callback , "Failed while fetching the inputs." )
316305 return
@@ -370,10 +359,16 @@ def _on_iteration(self) -> None:
370359 if reply .outputs is None or not reply .outputs :
371360 response = build_failure_reply ("fail" , "Outputs not found." )
372361 else :
373- if reply .outputs is None or not reply .outputs :
374- response = build_failure_reply ("fail" , "no-results" )
375-
376- response = build_reply (self ._worker , reply )
362+ reply .status_enum = "complete"
363+ reply .message = "Success"
364+
365+ results = self ._worker .prepare_outputs (reply )
366+ response = MessageHandler .build_response (
367+ status = reply .status_enum ,
368+ message = reply .message ,
369+ result = results ,
370+ custom_attributes = None ,
371+ )
377372
378373 timings .append (time .perf_counter () - interm ) # timing
379374 interm = time .perf_counter () # timing
0 commit comments