|
15 | 15 | import csv
|
16 | 16 | import json
|
17 | 17 |
|
| 18 | +import mock |
18 | 19 | import numpy as np
|
19 | 20 | import pytest
|
20 | 21 | import torch
|
@@ -177,3 +178,32 @@ def test_default_output_fn_gpu(inference_handler):
|
177 | 178 | output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV)
|
178 | 179 |
|
179 | 180 | assert "1,2,3\n4,5,6\n".encode("utf-8") == output
|
| 181 | + |
| 182 | +def test_eia_default_model_fn(eia_inference_handler): |
| 183 | + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: |
| 184 | + mock_os.getenv.return_value = "true" |
| 185 | + mock_os.path.join.return_value = "model_dir" |
| 186 | + mock_os.path.exists.return_value = True |
| 187 | + with mock.patch("torch.jit.load") as mock_torch: |
| 188 | + mock_torch.return_value = DummyModel() |
| 189 | + model = eia_inference_handler.default_model_fn("model_dir") |
| 190 | + assert model is not None |
| 191 | + |
| 192 | + |
| 193 | +def test_eia_default_model_fn_error(eia_inference_handler): |
| 194 | + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: |
| 195 | + mock_os.getenv.return_value = "true" |
| 196 | + mock_os.path.join.return_value = "model_dir" |
| 197 | + mock_os.path.exists.return_value = False |
| 198 | + with pytest.raises(FileNotFoundError): |
| 199 | + eia_inference_handler.default_model_fn("model_dir") |
| 200 | + |
| 201 | + |
| 202 | +def test_eia_default_predict_fn(eia_inference_handler, tensor): |
| 203 | + model = DummyModel() |
| 204 | + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: |
| 205 | + mock_os.getenv.return_value = "true" |
| 206 | + with mock.patch("torch.jit.optimized_execution") as mock_torch: |
| 207 | + mock_torch.__enter__.return_value = "dummy" |
| 208 | + eia_inference_handler.default_predict_fn(tensor, model) |
| 209 | + mock_torch.assert_called_once() |
0 commit comments