@@ -37,11 +37,11 @@ class HuggingFacePredictor(Predictor):
3737 """
3838
3939 def __init__ (
40- self ,
41- endpoint_name ,
42- sagemaker_session = None ,
43- serializer = NumpySerializer (),
44- deserializer = NumpyDeserializer (),
40+ self ,
41+ endpoint_name ,
42+ sagemaker_session = None ,
43+ serializer = NumpySerializer (),
44+ deserializer = NumpyDeserializer (),
4545 ):
4646 """Initialize an ``HuggingFacePredictor``.
4747
@@ -89,18 +89,18 @@ class HuggingFaceModel(FrameworkModel):
8989 _framework_name = "huggingface"
9090
9191 def __init__ (
92- self ,
93- model_data ,
94- role ,
95- entry_point ,
96- transformers_version = None ,
97- tensorflow_version = None ,
98- pytorch_version = None ,
99- py_version = None ,
100- image_uri = None ,
101- predictor_cls = HuggingFacePredictor ,
102- model_server_workers = None ,
103- ** kwargs
92+ self ,
93+ model_data ,
94+ role ,
95+ entry_point ,
96+ transformers_version = None ,
97+ tensorflow_version = None ,
98+ pytorch_version = None ,
99+ py_version = None ,
100+ image_uri = None ,
101+ predictor_cls = HuggingFacePredictor ,
102+ model_server_workers = None ,
103+ ** kwargs ,
104104 ):
105105 """Initialize a PyTorchModel.
106106
@@ -152,7 +152,11 @@ def __init__(
152152 :class:`~sagemaker.model.Model`.
153153 """
154154 validate_version_or_image_args (transformers_version , py_version , image_uri )
155- _validate_pt_tf_versions (pytorch_version = pytorch_version ,tensorflow_version = tensorflow_version ,image_uri = image_uri )
155+ _validate_pt_tf_versions (
156+ pytorch_version = pytorch_version ,
157+ tensorflow_version = tensorflow_version ,
158+ image_uri = image_uri ,
159+ )
156160 if py_version == "py2" :
157161 raise ValueError ("py2 is not supported with HuggingFace images" )
158162 self .framework_version = transformers_version
@@ -167,19 +171,19 @@ def __init__(
167171 self .model_server_workers = model_server_workers
168172
169173 def register (
170- self ,
171- content_types ,
172- response_types ,
173- inference_instances ,
174- transform_instances ,
175- model_package_name = None ,
176- model_package_group_name = None ,
177- image_uri = None ,
178- model_metrics = None ,
179- metadata_properties = None ,
180- marketplace_cert = False ,
181- approval_status = None ,
182- description = None ,
174+ self ,
175+ content_types ,
176+ response_types ,
177+ inference_instances ,
178+ transform_instances ,
179+ model_package_name = None ,
180+ model_package_group_name = None ,
181+ image_uri = None ,
182+ model_metrics = None ,
183+ metadata_properties = None ,
184+ marketplace_cert = False ,
185+ approval_status = None ,
186+ description = None ,
183187 ):
184188 """Creates a model package for creating SageMaker models or listing on Marketplace.
185189
@@ -290,9 +294,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
290294 f"tensorflow{ self .tensorflow_version } " # pylint: disable=no-member
291295 )
292296 else :
293- base_framework_version = (
294- f"pytorch{ self .pytorch_version } " # pylint: disable=no-member
295- )
297+ base_framework_version = f"pytorch{ self .pytorch_version } " # pylint: disable=no-member
296298 return image_uris .retrieve (
297299 self ._framework_name ,
298300 region_name ,
0 commit comments