diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index a9b2cb021d..63e2aca415 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -154,6 +154,7 @@ def predict( inference_id=None, custom_attributes=None, component_name: Optional[str] = None, + target_container_hostname=None, ): """Return the inference from the specified endpoint. @@ -188,6 +189,9 @@ def predict( function (Default: None). component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding the predictor. + target_container_hostname (str): Optional. If the endpoint hosts multiple containers + and is configured to use direct invocation, this parameter specifies the host name + of the container to invoke. (Default: None). Returns: object: Inference for the given input. If a deserializer was specified when creating @@ -203,6 +207,7 @@ def predict( target_variant=target_variant, inference_id=inference_id, custom_attributes=custom_attributes, + target_container_hostname=target_container_hostname, ) inference_component_name = component_name or self._get_component_name() diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 1e4f6d0f0a..bfb7742183 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -135,6 +135,29 @@ def test_predict_call_with_inference_id(): assert result == RETURN_VALUE +def test_predict_call_with_target_container_hostname(): + sagemaker_session = empty_sagemaker_session() + predictor = Predictor(ENDPOINT, sagemaker_session) + + data = "untouched" + result = predictor.predict(data, target_container_hostname="test_target_container_hostname") + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called + + expected_request_args = { + "Accept": DEFAULT_ACCEPT, + "Body": data, + "ContentType": DEFAULT_CONTENT_TYPE, + "EndpointName": ENDPOINT, + "TargetContainerHostname": "test_target_container_hostname", + } + + _, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args + assert kwargs == expected_request_args + + assert result == RETURN_VALUE + + def test_multi_model_predict_call(): sagemaker_session = empty_sagemaker_session() predictor = Predictor(ENDPOINT, sagemaker_session)