diff --git a/drucker/drucker_dashboard_servicer.py b/drucker/drucker_dashboard_servicer.py index ef9637a..0bf5c91 100644 --- a/drucker/drucker_dashboard_servicer.py +++ b/drucker/drucker_dashboard_servicer.py @@ -99,6 +99,11 @@ def on_error(self, error: Exception): self.logger.error(str(error)) self.logger.error(traceback.format_exc()) + def is_valid_upload_filename(self, filename: str) -> bool: + if Path(filename).name == filename: + return True + return False + def ServiceInfo(self, request: drucker_pb2.ServiceInfoRequest, context: _Context @@ -117,15 +122,20 @@ def UploadModel(self, ) -> drucker_pb2.ModelResponse: """ Upload your latest ML model. """ - save_path = None + first_req = next(request_iterator) + save_path = first_req.path + if not self.is_valid_upload_filename(save_path): + raise Exception(f'Error: Invalid model path specified -> {save_path}') + tmp_path = self.app.get_model_path(uuid.uuid4().hex) Path(tmp_path).parent.mkdir(parents=True, exist_ok=True) with open(tmp_path, 'wb') as f: + f.write(first_req.data) for request in request_iterator: - save_path = request.path - model_data = request.data - f.write(model_data) + f.write(request.data) + del first_req f.close() + model_path = self.app.get_model_path(save_path) Path(model_path).parent.mkdir(parents=True, exist_ok=True) shutil.move(tmp_path, model_path) @@ -139,17 +149,20 @@ def SwitchModel(self, ) -> drucker_pb2.ModelResponse: """ Switch your ML model to run. """ + if not self.is_valid_upload_filename(request.path): + raise Exception(f'Error: Invalid model path specified -> {request.path}') + model_assignment = self.app.db.session.query(ModelAssignment).filter(ModelAssignment.service_name == self.app.config.SERVICE_NAME).one() model_assignment.model_path = request.path model_assignment.first_boot = False self.app.db.session.commit() - model_path = self.app.get_model_path() # :TODO: Use enum for SERVICE_INFRA if self.app.config.SERVICE_INFRA == "kubernetes": pass elif self.app.config.SERVICE_INFRA == "default": - self.app.load_model(model_path) + self.app.model_path = self.app.get_model_path() + self.app.load_model() return drucker_pb2.ModelResponse(status=1, message='Success: Switching model file.') @@ -163,6 +176,8 @@ def EvaluateModel(self, """ first_req = next(request_iterator) save_path = first_req.data_path + if not self.is_valid_upload_filename(save_path): + raise Exception(f'Error: Invalid evaluation file path specified -> {save_path}') test_data = b''.join([first_req.data] + [r.data for r in request_iterator]) result, details = self.app.evaluate(test_data) diff --git a/drucker/test/test_dashboard_servicer.py b/drucker/test/test_dashboard_servicer.py index 42d7765..e148d07 100644 --- a/drucker/test/test_dashboard_servicer.py +++ b/drucker/test/test_dashboard_servicer.py @@ -15,7 +15,7 @@ class DruckerWorkerServicerTest(unittest.TestCase): def test_ServiceInfo(self): servicer = DruckerDashboardServicer(logger=system_logger, app=app) request = drucker_pb2.ServiceInfoRequest() - response = servicer.ServiceInfo(request=request, context=None) + response = servicer.ServiceInfo(request, Mock()) self.assertEqual(response.application_name, 'test') self.assertEqual(response.service_name, 'test-001') self.assertEqual(response.service_level, 'development') @@ -27,12 +27,13 @@ def test_ServiceInfo(self): def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid): # mock setting mock_path_class.return_value = Mock() + mock_path_class.return_value.name = 'my_path' mock_shutil.move.return_value = True mock_uuid.uuid4.return_value = Mock(hex='my_uuid') servicer = DruckerDashboardServicer(logger=system_logger, app=app) requests = iter(drucker_pb2.UploadModelRequest(path='my_path', data=b'data') for _ in range(1, 3)) - response = servicer.UploadModel(request_iterator=requests, context=None) + response = servicer.UploadModel(requests, Mock()) tmp_path = './test-model/test/my_uuid' save_path = './test-model/test/my_path' @@ -44,6 +45,23 @@ def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid): ], any_order=True) mock_shutil.move.assert_called_once_with(tmp_path, save_path) + @patch('drucker.drucker_dashboard_servicer.uuid') + @patch('drucker.drucker_dashboard_servicer.shutil') + @patch('drucker.drucker_dashboard_servicer.Path') + @patch("builtins.open", new_callable=mock_open) + def test_InvalidUploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid): + # mock setting + mock_path_class.return_value = Mock() + mock_path_class.return_value.name = 'my_path' + mock_shutil.move.return_value = True + mock_uuid.uuid4.return_value = Mock(hex='my_uuid') + + servicer = DruckerDashboardServicer(logger=system_logger, app=app) + requests = iter(drucker_pb2.UploadModelRequest(path='../../../my_path', data=b'data') for _ in range(1, 3)) + response = servicer.UploadModel(requests, Mock()) + + self.assertEqual(response.status, 0) + @patch('drucker.test.DummyApp') def test_SwitchModel(self, mock_app): # mock setting @@ -52,10 +70,25 @@ def test_SwitchModel(self, mock_app): servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app) request = drucker_pb2.SwitchModelRequest(path='my_path') - response = servicer.SwitchModel(request=request, context=None) + response = servicer.SwitchModel(request, Mock()) self.assertEqual(response.status, 1) - mock_app.load_model.assert_called_once_with('test/my_path') + mock_app.load_model.assert_called_once_with() + + @patch('drucker.test.DummyApp') + @patch('drucker.drucker_dashboard_servicer.Path') + def test_InvalidSwitchModel(self, mock_path_class, mock_app): + # mock setting + mock_path_class.return_value = Mock() + mock_path_class.return_value.name = 'my_path' + mock_app.get_model_path.return_value = 'test/my_path' + mock_app.config.SERVICE_INFRA = 'default' + + servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app) + request = drucker_pb2.SwitchModelRequest(path='../../my_path') + response = servicer.SwitchModel(request, Mock()) + + self.assertEqual(response.status, 0) @patch("builtins.open", new_callable=mock_open) @patch('drucker.drucker_dashboard_servicer.pickle') @@ -67,7 +100,7 @@ def test_EvalauteModel(self, mock_pickle, mock_file): servicer = DruckerDashboardServicer(logger=system_logger, app=app) requests = iter(drucker_pb2.EvaluateModelRequest(data_path='my_path', data=b'data_') for _ in range(1, 3)) - response = servicer.EvaluateModel(request_iterator=requests, context=None) + response = servicer.EvaluateModel(requests, Mock()) self.assertEqual(round(response.metrics.num, 3), eval_result.num) self.assertEqual(round(response.metrics.accuracy, 3), eval_result.accuracy) @@ -83,6 +116,23 @@ def test_EvalauteModel(self, mock_pickle, mock_file): call("./eval/test/my_path_eval_detail.pkl", "wb") ], any_order=True) + @patch("builtins.open", new_callable=mock_open) + @patch('drucker.drucker_dashboard_servicer.pickle') + @patch('drucker.drucker_dashboard_servicer.Path') + def test_InvalidEvalauteModel(self, mock_path_class, mock_pickle, mock_file): + # mock setting + mock_path_class.return_value = Mock() + mock_path_class.return_value.name = 'my_path' + eval_result = EvaluateResult(1, 0.8, [0.7], [0.6], [0.5], {'dummy': 0.4}) + details = [EvaluateDetail('test_input', 'test_label', PredictResult('pre_label', 0.9), False)] + app.evaluate = Mock(return_value=(eval_result, details)) + + servicer = DruckerDashboardServicer(logger=system_logger, app=app) + requests = iter(drucker_pb2.EvaluateModelRequest(data_path='../../my_path', data=b'data_') for _ in range(1, 3)) + response = servicer.EvaluateModel(requests, Mock()) + + self.assertEqual(response.metrics.num, 0) + def test_error_handling(self): # mock setting app.get_model_path = Mock(side_effect=Exception('dummy exception'))