Open
Description
Describe the feature you'd like
Batch Transform deployment to support ModelDataSource for LLM batch transform operations.
How would this feature be used? Please describe.
A clear and concise description of the use case for this feature. Please provide an example, if possible.
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id()
from huggingface_hub import snapshot_download
from pathlib import Path
import os
# - This will download the model into the current directory where ever the jupyter notebook is running
local_model_path = Path(".")
local_model_path.mkdir(exist_ok=True)
model_name = 'mistralai/Mistral-7B-v0.1'
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.txt", "*.model", "*.safetensors", "*.bin", "*.chk", "*.pth"]
# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
model_download_path = snapshot_download(
repo_id=model_name,
cache_dir=local_model_path,
allow_patterns=allow_patterns,
token='<HF TOKEN>'
)
%%writefile {model_download_path}/serving.properties
engine=Python
option.tensor_parallel_degree=max
option.model_id={{model_id}}
option.max_rolling_batch_size=16
option.rolling_batch=vllm
import jinja2
from pathlib import Path
jinja_env = jinja2.Environment()
template = jinja_env.from_string(Path("serving.properties").open().read())
Path("serving.properties").open("w").write(
template.render(model_id=base_model_s3_uri)
)
base_model_s3_uri = sess.upload_data(path=model_download_path, key_prefix="batch-transform-mistral/model")
print(f"Model uploaded to --- > {base_model_s3_uri}")
#https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/test_djl_inference.py#L31-L33
image_uri = image_uris.retrieve(
framework="djl-lmi",
region=region,
version="0.28.0"
)
model_data = {
"S3DataSource": {
"S3Uri": f'{base_model_s3_uri}/',
'S3DataType': 'S3Prefix',
'CompressionType': 'None'
}
}
# create your SageMaker Model
model = sagemaker.Model(image_uri=image_uri, model_data=model_data, role=role)
from sagemaker.utils import name_from_base
endpoint_name = name_from_base("lmi-batch-transform-mistral-gated")
# instance type you will deploy your model to
instance_type = "ml.g5.12xlarge"
# Creating the batch transformer object. If you have a large dataset you can
# divide it into smaller chunks and use more instances for faster inference
batch_transformer = model.transformer(
instance_count=1,
instance_type=instance_type,
output_path=s3_output_data_path,
assemble_with="Line",
accept="text/csv",
max_payload=1,
)
batch_transformer.env = hyper_params_dict
# Making the predictions on the input data
batch_transformer.transform(
s3_input_data_path, content_type="application/jsonlines", split_type="Line"
)
batch_transformer.wait()
This throws the error:
---------------------------------------------------------------------------
ClientError Traceback (most recent call last)
Cell In[36], line 14
11 batch_transformer.env = hyper_params_dict
13 # Making the predictions on the input data
---> 14 batch_transformer.transform(
15 s3_input_data_path, content_type="application[/jsonlines](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/jsonlines)", split_type="Line"
16 )
18 batch_transformer.wait()
File [/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py:346](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py#line=345), in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
342 return context
344 return _StepArguments(retrieve_caller_name(self_instance), run_func, *args, **kwargs)
--> 346 return run_func(*args, **kwargs)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:302](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=301), in Transformer.transform(self, data, data_type, content_type, compression_type, split_type, job_name, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config, wait, logs)
292 experiment_config = check_and_get_run_experiment_config(experiment_config)
294 batch_data_capture_config = resolve_class_attribute_from_config(
295 None,
296 batch_data_capture_config,
(...)
299 sagemaker_session=self.sagemaker_session,
300 )
--> 302 self.latest_transform_job = _TransformJob.start_new(
303 self,
304 data,
305 data_type,
306 content_type,
307 compression_type,
308 split_type,
309 input_filter,
310 output_filter,
311 join_source,
312 experiment_config,
313 model_client_config,
314 batch_data_capture_config,
315 )
317 if wait:
318 self.latest_transform_job.wait(logs=logs)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:636](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=635), in _TransformJob.start_new(cls, transformer, data, data_type, content_type, compression_type, split_type, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config)
619 """Placeholder docstring"""
621 transform_args = cls._get_transform_args(
622 transformer,
623 data,
(...)
633 batch_data_capture_config,
634 )
--> 636 transformer.sagemaker_session.transform(**transform_args)
638 return cls(transformer.sagemaker_session, transformer._current_job_name)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3805](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3804), in Session.transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, input_config, output_config, resource_config, experiment_config, env, tags, data_processing, model_client_config, batch_data_capture_config)
3802 logger.debug("Transform request: %s", json.dumps(request, indent=4))
3803 self.sagemaker_client.create_transform_job(**request)
-> 3805 self._intercept_create_request(transform_request, submit, self.transform.__name__)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:6497](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=6496), in Session._intercept_create_request(self, request, create, func_name)
6480 def _intercept_create_request(
6481 self,
6482 request: typing.Dict,
(...)
6485 # pylint: disable=unused-argument
6486 ):
6487 """This function intercepts the create job request.
6488
6489 PipelineSession inherits this Session class and will override
(...)
6495 func_name (str): the name of the function needed intercepting
6496 """
-> 6497 return create(request)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3803](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3802), in Session.transform.<locals>.submit(request)
3801 logger.info("Creating transform job with name: %s", job_name)
3802 logger.debug("Transform request: %s", json.dumps(request, indent=4))
-> 3803 self.sagemaker_client.create_transform_job(**request)
File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:565](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=564), in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
561 raise TypeError(
562 f"{py_operation_name}() only accepts keyword arguments."
563 )
564 # The "self" in this scope is referring to the BaseClient.
--> 565 return self._make_api_call(operation_name, kwargs)
File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:1021](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=1020), in BaseClient._make_api_call(self, operation_name, api_params)
1017 error_code = error_info.get("QueryErrorCode") or error_info.get(
1018 "Code"
1019 )
1020 error_class = self.exceptions.from_code(error_code)
-> 1021 raise error_class(parsed_response, operation_name)
1022 else:
1023 return parsed_response
ClientError: An error occurred (ValidationException) when calling the CreateTransformJob operation: SageMaker Batch currently doesn't support Model entity with container definitions which use ModelDataSource attribute
Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
Additional context
Add any other context or screenshots about the feature request here.