diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 02c3d9089..be96d55e2 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -72,7 +72,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf logger.debug("config=<%s> | initializing", self.config) client_args = client_args or {} - self.client = anthropic.Anthropic(**client_args) + self.client = anthropic.AsyncAnthropic(**client_args) @override def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override] @@ -358,8 +358,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] ModelThrottledException: If the request is throttled by Anthropic. """ try: - with self.client.messages.stream(**request) as stream: - for event in stream: + async with self.client.messages.stream(**request) as stream: + async for event in stream: if event.type in AnthropicModel.EVENT_TYPES: yield event.model_dump() diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 50033f8f7..142558105 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -8,7 +8,7 @@ from strands.models.anthropic import AnthropicModel -@pytest.fixture +@pytest.fixture(scope="module") def model(): return AnthropicModel( client_args={ @@ -19,7 +19,7 @@ def model(): ) -@pytest.fixture +@pytest.fixture(scope="module") def tools(): @strands.tool def tool_time() -> str: @@ -32,18 +32,29 @@ def tool_weather() -> str: return [tool_time, tool_weather] -@pytest.fixture +@pytest.fixture(scope="module") def system_prompt(): return "You are an AI assistant." -@pytest.fixture +@pytest.fixture(scope="module") def agent(model, tools, system_prompt): return Agent(model=model, tools=tools, system_prompt=system_prompt) +@pytest.fixture(scope="module") +def weather(): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") -def test_agent(agent): +def test_agent_invoke(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -51,13 +62,37 @@ def test_agent(agent): @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") -def test_structured_output(model): - class Weather(BaseModel): - time: str - weather: str +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) - agent = Agent(model=model) - result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") - assert isinstance(result, Weather) - assert result.time == "12:00" - assert result.weather == "sunny" + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 66046b7a8..fa1eb8616 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -11,7 +11,7 @@ @pytest.fixture def anthropic_client(): - with unittest.mock.patch.object(strands.models.anthropic.anthropic, "Anthropic") as mock_client_cls: + with unittest.mock.patch.object(strands.models.anthropic.anthropic, "AsyncAnthropic") as mock_client_cls: yield mock_client_cls.return_value @@ -625,7 +625,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(anthropic_client, model, alist): +async def test_stream(anthropic_client, model, agenerator, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -646,9 +646,9 @@ async def test_stream(anthropic_client, model, alist): ), ) - mock_stream = unittest.mock.MagicMock() - mock_stream.__iter__.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) - anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3]) + anthropic_client.messages.stream.return_value = mock_context request = {"model": "m1"} response = model.stream(request) @@ -705,7 +705,7 @@ async def test_stream_bad_request_error(anthropic_client, model): @pytest.mark.asyncio -async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): +async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -749,9 +749,9 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, ), ] - mock_stream = unittest.mock.MagicMock() - mock_stream.__iter__.return_value = iter(events) - anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = agenerator(events) + anthropic_client.messages.stream.return_value = mock_context stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream)