Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 88 additions & 10 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,7 @@ def messages_to_oci_params(
self.oci_chat_message[self.get_role(msg)](
tool_results=[
self.oci_tool_result(
call=self.oci_tool_call(
name=msg.name, parameters={}
),
call=self.oci_tool_call(name=msg.name, parameters={}),
outputs=[{"output": msg.content}],
)
],
Expand All @@ -381,9 +379,17 @@ def messages_to_oci_params(
for i, message in enumerate(messages[::-1]):
current_turn.append(message)
if isinstance(message, HumanMessage):
if len(messages) > i and isinstance(messages[len(messages) - i - 2], ToolMessage):
# add dummy message REPEATING the tool_result to avoid the error about ToolMessage needing to be followed by an AI message
oci_chat_history.append(self.oci_chat_message['CHATBOT'](message=messages[len(messages) - i - 2].content))
if len(messages) > i and isinstance(
messages[len(messages) - i - 2], ToolMessage
):
# add dummy message REPEATING the tool_result to avoid
# the error about ToolMessage needing to be followed
# by an AI message
oci_chat_history.append(
self.oci_chat_message["CHATBOT"](
message=messages[len(messages) - i - 2].content
)
)
break
current_turn = list(reversed(current_turn))

Expand Down Expand Up @@ -713,8 +719,8 @@ def messages_to_oci_params(
else:
oci_message = self.oci_chat_message[role](content=tool_content)
elif isinstance(message, AIMessage) and (
message.tool_calls or
message.additional_kwargs.get("tool_calls")):
message.tool_calls or message.additional_kwargs.get("tool_calls")
):
# Process content and tool calls for assistant messages
content = self._process_message_content(message.content)
tool_calls = []
Expand All @@ -736,11 +742,78 @@ def messages_to_oci_params(
oci_message = self.oci_chat_message[role](content=content)
oci_messages.append(oci_message)

return {
result = {
"messages": oci_messages,
"api_format": self.chat_api_format,
}

# BUGFIX: Intelligently manage tool_choice to prevent infinite loops
# while allowing legitimate multi-step tool orchestration.
# This addresses a known issue with Meta Llama models that
# continue calling tools even after receiving results.

def _should_allow_more_tool_calls(
messages: List[BaseMessage],
max_tool_calls: int
) -> bool:
"""
Determine if the model should be allowed to call more tools.

Returns False (force stop) if:
- Tool call limit exceeded
- Infinite loop detected (same tool called repeatedly with same args)

Returns True otherwise to allow multi-step tool orchestration.

Args:
messages: Conversation history
max_tool_calls: Maximum number of tool calls before forcing stop
"""
# Count total tool calls made so far
tool_call_count = sum(
1 for msg in messages
if isinstance(msg, ToolMessage)
)

# Safety limit: prevent runaway tool calling
if tool_call_count >= max_tool_calls:
return False

# Detect infinite loop: same tool called with same arguments in succession
recent_calls = []
for msg in reversed(messages):
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tc in msg.tool_calls:
# Create signature: (tool_name, sorted_args)
try:
args_str = json.dumps(tc.get('args', {}), sort_keys=True)
signature = (tc.get('name', ''), args_str)

# Check if this exact call was made in last 2 calls
if signature in recent_calls[-2:]:
return False # Infinite loop detected

recent_calls.append(signature)
except Exception:
# If we can't serialize args, be conservative and continue
pass

# Only check last 4 AI messages (last 4 tool call attempts)
if len(recent_calls) >= 4:
break

return True

has_tool_results = any(isinstance(msg, ToolMessage) for msg in messages)
if has_tool_results and "tools" in kwargs and "tool_choice" not in kwargs:
max_tool_calls = kwargs.get("max_sequential_tool_calls", 8)
if not _should_allow_more_tool_calls(messages, max_tool_calls):
# Force model to stop and provide final answer
result["tool_choice"] = self.oci_tool_choice_none()
# else: Allow model to decide (default behavior)

return result

def _process_message_content(
self, content: Union[str, List[Union[str, Dict]]]
) -> List[Any]:
Expand Down Expand Up @@ -934,6 +1007,7 @@ def process_stream_tool_calls(

class MetaProvider(GenericProvider):
"""Provider for Meta models. This provider is for backward compatibility."""

pass


Expand Down Expand Up @@ -1050,7 +1124,11 @@ def _prepare_request(
"Please make sure you have the oci package installed."
) from ex

oci_params = self._provider.messages_to_oci_params(messages, **kwargs)
oci_params = self._provider.messages_to_oci_params(
messages,
max_sequential_tool_calls=self.max_sequential_tool_calls,
**kwargs
)

oci_params["is_stream"] = stream
_model_kwargs = self.model_kwargs or {}
Expand Down
4 changes: 4 additions & 0 deletions libs/oci/langchain_oci/llms/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class OCIGenAIBase(BaseModel, ABC):
is_stream: bool = False
"""Whether to stream back partial progress"""

max_sequential_tool_calls: int = 8
"""Maximum tool calls before forcing final answer.
Prevents infinite loops while allowing multi-step orchestration."""

model_config = ConfigDict(
extra="forbid", arbitrary_types_allowed=True, protected_namespaces=()
)
Expand Down
Loading