Skip to content

Commit cd9df32

Browse files
committed
Initial fixes to PR change requests
1 parent bf06097 commit cd9df32

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

src/strands/models/sagemaker.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
- Docs: https://aws.amazon.com/sagemaker-ai/
44
"""
55

6+
import base64
67
import json
78
import logging
9+
import mimetypes
810
import uuid
911
from typing import Any, Iterable, Optional, TypedDict, Union
1012

@@ -38,21 +40,21 @@ class ModelConfig(TypedDict, total=False):
3840
"""Configuration options for SageMaker models.
3941
4042
Attributes:
41-
endpoint_name: The name of the SageMaker endpoint to invoke
42-
inference_component_name: The name of the inference component to use
43-
max_tokens: Maximum number of tokens to generate in the response
44-
stop_sequences: List of sequences that will stop generation when encountered
45-
temperature: Controls randomness in generation (higher = more random)
46-
top_p: Controls diversity via nucleus sampling (alternative to temperature)
47-
additional_args: Any additional arguments to include in the request
43+
additional_args: Any additional arguments to include in the request.
44+
endpoint_name: The name of the SageMaker endpoint to invoke.
45+
inference_component_name: The name of the inference component to use.
46+
max_tokens: Maximum number of tokens to generate in the response.
47+
stop_sequences: List of sequences that will stop generation when encountered.
48+
temperature: Controls randomness in generation (higher = more random).
49+
top_p: Controls diversity via nucleus sampling (alternative to temperature).
4850
"""
51+
additional_args: Optional[dict[str, Any]]
4952
endpoint_name: str
5053
inference_component_name: Optional[str]
5154
max_tokens: Optional[int]
5255
stop_sequences: Optional[list[str]]
5356
temperature: Optional[float]
5457
top_p: Optional[float]
55-
additional_args: Optional[dict[str, Any]]
5658

5759
def __init__(
5860
self,
@@ -61,8 +63,7 @@ def __init__(
6163
inference_component_name: Optional[str] = None,
6264
boto_session: Optional[boto3.Session] = None,
6365
boto_client_config: Optional[BotocoreConfig] = None,
64-
retry_attempts: int = 3,
65-
retry_delay: int = 30,
66+
region_name: Optional[str] = None,
6667
**model_config: Unpack["SageMakerAIModel.ModelConfig"],
6768
):
6869
"""Initialize provider instance.
@@ -81,16 +82,12 @@ def __init__(
8182
inference_component_name=inference_component_name
8283
)
8384
self.update_config(**model_config)
84-
85-
# Set retry configuration
86-
self.retry_attempts = retry_attempts
87-
self.retry_delay = retry_delay
8885

89-
logger.debug("endpoint=%s, config=%s | initializing", self.config["endpoint_name"], self.config)
86+
# logger.debug("endpoint=%s, config=%s | initializing", self.config["endpoint_name"], self.config)
87+
logger.debug("config=<%s> | initializing", self.config)
9088

91-
default_region = "us-west-2"
9289
session = boto_session or boto3.Session(
93-
region_name=default_region,
90+
region_name=region_name,
9491
)
9592
self.client = session.client(
9693
service_name="sagemaker-runtime",
@@ -135,11 +132,14 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
135132
return {"role": message["role"], "content": content["text"]}
136133

137134
if "image" in content:
138-
# Convert bytes to base64 string for JSON serialization
139-
image_bytes = content["image"]["source"]["bytes"]
140-
if isinstance(image_bytes, bytes):
141-
image_bytes = image_bytes.decode('utf-8') if isinstance(image_bytes, bytes) else image_bytes
142-
return {"role": message["role"], "images": [image_bytes]}
135+
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
136+
image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
137+
return {
138+
"image_url": {
139+
"url": f"data:{mime_type};base64,{image_data}",
140+
},
141+
"type": "image_url",
142+
}
143143

144144
if "toolUse" in content:
145145
return {

0 commit comments

Comments
 (0)