Skip to content

Commit b7575f3

Browse files
authored
aws[patch]: support snake_case arguments in tools in ChatBedrockConverse (#441)
Following #437
1 parent ba6fc73 commit b7575f3

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def _generate(
592592
logger.debug(f"input message to bedrock: {bedrock_messages}")
593593
logger.debug(f"System message to bedrock: {system}")
594594
params = self._converse_params(
595-
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema", "function"})
595+
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema", "properties"})
596596
)
597597
logger.debug(f"Input params: {params}")
598598
logger.info("Using Bedrock Converse API to generate response")
@@ -613,7 +613,7 @@ def _stream(
613613
) -> Iterator[ChatGenerationChunk]:
614614
bedrock_messages, system = _messages_to_bedrock(messages)
615615
params = self._converse_params(
616-
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema", "function"})
616+
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema", "properties"})
617617
)
618618
response = self.client.converse_stream(
619619
messages=bedrock_messages, system=system, **params

libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Standard LangChain interface tests"""
22

3-
from typing import Literal, Type
3+
from typing import Literal, Optional, Type
44

55
import pytest
66
from langchain_core.exceptions import OutputParserException
77
from langchain_core.language_models import BaseChatModel
8-
from langchain_core.messages import HumanMessage
8+
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
99
from langchain_core.tools import BaseTool
1010
from langchain_tests.integration_tests import ChatModelIntegrationTests
1111
from pydantic import BaseModel, Field
@@ -181,6 +181,30 @@ def test_structured_output_snake_case() -> None:
181181
assert isinstance(chunk, ClassifyQuery)
182182

183183

184+
def test_tool_calling_snake_case() -> None:
185+
model = ChatBedrockConverse(model="anthropic.claude-3-sonnet-20240229-v1:0")
186+
187+
def classify_query(query_type: Literal["cat", "dog"]) -> None:
188+
pass
189+
190+
chat = model.bind_tools([classify_query], tool_choice="any")
191+
response = chat.invoke("How big are cats?")
192+
assert isinstance(response, AIMessage)
193+
assert len(response.tool_calls) == 1
194+
tool_call = response.tool_calls[0]
195+
assert tool_call["name"] == "classify_query"
196+
assert tool_call["args"] == {"query_type": "cat"}
197+
198+
full = None
199+
for chunk in chat.stream("How big are cats?"):
200+
full = chunk if full is None else full + chunk # type: ignore[assignment]
201+
assert isinstance(full, AIMessageChunk)
202+
assert len(full.tool_calls) == 1
203+
tool_call = full.tool_calls[0]
204+
assert tool_call["name"] == "classify_query"
205+
assert tool_call["args"] == {"query_type": "cat"}
206+
207+
184208
def test_structured_output_streaming() -> None:
185209
model = ChatBedrockConverse(
186210
model="anthropic.claude-3-sonnet-20240229-v1:0", temperature=0

0 commit comments

Comments
 (0)