Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions drucker/drucker_dashboard_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def on_error(self, error: Exception):
self.logger.error(str(error))
self.logger.error(traceback.format_exc())

def is_valid_upload_filename(self, filename: str) -> bool:
if Path(filename).name == filename:
return True
return False

def ServiceInfo(self,
request: drucker_pb2.ServiceInfoRequest,
context: _Context
Expand All @@ -117,15 +122,20 @@ def UploadModel(self,
) -> drucker_pb2.ModelResponse:
""" Upload your latest ML model.
"""
save_path = None
first_req = next(request_iterator)
save_path = first_req.path
if not self.is_valid_upload_filename(save_path):
raise Exception(f'Error: Invalid model path specified -> {save_path}')

tmp_path = self.app.get_model_path(uuid.uuid4().hex)
Path(tmp_path).parent.mkdir(parents=True, exist_ok=True)
with open(tmp_path, 'wb') as f:
f.write(first_req.data)
for request in request_iterator:
save_path = request.path
model_data = request.data
f.write(model_data)
f.write(request.data)
del first_req
f.close()

model_path = self.app.get_model_path(save_path)
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
shutil.move(tmp_path, model_path)
Expand All @@ -139,17 +149,20 @@ def SwitchModel(self,
) -> drucker_pb2.ModelResponse:
""" Switch your ML model to run.
"""
if not self.is_valid_upload_filename(request.path):
raise Exception(f'Error: Invalid model path specified -> {request.path}')

model_assignment = self.app.db.session.query(ModelAssignment).filter(ModelAssignment.service_name == self.app.config.SERVICE_NAME).one()
model_assignment.model_path = request.path
model_assignment.first_boot = False
self.app.db.session.commit()
model_path = self.app.get_model_path()

# :TODO: Use enum for SERVICE_INFRA
if self.app.config.SERVICE_INFRA == "kubernetes":
pass
elif self.app.config.SERVICE_INFRA == "default":
self.app.load_model(model_path)
self.app.model_path = self.app.get_model_path()
self.app.load_model()

return drucker_pb2.ModelResponse(status=1,
message='Success: Switching model file.')
Expand All @@ -163,6 +176,8 @@ def EvaluateModel(self,
"""
first_req = next(request_iterator)
save_path = first_req.data_path
if not self.is_valid_upload_filename(save_path):
raise Exception(f'Error: Invalid evaluation file path specified -> {save_path}')

test_data = b''.join([first_req.data] + [r.data for r in request_iterator])
result, details = self.app.evaluate(test_data)
Expand Down
60 changes: 55 additions & 5 deletions drucker/test/test_dashboard_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DruckerWorkerServicerTest(unittest.TestCase):
def test_ServiceInfo(self):
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
request = drucker_pb2.ServiceInfoRequest()
response = servicer.ServiceInfo(request=request, context=None)
response = servicer.ServiceInfo(request, Mock())
self.assertEqual(response.application_name, 'test')
self.assertEqual(response.service_name, 'test-001')
self.assertEqual(response.service_level, 'development')
Expand All @@ -27,12 +27,13 @@ def test_ServiceInfo(self):
def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
# mock setting
mock_path_class.return_value = Mock()
mock_path_class.return_value.name = 'my_path'
mock_shutil.move.return_value = True
mock_uuid.uuid4.return_value = Mock(hex='my_uuid')

servicer = DruckerDashboardServicer(logger=system_logger, app=app)
requests = iter(drucker_pb2.UploadModelRequest(path='my_path', data=b'data') for _ in range(1, 3))
response = servicer.UploadModel(request_iterator=requests, context=None)
response = servicer.UploadModel(requests, Mock())

tmp_path = './test-model/test/my_uuid'
save_path = './test-model/test/my_path'
Expand All @@ -44,6 +45,23 @@ def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
], any_order=True)
mock_shutil.move.assert_called_once_with(tmp_path, save_path)

@patch('drucker.drucker_dashboard_servicer.uuid')
@patch('drucker.drucker_dashboard_servicer.shutil')
@patch('drucker.drucker_dashboard_servicer.Path')
@patch("builtins.open", new_callable=mock_open)
def test_InvalidUploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
# mock setting
mock_path_class.return_value = Mock()
mock_path_class.return_value.name = 'my_path'
mock_shutil.move.return_value = True
mock_uuid.uuid4.return_value = Mock(hex='my_uuid')

servicer = DruckerDashboardServicer(logger=system_logger, app=app)
requests = iter(drucker_pb2.UploadModelRequest(path='../../../my_path', data=b'data') for _ in range(1, 3))
response = servicer.UploadModel(requests, Mock())

self.assertEqual(response.status, 0)

@patch('drucker.test.DummyApp')
def test_SwitchModel(self, mock_app):
# mock setting
Expand All @@ -52,10 +70,25 @@ def test_SwitchModel(self, mock_app):

servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app)
request = drucker_pb2.SwitchModelRequest(path='my_path')
response = servicer.SwitchModel(request=request, context=None)
response = servicer.SwitchModel(request, Mock())

self.assertEqual(response.status, 1)
mock_app.load_model.assert_called_once_with('test/my_path')
mock_app.load_model.assert_called_once_with()

@patch('drucker.test.DummyApp')
@patch('drucker.drucker_dashboard_servicer.Path')
def test_InvalidSwitchModel(self, mock_path_class, mock_app):
# mock setting
mock_path_class.return_value = Mock()
mock_path_class.return_value.name = 'my_path'
mock_app.get_model_path.return_value = 'test/my_path'
mock_app.config.SERVICE_INFRA = 'default'

servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app)
request = drucker_pb2.SwitchModelRequest(path='../../my_path')
response = servicer.SwitchModel(request, Mock())

self.assertEqual(response.status, 0)

@patch("builtins.open", new_callable=mock_open)
@patch('drucker.drucker_dashboard_servicer.pickle')
Expand All @@ -67,7 +100,7 @@ def test_EvalauteModel(self, mock_pickle, mock_file):

servicer = DruckerDashboardServicer(logger=system_logger, app=app)
requests = iter(drucker_pb2.EvaluateModelRequest(data_path='my_path', data=b'data_') for _ in range(1, 3))
response = servicer.EvaluateModel(request_iterator=requests, context=None)
response = servicer.EvaluateModel(requests, Mock())

self.assertEqual(round(response.metrics.num, 3), eval_result.num)
self.assertEqual(round(response.metrics.accuracy, 3), eval_result.accuracy)
Expand All @@ -83,6 +116,23 @@ def test_EvalauteModel(self, mock_pickle, mock_file):
call("./eval/test/my_path_eval_detail.pkl", "wb")
], any_order=True)

@patch("builtins.open", new_callable=mock_open)
@patch('drucker.drucker_dashboard_servicer.pickle')
@patch('drucker.drucker_dashboard_servicer.Path')
def test_InvalidEvalauteModel(self, mock_path_class, mock_pickle, mock_file):
# mock setting
mock_path_class.return_value = Mock()
mock_path_class.return_value.name = 'my_path'
eval_result = EvaluateResult(1, 0.8, [0.7], [0.6], [0.5], {'dummy': 0.4})
details = [EvaluateDetail('test_input', 'test_label', PredictResult('pre_label', 0.9), False)]
app.evaluate = Mock(return_value=(eval_result, details))

servicer = DruckerDashboardServicer(logger=system_logger, app=app)
requests = iter(drucker_pb2.EvaluateModelRequest(data_path='../../my_path', data=b'data_') for _ in range(1, 3))
response = servicer.EvaluateModel(requests, Mock())

self.assertEqual(response.metrics.num, 0)

def test_error_handling(self):
# mock setting
app.get_model_path = Mock(side_effect=Exception('dummy exception'))
Expand Down