Skip to content

Commit 1b83c5f

Browse files
authored
structured output - multi-modal input (#405)
1 parent c412292 commit 1b83c5f

File tree

7 files changed

+156
-23
lines changed

7 files changed

+156
-23
lines changed

src/strands/agent/agent.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,13 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
380380

381381
return cast(AgentResult, event["result"])
382382

383-
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
383+
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
384384
"""This method allows you to get structured output from the agent.
385385
386386
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
387387
If you don't pass in a prompt, it will use only the conversation history to respond.
388388
389-
For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
389+
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
390390
instruct the model to output the structured data.
391391
392392
Args:
@@ -405,13 +405,15 @@ def execute() -> T:
405405
future = executor.submit(execute)
406406
return future.result()
407407

408-
async def structured_output_async(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
408+
async def structured_output_async(
409+
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
410+
) -> T:
409411
"""This method allows you to get structured output from the agent.
410412
411413
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
412414
If you don't pass in a prompt, it will use only the conversation history to respond.
413415
414-
For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
416+
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
415417
instruct the model to output the structured data.
416418
417419
Args:
@@ -430,7 +432,8 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[
430432

431433
# add the prompt as the last message
432434
if prompt:
433-
self._append_message({"role": "user", "content": [{"text": prompt}]})
435+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
436+
self._append_message({"role": "user", "content": content})
434437

435438
events = self.model.structured_output(output_model, self.messages)
436439
async for event in events:

tests/strands/agent/test_agent.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,28 @@ def test_agent_structured_output(agent, user, agenerator):
959959
agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}])
960960

961961

962+
def test_agent_structured_output_multi_modal_input(agent, user, agenerator):
963+
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
964+
965+
prompt = [
966+
{"text": "Please describe the user in this image"},
967+
{
968+
"image": {
969+
"format": "png",
970+
"source": {
971+
"bytes": b"\x89PNG\r\n\x1a\n",
972+
},
973+
}
974+
},
975+
]
976+
977+
tru_result = agent.structured_output(type(user), prompt)
978+
exp_result = user
979+
assert tru_result == exp_result
980+
981+
agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": prompt}])
982+
983+
962984
@pytest.mark.asyncio
963985
async def test_agent_structured_output_in_async_context(agent, user, agenerator):
964986
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

tests_integ/models/test_model_anthropic.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
pytestmark = providers.anthropic.mark
1313

1414

15-
@pytest.fixture(scope="module")
15+
@pytest.fixture
1616
def model():
1717
return AnthropicModel(
1818
client_args={
@@ -23,7 +23,7 @@ def model():
2323
)
2424

2525

26-
@pytest.fixture(scope="module")
26+
@pytest.fixture
2727
def tools():
2828
@strands.tool
2929
def tool_time() -> str:
@@ -36,17 +36,17 @@ def tool_weather() -> str:
3636
return [tool_time, tool_weather]
3737

3838

39-
@pytest.fixture(scope="module")
39+
@pytest.fixture
4040
def system_prompt():
4141
return "You are an AI assistant."
4242

4343

44-
@pytest.fixture(scope="module")
44+
@pytest.fixture
4545
def agent(model, tools, system_prompt):
4646
return Agent(model=model, tools=tools, system_prompt=system_prompt)
4747

4848

49-
@pytest.fixture(scope="module")
49+
@pytest.fixture
5050
def weather():
5151
class Weather(BaseModel):
5252
"""Extracts the time and weather from the user's message with the exact strings."""
@@ -57,6 +57,16 @@ class Weather(BaseModel):
5757
return Weather(time="12:00", weather="sunny")
5858

5959

60+
@pytest.fixture
61+
def yellow_color():
62+
class Color(BaseModel):
63+
"""Describes a color."""
64+
65+
name: str
66+
67+
return Color(name="yellow")
68+
69+
6070
def test_agent_invoke(agent):
6171
result = agent("What is the time and weather in New York?")
6272
text = result.message["content"][0]["text"].lower()
@@ -97,7 +107,7 @@ async def test_agent_structured_output_async(agent, weather):
97107
assert tru_weather == exp_weather
98108

99109

100-
def test_multi_modal_input(agent, yellow_img):
110+
def test_invoke_multi_modal_input(agent, yellow_img):
101111
content = [
102112
{"text": "what is in this image"},
103113
{
@@ -113,3 +123,20 @@ def test_multi_modal_input(agent, yellow_img):
113123
text = result.message["content"][0]["text"].lower()
114124

115125
assert "yellow" in text
126+
127+
128+
def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
129+
content = [
130+
{"text": "Is this image red, blue, or yellow?"},
131+
{
132+
"image": {
133+
"format": "png",
134+
"source": {
135+
"bytes": yellow_img,
136+
},
137+
},
138+
},
139+
]
140+
tru_color = agent.structured_output(type(yellow_color), content)
141+
exp_color = yellow_color
142+
assert tru_color == exp_color

tests_integ/models/test_model_bedrock.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ def non_streaming_agent(non_streaming_model, system_prompt):
3737
return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
3838

3939

40+
@pytest.fixture
41+
def yellow_color():
42+
class Color(BaseModel):
43+
"""Describes a color."""
44+
45+
name: str
46+
47+
return Color(name="yellow")
48+
49+
4050
def test_streaming_agent(streaming_agent):
4151
"""Test agent with streaming model."""
4252
result = streaming_agent("Hello!")
@@ -153,7 +163,7 @@ class Weather(BaseModel):
153163
assert result.weather == "sunny"
154164

155165

156-
def test_multi_modal_input(streaming_agent, yellow_img):
166+
def test_invoke_multi_modal_input(streaming_agent, yellow_img):
157167
content = [
158168
{"text": "what is in this image"},
159169
{
@@ -169,3 +179,20 @@ def test_multi_modal_input(streaming_agent, yellow_img):
169179
text = result.message["content"][0]["text"].lower()
170180

171181
assert "yellow" in text
182+
183+
184+
def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color):
185+
content = [
186+
{"text": "Is this image red, blue, or yellow?"},
187+
{
188+
"image": {
189+
"format": "png",
190+
"source": {
191+
"bytes": yellow_img,
192+
},
193+
},
194+
},
195+
]
196+
tru_color = streaming_agent.structured_output(type(yellow_color), content)
197+
exp_color = yellow_color
198+
assert tru_color == exp_color

tests_integ/models/test_model_litellm.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def agent(model, tools):
2929
return Agent(model=model, tools=tools)
3030

3131

32+
@pytest.fixture
33+
def yellow_color():
34+
class Color(BaseModel):
35+
"""Describes a color."""
36+
37+
name: str
38+
39+
return Color(name="yellow")
40+
41+
3242
def test_agent(agent):
3343
result = agent("What is the time and weather in New York?")
3444
text = result.message["content"][0]["text"].lower()
@@ -49,9 +59,9 @@ class Weather(BaseModel):
4959
assert result.weather == "sunny"
5060

5161

52-
def test_multi_modal_input(agent, yellow_img):
62+
def test_invoke_multi_modal_input(agent, yellow_img):
5363
content = [
54-
{"text": "what is in this image"},
64+
{"text": "Is this image red, blue, or yellow?"},
5565
{
5666
"image": {
5767
"format": "png",
@@ -65,3 +75,20 @@ def test_multi_modal_input(agent, yellow_img):
6575
text = result.message["content"][0]["text"].lower()
6676

6777
assert "yellow" in text
78+
79+
80+
def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
81+
content = [
82+
{"text": "what is in this image"},
83+
{
84+
"image": {
85+
"format": "png",
86+
"source": {
87+
"bytes": yellow_img,
88+
},
89+
},
90+
},
91+
]
92+
tru_color = agent.structured_output(type(yellow_color), content)
93+
exp_color = yellow_color
94+
assert tru_color == exp_color

tests_integ/models/test_model_ollama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
pytestmark = providers.ollama.mark
1111

1212

13-
@pytest.fixture(scope="module")
13+
@pytest.fixture
1414
def model():
1515
return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b")
1616

1717

18-
@pytest.fixture(scope="module")
18+
@pytest.fixture
1919
def tools():
2020
@strands.tool
2121
def tool_time() -> str:
@@ -28,12 +28,12 @@ def tool_weather() -> str:
2828
return [tool_time, tool_weather]
2929

3030

31-
@pytest.fixture(scope="module")
31+
@pytest.fixture
3232
def agent(model, tools):
3333
return Agent(model=model, tools=tools)
3434

3535

36-
@pytest.fixture(scope="module")
36+
@pytest.fixture
3737
def weather():
3838
class Weather(BaseModel):
3939
"""Extracts the time and weather from the user's message with the exact strings."""

tests_integ/models/test_model_openai.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
pytestmark = providers.openai.mark
1313

1414

15-
@pytest.fixture(scope="module")
15+
@pytest.fixture
1616
def model():
1717
return OpenAIModel(
1818
model_id="gpt-4o",
@@ -22,7 +22,7 @@ def model():
2222
)
2323

2424

25-
@pytest.fixture(scope="module")
25+
@pytest.fixture
2626
def tools():
2727
@strands.tool
2828
def tool_time() -> str:
@@ -35,12 +35,12 @@ def tool_weather() -> str:
3535
return [tool_time, tool_weather]
3636

3737

38-
@pytest.fixture(scope="module")
38+
@pytest.fixture
3939
def agent(model, tools):
4040
return Agent(model=model, tools=tools)
4141

4242

43-
@pytest.fixture(scope="module")
43+
@pytest.fixture
4444
def weather():
4545
class Weather(BaseModel):
4646
"""Extracts the time and weather from the user's message with the exact strings."""
@@ -51,6 +51,16 @@ class Weather(BaseModel):
5151
return Weather(time="12:00", weather="sunny")
5252

5353

54+
@pytest.fixture
55+
def yellow_color():
56+
class Color(BaseModel):
57+
"""Describes a color."""
58+
59+
name: str
60+
61+
return Color(name="yellow")
62+
63+
5464
@pytest.fixture(scope="module")
5565
def test_image_path(request):
5666
return request.config.rootpath / "tests_integ" / "test_image.png"
@@ -96,7 +106,7 @@ async def test_agent_structured_output_async(agent, weather):
96106
assert tru_weather == exp_weather
97107

98108

99-
def test_multi_modal_input(agent, yellow_img):
109+
def test_invoke_multi_modal_input(agent, yellow_img):
100110
content = [
101111
{"text": "what is in this image"},
102112
{
@@ -114,6 +124,23 @@ def test_multi_modal_input(agent, yellow_img):
114124
assert "yellow" in text
115125

116126

127+
def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
128+
content = [
129+
{"text": "Is this image red, blue, or yellow?"},
130+
{
131+
"image": {
132+
"format": "png",
133+
"source": {
134+
"bytes": yellow_img,
135+
},
136+
},
137+
},
138+
]
139+
tru_color = agent.structured_output(type(yellow_color), content)
140+
exp_color = yellow_color
141+
assert tru_color == exp_color
142+
143+
117144
@pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320")
118145
def test_tool_returning_images(model, yellow_img):
119146
@tool

0 commit comments

Comments
 (0)