Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion libs/oci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ embeddings.embed_query("What is the meaning of life?")
```

### 4. Use Structured Output
`ChatOCIGenAI` supports structured output.
`ChatOCIGenAI` supports structured output.

<sub>**Note:** The default method is `function_calling`. If default method returns `None` (e.g. for Gemini models), try `json_schema` or `json_mode`.</sub>

Expand All @@ -79,6 +79,30 @@ structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about programming")
```

### 5. Use Parallel Tool Calling (Meta/Llama models only)
Enable parallel tool calling to execute multiple tools simultaneously, improving performance for multi-tool workflows.

```python
from langchain_oci import ChatOCIGenAI

# Option 1: Set at class level for all tool bindings
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_COMPARTMENT_ID",
parallel_tool_calls=True # Enable parallel tool calling
)

# Option 2: Set per-binding
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct")
llm_with_tools = llm.bind_tools(
[get_weather, calculate_tip, get_population],
parallel_tool_calls=True # Tools can execute simultaneously
)
```

<sub>**Note:** Parallel tool calling is only supported for Meta/Llama models. Cohere models will raise an error if this parameter is used.</sub>


## OCI Data Science Model Deployment Examples

Expand Down
23 changes: 23 additions & 0 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ def messages_to_oci_params(

This includes conversion of chat history and tool call results.
"""
# Cohere models don't support parallel tool calls
if kwargs.get("is_parallel_tool_calls"):
raise ValueError(
"Parallel tool calls are not supported for Cohere models. "
"This feature is only available for Meta/Llama models using GenericChatRequest."
)

is_force_single_step = kwargs.get("is_force_single_step", False)
oci_chat_history = []

Expand Down Expand Up @@ -829,6 +836,10 @@ def _should_allow_more_tool_calls(
result["tool_choice"] = self.oci_tool_choice_none()
# else: Allow model to decide (default behavior)

# Add parallel tool calls support for Meta/Llama models
if "is_parallel_tool_calls" in kwargs:
result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"]

return result

def _process_message_content(
Expand Down Expand Up @@ -1186,6 +1197,7 @@ def bind_tools(
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
parallel_tool_calls: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Expand All @@ -1206,6 +1218,11 @@ def bind_tools(
{"type": "function", "function": {"name": <<tool_name>>}}:
calls <<tool_name>> tool.
- False or None: no effect, default Meta behavior.
parallel_tool_calls: Whether to enable parallel function calling.
If True, the model can call multiple tools simultaneously.
If False, tools are called sequentially.
If None (default), uses the class-level parallel_tool_calls setting.
Only supported for Meta/Llama models using GenericChatRequest.
kwargs: Any additional parameters are passed directly to
:meth:`~langchain_oci.chat_models.oci_generative_ai.ChatOCIGenAI.bind`.
"""
Expand All @@ -1215,6 +1232,12 @@ def bind_tools(
if tool_choice is not None:
kwargs["tool_choice"] = self._provider.process_tool_choice(tool_choice)

# Add parallel tool calls support
# Use bind-time parameter if provided, else fall back to class default
use_parallel = parallel_tool_calls if parallel_tool_calls is not None else self.parallel_tool_calls
if use_parallel:
kwargs["is_parallel_tool_calls"] = True

return super().bind(tools=formatted_tools, **kwargs)

def with_structured_output(
Expand Down
6 changes: 6 additions & 0 deletions libs/oci/langchain_oci/llms/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class OCIGenAIBase(BaseModel, ABC):
"""Maximum tool calls before forcing final answer.
Prevents infinite loops while allowing multi-step orchestration."""

parallel_tool_calls: bool = False
"""Whether to enable parallel function calling during tool use.
If True, the model can call multiple tools simultaneously.
Only supported for Meta/Llama models using GenericChatRequest.
Default: False for backward compatibility."""

model_config = ConfigDict(
extra="forbid", arbitrary_types_allowed=True, protected_namespaces=()
)
Expand Down
199 changes: 199 additions & 0 deletions libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Unit tests for parallel tool calling feature."""
import pytest
from unittest.mock import MagicMock

from langchain_core.messages import HumanMessage
from langchain_oci.chat_models import ChatOCIGenAI


@pytest.mark.requires("oci")
def test_parallel_tool_calls_class_level():
"""Test class-level parallel_tool_calls parameter."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
parallel_tool_calls=True,
client=oci_gen_ai_client
)
assert llm.parallel_tool_calls is True


@pytest.mark.requires("oci")
def test_parallel_tool_calls_default_false():
"""Test that parallel_tool_calls defaults to False."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
client=oci_gen_ai_client
)
assert llm.parallel_tool_calls is False


@pytest.mark.requires("oci")
def test_parallel_tool_calls_bind_tools_explicit_true():
"""Test parallel_tool_calls=True in bind_tools."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

def tool2(x: int) -> int:
"""Tool 2."""
return x * 2

llm_with_tools = llm.bind_tools(
[tool1, tool2],
parallel_tool_calls=True
)

assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True


@pytest.mark.requires("oci")
def test_parallel_tool_calls_bind_tools_explicit_false():
"""Test parallel_tool_calls=False in bind_tools."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

llm_with_tools = llm.bind_tools(
[tool1],
parallel_tool_calls=False
)

# When explicitly False, should not set the parameter
assert "is_parallel_tool_calls" not in llm_with_tools.kwargs


@pytest.mark.requires("oci")
def test_parallel_tool_calls_bind_tools_uses_class_default():
"""Test that bind_tools uses class default when not specified."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
parallel_tool_calls=True, # Set class default
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

# Don't specify parallel_tool_calls in bind_tools
llm_with_tools = llm.bind_tools([tool1])

# Should use class default (True)
assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True


@pytest.mark.requires("oci")
def test_parallel_tool_calls_bind_tools_overrides_class_default():
"""Test that bind_tools parameter overrides class default."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
parallel_tool_calls=True, # Set class default to True
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

# Override with False in bind_tools
llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=False)

# Should not set the parameter when explicitly False
assert "is_parallel_tool_calls" not in llm_with_tools.kwargs


@pytest.mark.requires("oci")
def test_parallel_tool_calls_passed_to_oci_api_meta():
"""Test that is_parallel_tool_calls is passed to OCI API for Meta models."""
from oci.generative_ai_inference import models

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
client=oci_gen_ai_client
)

def get_weather(city: str) -> str:
"""Get weather for a city."""
return f"Weather in {city}"

llm_with_tools = llm.bind_tools([get_weather], parallel_tool_calls=True)

# Prepare a request
request = llm_with_tools._prepare_request(
[HumanMessage(content="What's the weather?")],
stop=None,
stream=False,
**llm_with_tools.kwargs
)

# Verify is_parallel_tool_calls is in the request
assert hasattr(request.chat_request, 'is_parallel_tool_calls')
assert request.chat_request.is_parallel_tool_calls is True


@pytest.mark.requires("oci")
def test_parallel_tool_calls_cohere_raises_error():
"""Test that Cohere models raise error for parallel tool calls."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="cohere.command-r-plus",
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True)

# Should raise ValueError when trying to prepare request
with pytest.raises(ValueError, match="not supported for Cohere"):
llm_with_tools._prepare_request(
[HumanMessage(content="test")],
stop=None,
stream=False,
**llm_with_tools.kwargs
)


@pytest.mark.requires("oci")
def test_parallel_tool_calls_cohere_class_level_raises_error():
"""Test that Cohere models with class-level parallel_tool_calls raise error."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="cohere.command-r-plus",
parallel_tool_calls=True, # Set at class level
client=oci_gen_ai_client
)

def tool1(x: int) -> int:
"""Tool 1."""
return x + 1

llm_with_tools = llm.bind_tools([tool1]) # Uses class default

# Should raise ValueError when trying to prepare request
with pytest.raises(ValueError, match="not supported for Cohere"):
llm_with_tools._prepare_request(
[HumanMessage(content="test")],
stop=None,
stream=False,
**llm_with_tools.kwargs
)