From c01dde749687f86e7891ec403eed3f98d4fcfb50 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Fri, 20 Jan 2023 23:59:12 +0000 Subject: [PATCH 1/8] Add environment variable with VMARGS --- src/sagemaker_inference/environment.py | 15 +++++++++------ src/sagemaker_inference/model_server.py | 2 +- src/sagemaker_inference/parameters.py | 1 + test/unit/test_environment.py | 2 ++ 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 3415097..039aa43 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -25,6 +25,7 @@ DEFAULT_MODEL_SERVER_TIMEOUT = "60" DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes DEFAULT_HTTP_PORT = "8080" +DEFAULT_VMARGS = "-XX:-UseContainerSupport" SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str @@ -70,15 +71,12 @@ def __init__(self): os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT) ) self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV) - self._startup_timeout = int( - os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT) - ) - self._default_accept = os.environ.get( - parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON - ) + self._startup_timeout = int(os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT)) + self._default_accept = os.environ.get(parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON) self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) + self._vmargs = os.environ.get(parameters.VMARGS, DEFAULT_VMARGS) @staticmethod def _parse_module_name(program_param): @@ -140,3 +138,8 @@ def safe_port_range(self): # type: () -> str specified by SageMaker for handling pings and invocations. """ return self._safe_port_range + + @property + def vmargs(self): # type: () -> str + """str: vmargs can be provided for the JVM, to be overriden""" + return self._vmargs diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index 447b626..4d3be2e 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -159,7 +159,7 @@ def _generate_mms_config_properties(env, handler_service=None): "default_workers_per_model": env.model_server_workers, "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), "management_address": "http://0.0.0.0:{}".format(env.management_http_port), - "vmargs": "-XX:-UseContainerSupport", + "vmargs": env.vmargs, } # If provided, add handler service to user config if handler_service: diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index c6b7d2d..dd0cf58 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -24,3 +24,4 @@ BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str +VMARGS = "VMARGS" # type: str \ No newline at end of file diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index f54c43e..a8cb57a 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -28,6 +28,7 @@ parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html", parameters.BIND_TO_PORT_ENV: "1738", parameters.SAFE_PORT_RANGE_ENV: "1111-2222", + parameters.VMARGS: "-XX:-UseContainerSupport", }, clear=True, ) @@ -45,6 +46,7 @@ def test_env(): assert env.inference_http_port == "1738" assert env.management_http_port == "1738" assert env.safe_port_range == "1111-2222" + assert "-XX:-UseContainerSupport" in env.vmargs @pytest.mark.parametrize("sagemaker_program", ["program.py", "program"]) From e7f44732ca1f6928be92c29610e7cb1a79acaadf Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Sat, 21 Jan 2023 00:21:49 +0000 Subject: [PATCH 2/8] Fix linting error --- src/sagemaker_inference/environment.py | 8 ++++++-- src/sagemaker_inference/parameters.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 039aa43..0f6e20c 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -71,8 +71,12 @@ def __init__(self): os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT) ) self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV) - self._startup_timeout = int(os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT)) - self._default_accept = os.environ.get(parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON) + self._startup_timeout = int( + os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT) + ) + self._default_accept = os.environ.get( + parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON + ) self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index dd0cf58..26771f3 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -24,4 +24,4 @@ BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str -VMARGS = "VMARGS" # type: str \ No newline at end of file +VMARGS = "VMARGS" # type: str From 729df77dea9cf7805b0a1e101c1740cefcd6f830 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Sat, 21 Jan 2023 00:27:08 +0000 Subject: [PATCH 3/8] Fix flake8 --- src/sagemaker_inference/parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index 26771f3..0adbcce 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -24,4 +24,4 @@ BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str -VMARGS = "VMARGS" # type: str +VMARGS = "VMARGS" # type: str From 5b3579c57b10f7c4d0bdad0f180351d25f839910 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 25 Jan 2023 00:28:14 +0000 Subject: [PATCH 4/8] Fix mock calls in default handler test --- test/unit/test_default_handler_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py index a6d2ba8..5efb194 100644 --- a/test/unit/test_default_handler_service.py +++ b/test/unit/test_default_handler_service.py @@ -43,7 +43,7 @@ def test_handle(): result = handler_service.handle(DATA, CONTEXT) assert result == TRANSFORMED_RESULT - assert transformer.transform.called_once_with(DATA, CONTEXT) + assert transformer.transform.assert_called_once_with(DATA, CONTEXT) def test_initialize(): @@ -57,4 +57,4 @@ def getitem(key): context.system_properties.__getitem__.side_effect = getitem DefaultHandlerService(transformer).initialize(context) - assert transformer.validate_and_initialize().called_once() + assert transformer.validate_and_initialize().assert_called_once() From 2c9e7d3f90f10d53c524a0b12bf9436332766f14 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 25 Jan 2023 00:37:41 +0000 Subject: [PATCH 5/8] Remove duplicate assert --- test/unit/test_default_handler_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py index 5efb194..7f9ef74 100644 --- a/test/unit/test_default_handler_service.py +++ b/test/unit/test_default_handler_service.py @@ -43,7 +43,7 @@ def test_handle(): result = handler_service.handle(DATA, CONTEXT) assert result == TRANSFORMED_RESULT - assert transformer.transform.assert_called_once_with(DATA, CONTEXT) + transformer.transform.assert_called_once_with(DATA, CONTEXT) def test_initialize(): @@ -57,4 +57,4 @@ def getitem(key): context.system_properties.__getitem__.side_effect = getitem DefaultHandlerService(transformer).initialize(context) - assert transformer.validate_and_initialize().assert_called_once() + transformer.validate_and_initialize().assert_called_once() From 98f4119f55104341ee9107372fd952c3ece87098 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 25 Jan 2023 01:06:50 +0000 Subject: [PATCH 6/8] Fix function call --- test/unit/test_default_handler_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py index 7f9ef74..f353797 100644 --- a/test/unit/test_default_handler_service.py +++ b/test/unit/test_default_handler_service.py @@ -57,4 +57,4 @@ def getitem(key): context.system_properties.__getitem__.side_effect = getitem DefaultHandlerService(transformer).initialize(context) - transformer.validate_and_initialize().assert_called_once() + transformer.validate_and_initialize.assert_called_once() From 030a53281a55117822ef7869ffc714a5706ebad7 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Sat, 28 Jan 2023 01:14:35 +0000 Subject: [PATCH 7/8] Use pytest instead of py.test --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 9503285..0fc890f 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ passenv = # {posargs} can be passed in by additional arguments specified when invoking tox. # Can be used to specify which tests to run, e.g.: tox -- -s commands = - coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m py.test {posargs} + coverage run --rcfile .coveragerc_{envname} --source sagemaker_inference -m pytest {posargs} {env:IGNORE_COVERAGE:} coverage report --rcfile .coveragerc_{envname} {env:IGNORE_COVERAGE:} coverage html --rcfile .coveragerc_{envname} From d44ea0889206ef1cb458ab23ab76edf1a8a93113 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Tue, 31 Jan 2023 18:55:41 +0000 Subject: [PATCH 8/8] Change env variable name --- src/sagemaker_inference/environment.py | 2 +- src/sagemaker_inference/parameters.py | 2 +- test/unit/test_environment.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py index 0f6e20c..89c68e7 100644 --- a/src/sagemaker_inference/environment.py +++ b/src/sagemaker_inference/environment.py @@ -80,7 +80,7 @@ def __init__(self): self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) - self._vmargs = os.environ.get(parameters.VMARGS, DEFAULT_VMARGS) + self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS) @staticmethod def _parse_module_name(program_param): diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py index 0adbcce..d1f438b 100644 --- a/src/sagemaker_inference/parameters.py +++ b/src/sagemaker_inference/parameters.py @@ -20,8 +20,8 @@ DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str +MODEL_SERVER_VMARGS = "SAGEMAKER_MODEL_SERVER_VMARGS" # type: str STARTUP_TIMEOUT_ENV = "SAGEMAKER_STARTUP_TIMEOUT" # type: str BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str -VMARGS = "VMARGS" # type: str diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index a8cb57a..afb90de 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -28,7 +28,7 @@ parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html", parameters.BIND_TO_PORT_ENV: "1738", parameters.SAFE_PORT_RANGE_ENV: "1111-2222", - parameters.VMARGS: "-XX:-UseContainerSupport", + parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport", }, clear=True, )