Skip to content

Commit d40d43d

Browse files
committed
fix: formatting and linting
1 parent 2405769 commit d40d43d

File tree

3 files changed

+44
-121
lines changed

3 files changed

+44
-121
lines changed

src/strands/models/bedrock.py

Lines changed: 22 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,11 @@ def __init__(
136136
else:
137137
new_user_agent = "strands-agents"
138138

139-
client_config = boto_client_config.merge(
140-
BotocoreConfig(user_agent_extra=new_user_agent)
141-
)
139+
client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
142140
else:
143141
client_config = BotocoreConfig(user_agent_extra="strands-agents")
144142

145-
resolved_region = (
146-
region_name
147-
or session.region_name
148-
or os.environ.get("AWS_REGION")
149-
or DEFAULT_BEDROCK_REGION
150-
)
143+
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
151144

152145
self.client = session.client(
153146
service_name="bedrock-runtime",
@@ -156,9 +149,7 @@ def __init__(
156149
region_name=resolved_region,
157150
)
158151

159-
logger.debug(
160-
"region=<%s> | bedrock client created", self.client.meta.region_name
161-
)
152+
logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)
162153

163154
@override
164155
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
@@ -199,11 +190,7 @@ def format_request(
199190
"messages": self._format_bedrock_messages(messages),
200191
"system": [
201192
*([{"text": system_prompt}] if system_prompt else []),
202-
*(
203-
[{"cachePoint": {"type": self.config["cache_prompt"]}}]
204-
if self.config.get("cache_prompt")
205-
else []
206-
),
193+
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
207194
],
208195
**(
209196
{
@@ -223,20 +210,12 @@ def format_request(
223210
else {}
224211
),
225212
**(
226-
{
227-
"additionalModelRequestFields": self.config[
228-
"additional_request_fields"
229-
]
230-
}
213+
{"additionalModelRequestFields": self.config["additional_request_fields"]}
231214
if self.config.get("additional_request_fields")
232215
else {}
233216
),
234217
**(
235-
{
236-
"additionalModelResponseFieldPaths": self.config[
237-
"additional_response_field_paths"
238-
]
239-
}
218+
{"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]}
240219
if self.config.get("additional_response_field_paths")
241220
else {}
242221
),
@@ -247,18 +226,13 @@ def format_request(
247226
"guardrailVersion": self.config["guardrail_version"],
248227
"trace": self.config.get("guardrail_trace", "enabled"),
249228
**(
250-
{
251-
"streamProcessingMode": self.config.get(
252-
"guardrail_stream_processing_mode"
253-
)
254-
}
229+
{"streamProcessingMode": self.config.get("guardrail_stream_processing_mode")}
255230
if self.config.get("guardrail_stream_processing_mode")
256231
else {}
257232
),
258233
}
259234
}
260-
if self.config.get("guardrail_id")
261-
and self.config.get("guardrail_version")
235+
if self.config.get("guardrail_id") and self.config.get("guardrail_version")
262236
else {}
263237
),
264238
"inferenceConfig": {
@@ -273,14 +247,11 @@ def format_request(
273247
},
274248
**(
275249
self.config["additional_args"]
276-
if "additional_args" in self.config
277-
and self.config["additional_args"] is not None
250+
if "additional_args" in self.config and self.config["additional_args"] is not None
278251
else {}
279252
),
280253
}
281254

282-
283-
284255
def _format_bedrock_messages(self, messages: Messages) -> Messages:
285256
"""Format messages for Bedrock API compatibility.
286257
@@ -312,7 +283,7 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
312283
# DeepSeek models have issues with reasoningContent
313284
if is_deepseek and "reasoningContent" in content_block:
314285
continue
315-
286+
316287
if "toolResult" in content_block:
317288
# Create a new content block with only the cleaned toolResult
318289
tool_result: ToolResult = content_block["toolResult"]
@@ -332,9 +303,7 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
332303

333304
# Create new message with cleaned content (skip if empty for DeepSeek)
334305
if cleaned_content:
335-
cleaned_message: Message = Message(
336-
content=cleaned_content, role=message["role"]
337-
)
306+
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
338307
cleaned_messages.append(cleaned_message)
339308

340309
return cleaned_messages
@@ -352,17 +321,11 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
352321
output_assessments = guardrail_data.get("outputAssessments", {})
353322

354323
# Check input assessments
355-
if any(
356-
self._find_detected_and_blocked_policy(assessment)
357-
for assessment in input_assessment.values()
358-
):
324+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
359325
return True
360326

361327
# Check output assessments
362-
if any(
363-
self._find_detected_and_blocked_policy(assessment)
364-
for assessment in output_assessments.values()
365-
):
328+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
366329
return True
367330

368331
return False
@@ -437,9 +400,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
437400
loop = asyncio.get_event_loop()
438401
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
439402

440-
thread = asyncio.to_thread(
441-
self._stream, callback, messages, tool_specs, system_prompt
442-
)
403+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
443404
task = asyncio.create_task(thread)
444405

445406
while True:
@@ -451,8 +412,6 @@ def callback(event: Optional[StreamEvent] = None) -> None:
451412

452413
await task
453414

454-
455-
456415
def _stream(
457416
self,
458417
callback: Callable[..., None],
@@ -538,10 +497,7 @@ def _stream(
538497
if e.response["Error"]["Code"] == "ThrottlingException":
539498
raise ModelThrottledException(error_message) from e
540499

541-
if any(
542-
overflow_message in error_message
543-
for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES
544-
):
500+
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
545501
logger.warning("bedrock threw context window overflow error")
546502
raise ContextWindowOverflowException(e) from e
547503

@@ -577,9 +533,7 @@ def _stream(
577533
callback()
578534
logger.debug("finished streaming response from model")
579535

580-
def _convert_non_streaming_to_streaming(
581-
self, response: dict[str, Any]
582-
) -> Iterable[StreamEvent]:
536+
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
583537
"""Convert a non-streaming response to the streaming format.
584538
585539
Args:
@@ -609,9 +563,7 @@ def _convert_non_streaming_to_streaming(
609563
# For tool use, we need to yield the input as a delta
610564
input_value = json.dumps(content["toolUse"]["input"])
611565

612-
yield {
613-
"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}
614-
}
566+
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
615567
elif "text" in content:
616568
# Then yield the text as a delta
617569
yield {
@@ -623,13 +575,7 @@ def _convert_non_streaming_to_streaming(
623575
# Then yield the reasoning content as a delta
624576
yield {
625577
"contentBlockDelta": {
626-
"delta": {
627-
"reasoningContent": {
628-
"text": content["reasoningContent"]["reasoningText"][
629-
"text"
630-
]
631-
}
632-
}
578+
"delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}}
633579
}
634580
}
635581

@@ -638,9 +584,7 @@ def _convert_non_streaming_to_streaming(
638584
"contentBlockDelta": {
639585
"delta": {
640586
"reasoningContent": {
641-
"signature": content["reasoningContent"][
642-
"reasoningText"
643-
]["signature"]
587+
"signature": content["reasoningContent"]["reasoningText"]["signature"]
644588
}
645589
}
646590
}
@@ -707,11 +651,7 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
707651
# Check if input is a dictionary
708652
if isinstance(input, dict):
709653
# Check if current dictionary has action: BLOCKED and detected: true
710-
if (
711-
input.get("action") == "BLOCKED"
712-
and input.get("detected")
713-
and isinstance(input.get("detected"), bool)
714-
):
654+
if input.get("action") == "BLOCKED" and input.get("detected") and isinstance(input.get("detected"), bool):
715655
return True
716656

717657
# Recursively check all values in the dictionary
@@ -762,9 +702,7 @@ async def structured_output(
762702
stop_reason, messages, _, _ = event["stop"]
763703

764704
if stop_reason != "tool_use":
765-
raise ValueError(
766-
f'Model returned stop_reason: {stop_reason} instead of "tool_use".'
767-
)
705+
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
768706

769707
content = messages["content"]
770708
output_response: dict[str, Any] | None = None
@@ -777,8 +715,6 @@ async def structured_output(
777715
continue
778716

779717
if output_response is None:
780-
raise ValueError(
781-
"No valid tool use or tool use input was found in the Bedrock response."
782-
)
718+
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
783719

784720
yield {"output": output_model(**output_response)}

src/strands/tools/registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ def register_tool(self, tool: AgentTool) -> None:
192192

193193
# Check duplicate tool name, throw on duplicate tool names except if hot_reloading is enabled
194194
if tool.tool_name in self.registry and not tool.supports_hot_reload:
195-
raise ValueError(
196-
f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name."
197-
)
195+
raise ValueError(
196+
f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name."
197+
)
198198

199199
# Check for normalized name conflicts (- vs _)
200200
if self.registry.get(tool.tool_name) is None:

tests/strands/models/test_bedrock.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,63 +1307,50 @@ def test_format_request_cleans_tool_result_content_blocks(model, model_id):
13071307
async def test_stream_deepseek_filters_reasoning_content(bedrock_client, alist):
13081308
"""Test that DeepSeek models filter reasoningContent from messages during streaming."""
13091309
model = BedrockModel(model_id="us.deepseek.r1-v1:0")
1310-
1310+
13111311
messages = [
1312+
{"role": "user", "content": [{"text": "Hello"}]},
13121313
{
1313-
"role": "user",
1314-
"content": [{"text": "Hello"}]
1315-
},
1316-
{
1317-
"role": "assistant",
1314+
"role": "assistant",
13181315
"content": [
13191316
{"text": "Response"},
13201317
{"reasoningContent": {"reasoningText": {"text": "Thinking..."}}},
1321-
]
1322-
}
1318+
],
1319+
},
13231320
]
1324-
1321+
13251322
bedrock_client.converse_stream.return_value = {"stream": []}
1326-
1323+
13271324
await alist(model.stream(messages))
1328-
1325+
13291326
# Verify the request was made with filtered messages (no reasoningContent)
13301327
call_args = bedrock_client.converse_stream.call_args[1]
13311328
sent_messages = call_args["messages"]
1332-
1329+
13331330
assert len(sent_messages) == 2
13341331
assert sent_messages[0]["content"] == [{"text": "Hello"}]
13351332
assert sent_messages[1]["content"] == [{"text": "Response"}]
13361333

1334+
13371335
@pytest.mark.asyncio
13381336
async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist):
13391337
"""Test that DeepSeek models skip messages that would be empty after filtering reasoningContent."""
13401338
model = BedrockModel(model_id="us.deepseek.r1-v1:0")
1341-
1339+
13421340
messages = [
1343-
{
1344-
"role": "user",
1345-
"content": [{"text": "Hello"}]
1346-
},
1347-
{
1348-
"role": "assistant",
1349-
"content": [
1350-
{"reasoningContent": {"reasoningText": {"text": "Only reasoning..."}}}
1351-
]
1352-
},
1353-
{
1354-
"role": "user",
1355-
"content": [{"text": "Follow up"}]
1356-
}
1341+
{"role": "user", "content": [{"text": "Hello"}]},
1342+
{"role": "assistant", "content": [{"reasoningContent": {"reasoningText": {"text": "Only reasoning..."}}}]},
1343+
{"role": "user", "content": [{"text": "Follow up"}]},
13571344
]
1358-
1345+
13591346
bedrock_client.converse_stream.return_value = {"stream": []}
1360-
1347+
13611348
await alist(model.stream(messages))
1362-
1349+
13631350
# Verify the request was made with only non-empty messages
13641351
call_args = bedrock_client.converse_stream.call_args[1]
13651352
sent_messages = call_args["messages"]
1366-
1353+
13671354
assert len(sent_messages) == 2
13681355
assert sent_messages[0]["content"] == [{"text": "Hello"}]
1369-
assert sent_messages[1]["content"] == [{"text": "Follow up"}]
1356+
assert sent_messages[1]["content"] == [{"text": "Follow up"}]

0 commit comments

Comments
 (0)