33- Docs: https://aws.amazon.com/sagemaker-ai/
44"""
55
6+ import base64
67import json
78import logging
9+ import mimetypes
810import uuid
911from typing import Any , Iterable , Optional , TypedDict , Union
1012
@@ -38,21 +40,21 @@ class ModelConfig(TypedDict, total=False):
3840 """Configuration options for SageMaker models.
3941
4042 Attributes:
41- endpoint_name: The name of the SageMaker endpoint to invoke
42- inference_component_name : The name of the inference component to use
43- max_tokens: Maximum number of tokens to generate in the response
44- stop_sequences: List of sequences that will stop generation when encountered
45- temperature: Controls randomness in generation (higher = more random)
46- top_p : Controls diversity via nucleus sampling (alternative to temperature)
47- additional_args: Any additional arguments to include in the request
43+ additional_args: Any additional arguments to include in the request.
44+ endpoint_name : The name of the SageMaker endpoint to invoke.
45+ inference_component_name: The name of the inference component to use.
46+ max_tokens: Maximum number of tokens to generate in the response.
47+ stop_sequences: List of sequences that will stop generation when encountered.
48+ temperature : Controls randomness in generation (higher = more random).
49+ top_p: Controls diversity via nucleus sampling (alternative to temperature).
4850 """
51+ additional_args : Optional [dict [str , Any ]]
4952 endpoint_name : str
5053 inference_component_name : Optional [str ]
5154 max_tokens : Optional [int ]
5255 stop_sequences : Optional [list [str ]]
5356 temperature : Optional [float ]
5457 top_p : Optional [float ]
55- additional_args : Optional [dict [str , Any ]]
5658
5759 def __init__ (
5860 self ,
@@ -61,8 +63,7 @@ def __init__(
6163 inference_component_name : Optional [str ] = None ,
6264 boto_session : Optional [boto3 .Session ] = None ,
6365 boto_client_config : Optional [BotocoreConfig ] = None ,
64- retry_attempts : int = 3 ,
65- retry_delay : int = 30 ,
66+ region_name : Optional [str ] = None ,
6667 ** model_config : Unpack ["SageMakerAIModel.ModelConfig" ],
6768 ):
6869 """Initialize provider instance.
@@ -81,16 +82,12 @@ def __init__(
8182 inference_component_name = inference_component_name
8283 )
8384 self .update_config (** model_config )
84-
85- # Set retry configuration
86- self .retry_attempts = retry_attempts
87- self .retry_delay = retry_delay
8885
89- logger .debug ("endpoint=%s, config=%s | initializing" , self .config ["endpoint_name" ], self .config )
86+ # logger.debug("endpoint=%s, config=%s | initializing", self.config["endpoint_name"], self.config)
87+ logger .debug ("config=<%s> | initializing" , self .config )
9088
91- default_region = "us-west-2"
9289 session = boto_session or boto3 .Session (
93- region_name = default_region ,
90+ region_name = region_name ,
9491 )
9592 self .client = session .client (
9693 service_name = "sagemaker-runtime" ,
@@ -135,11 +132,14 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
135132 return {"role" : message ["role" ], "content" : content ["text" ]}
136133
137134 if "image" in content :
138- # Convert bytes to base64 string for JSON serialization
139- image_bytes = content ["image" ]["source" ]["bytes" ]
140- if isinstance (image_bytes , bytes ):
141- image_bytes = image_bytes .decode ('utf-8' ) if isinstance (image_bytes , bytes ) else image_bytes
142- return {"role" : message ["role" ], "images" : [image_bytes ]}
135+ mime_type = mimetypes .types_map .get (f".{ content ['image' ]['format' ]} " , "application/octet-stream" )
136+ image_data = base64 .b64encode (content ["image" ]["source" ]["bytes" ]).decode ("utf-8" )
137+ return {
138+ "image_url" : {
139+ "url" : f"data:{ mime_type } ;base64,{ image_data } " ,
140+ },
141+ "type" : "image_url" ,
142+ }
143143
144144 if "toolUse" in content :
145145 return {
0 commit comments