diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 3415097..89c68e7 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -25,6 +25,7 @@ DEFAULT_MODEL_SERVER_TIMEOUT = "60" DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes DEFAULT_HTTP_PORT = "8080" +DEFAULT_VMARGS = "-XX:-UseContainerSupport" SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str @@ -79,6 +80,7 @@ def __init__(self): self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) + self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS) @staticmethod def _parse_module_name(program_param): @@ -140,3 +142,8 @@ def safe_port_range(self): # type: () -> str specified by SageMaker for handling pings and invocations. """ return self._safe_port_range + + @property + def vmargs(self): # type: () -> str + """str: vmargs can be provided for the JVM, to be overriden""" + return self._vmargs diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index 447b626..4d3be2e 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -159,7 +159,7 @@ def _generate_mms_config_properties(env, handler_service=None): "default_workers_per_model": env.model_server_workers, "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), "management_address": "http://0.0.0.0:{}".format(env.management_http_port), - "vmargs": "-XX:-UseContainerSupport", + "vmargs": env.vmargs, } # If provided, add handler service to user config if handler_service: diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index c6b7d2d..d1f438b 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -20,6 +20,7 @@ DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str +MODEL_SERVER_VMARGS = "SAGEMAKER_MODEL_SERVER_VMARGS" # type: str STARTUP_TIMEOUT_ENV = "SAGEMAKER_STARTUP_TIMEOUT" # type: str BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py index a6d2ba8..f353797 100644 --- a/test/unit/test_default_handler_service.py +++ b/test/unit/test_default_handler_service.py @@ -43,7 +43,7 @@ def test_handle(): result = handler_service.handle(DATA, CONTEXT) assert result == TRANSFORMED_RESULT - assert transformer.transform.called_once_with(DATA, CONTEXT) + transformer.transform.assert_called_once_with(DATA, CONTEXT) def test_initialize(): @@ -57,4 +57,4 @@ def getitem(key): context.system_properties.__getitem__.side_effect = getitem DefaultHandlerService(transformer).initialize(context) - assert transformer.validate_and_initialize().called_once() + transformer.validate_and_initialize.assert_called_once() diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index f54c43e..afb90de 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -28,6 +28,7 @@ parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html", parameters.BIND_TO_PORT_ENV: "1738", parameters.SAFE_PORT_RANGE_ENV: "1111-2222", + parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport", }, clear=True, ) @@ -45,6 +46,7 @@ def test_env(): assert env.inference_http_port == "1738" assert env.management_http_port == "1738" assert env.safe_port_range == "1111-2222" + assert "-XX:-UseContainerSupport" in env.vmargs @pytest.mark.parametrize("sagemaker_program", ["program.py", "program"]) diff --git a/tox.ini b/tox.ini index 9503285..0fc890f 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ passenv = # {posargs} can be passed in by additional arguments specified when invoking tox. # Can be used to specify which tests to run, e.g.: tox -- -s commands = - coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m py.test {posargs} + coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m pytest {posargs} {env:IGNORE_COVERAGE:} coverage report --rcfile .coveragerc_{envname} {env:IGNORE_COVERAGE:} coverage html --rcfile .coveragerc_{envname}