Skip to content

Commit 8d51e40

Browse files
committed
PR Feedback
1 parent cf38f47 commit 8d51e40

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

test/container/1.5.0/Dockerfile.pytorch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ RUN apt-get update \
1818
libsm6 \
1919
libxext6 \
2020
libxrender-dev \
21-
openjdk-11-jdk \
21+
openjdk-11-jdk-headless \
2222
&& rm -rf /var/lib/apt/lists/*
2323

2424
RUN conda install -c conda-forge opencv==4.0.1 \

test/unit/test_default_inference_handler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import csv
1616
import json
1717

18+
import mock
1819
import numpy as np
1920
import pytest
2021
import torch
@@ -177,3 +178,32 @@ def test_default_output_fn_gpu(inference_handler):
177178
output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV)
178179

179180
assert "1,2,3\n4,5,6\n".encode("utf-8") == output
181+
182+
def test_eia_default_model_fn(eia_inference_handler):
183+
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
184+
mock_os.getenv.return_value = "true"
185+
mock_os.path.join.return_value = "model_dir"
186+
mock_os.path.exists.return_value = True
187+
with mock.patch("torch.jit.load") as mock_torch:
188+
mock_torch.return_value = DummyModel()
189+
model = eia_inference_handler.default_model_fn("model_dir")
190+
assert model is not None
191+
192+
193+
def test_eia_default_model_fn_error(eia_inference_handler):
194+
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
195+
mock_os.getenv.return_value = "true"
196+
mock_os.path.join.return_value = "model_dir"
197+
mock_os.path.exists.return_value = False
198+
with pytest.raises(FileNotFoundError):
199+
eia_inference_handler.default_model_fn("model_dir")
200+
201+
202+
def test_eia_default_predict_fn(eia_inference_handler, tensor):
203+
model = DummyModel()
204+
with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os:
205+
mock_os.getenv.return_value = "true"
206+
with mock.patch("torch.jit.optimized_execution") as mock_torch:
207+
mock_torch.__enter__.return_value = "dummy"
208+
eia_inference_handler.default_predict_fn(tensor, model)
209+
mock_torch.assert_called_once()

0 commit comments

Comments
 (0)