Skip to content

Commit 9a79fbd

Browse files
authored
feat: Add parallel tool calling support for Meta/Llama models (#59)
* feat: Add parallel tool calling support for Meta/Llama models Add support for the parallel_tool_calls parameter to enable parallel function calling in Meta/Llama models, improving performance for multi-tool workflows. - Add parallel_tool_calls class parameter to OCIGenAIBase (default: False) - Add parallel_tool_calls parameter to bind_tools() method - Support hybrid approach: class-level default + per-binding override - Pass is_parallel_tool_calls to OCI API in MetaProvider - Add validation for Cohere models (raises error if attempted) - 9 comprehensive unit tests (all passing) - 4 integration tests with live OCI API (all passing) - No regression in existing tests Class-level default: llm = ChatOCIGenAI( model_id="meta.llama-3.3-70b-instruct", parallel_tool_calls=True ) Per-binding override: llm_with_tools = llm.bind_tools( [tool1, tool2, tool3], parallel_tool_calls=True ) - Up to N× speedup for N independent tool calls - Backward compatible (default: False) - Clear error messages for unsupported models - Follows existing parameter patterns * Fix code formatting for line length compliance * Update documentation to reflect broader model support for parallel tool calling - Update README to include all GenericChatRequest models (Grok, OpenAI, Mistral) - Update code comments and docstrings - Update error messages with complete model list - Clarify that feature works with GenericChatRequest, not just Meta/Llama * Move integration test to correct folder structure Relocated test_parallel_tool_calling_integration.py to tests/integration_tests/chat_models/ Following repository convention for integration test organization * Add version filter for Llama parallel tool calling Only Llama 4+ models support parallel tool calling based on testing. Parallel tool calling support: - Llama 4+ - SUPPORTED (tested and verified with real OCI API) - ALL Llama 3.x (3.0, 3.1, 3.2, 3.3) - BLOCKED - Cohere - BLOCKED (existing behavior) - Other models (xAI Grok, OpenAI, Mistral) - SUPPORTED Implementation: - Added _supports_parallel_tool_calls() helper method with regex version parsing - Updated bind_tools() to validate model version before enabling parallel calls - Provides clear error messages: "only available for Llama 4+ models" Unit tests added (8 tests, all mocked, no OCI connection): - test_version_filter_llama_3_0_blocked - test_version_filter_llama_3_1_blocked - test_version_filter_llama_3_2_blocked - test_version_filter_llama_3_3_blocked (Llama 3.3 doesn't support it either) - test_version_filter_llama_4_allowed - test_version_filter_other_models_allowed - test_version_filter_supports_parallel_tool_calls_method - Plus existing parallel tool calling tests updated to use Llama 4 * Fix linting issues after rebase - Fix line length violations in chat_models and llms - Replace print statements with logging in integration tests - Fix import sorting and remove unused imports - Fix unused variable in test * Fix remaining linting issues in test files * Move parallel tool call validation from bind_tools to provider - Validation now happens at request preparation time - Cohere validation remains in CohereProvider - Llama 3.x validation added to GenericProvider - Fixes failing unit tests * Add Llama 3.x validation at bind_tools time - Llama 3.x validation happens early at bind_tools time - Cohere validation happens at provider level (_prepare_request time) - All 16 parallel tool calling tests now pass * Fix line length issue in bind_tools validation * Apply ruff formatting to parallel tool calling tests * Move parallel_tool_calls to bind_tools only (remove class-level param) * Update integration tests for bind_tools-only parallel_tool_calls * Fix README to show bind_tools-only parallel_tool_calls usage * Fix mypy type errors for LangChain 1.x compatibility - Add type: ignore[override] to bind_tools methods in oci_data_science.py and oci_generative_ai.py to handle signature incompatibility with BaseChatModel parent class - Remove unused type: ignore comments in oci_generative_ai.py - Add type: ignore[attr-defined] comments for RunnableBinding runtime attributes (kwargs, _prepare_request) in test_parallel_tool_calling.py - Fix test_parallel_tool_calling_integration.py to use getattr for tool_calls attribute access on BaseMessage - Fix test_tool_calling.py: import StructuredTool from langchain_core.tools - Fix test_oci_data_science.py: remove unused type: ignore comment - Fix test_oci_generative_ai_responses_api.py: add type: ignore for LangGraph invoke arg type * Fix mypy errors for CI environment compatibility - Add type: ignore[unreachable] back to BaseTool isinstance check in oci_generative_ai.py (CI mypy flags this as unreachable) - Remove type: ignore[override] from bind_tools (CI reports unused) - Fix test_oci_data_science.py: explicitly type output variable and use explicit addition instead of += to avoid assignment type error - Remove unused type: ignore comments from test files * Fix Python 3.9 compatibility in test_oci_data_science.py - Use Optional[T] instead of T | None syntax for Python 3.9 compat - Add type: ignore[assignment] for AIMessageChunk addition * Simplify parallel tool calls: use provider property instead of model_id parsing Addresses reviewer feedback: - Add supports_parallel_tool_calls property to Provider base class (False) - Override in GenericProvider to return True (supports parallel calls) - CohereProvider inherits False (doesn't support parallel calls) - Remove _supports_parallel_tool_calls method with hacky model_id parsing - Simplify bind_tools to use provider property for validation - Remove Llama version-specific validation (let API fail naturally) - Update unit tests to focus on provider-based validation * Fix integration test for bind_tools validation timing * Fix mypy linting issues for Python 3.9 compatibility - Reorder convert_to_oci_tool checks to avoid unreachable code warning - Fix type annotation in test_stream_vllm to use BaseMessageChunk
1 parent 28bf80c commit 9a79fbd

File tree

6 files changed

+582
-24
lines changed

6 files changed

+582
-24
lines changed

libs/oci/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ embeddings.embed_query("What is the meaning of life?")
6262
```
6363

6464
### 4. Use Structured Output
65-
`ChatOCIGenAI` supports structured output.
65+
`ChatOCIGenAI` supports structured output.
6666

6767
<sub>**Note:** The default method is `function_calling`. If default method returns `None` (e.g. for Gemini models), try `json_schema` or `json_mode`.</sub>
6868

@@ -126,6 +126,27 @@ messages = [
126126
response = client.invoke(messages)
127127
```
128128

129+
### 6. Use Parallel Tool Calling (Meta/Llama 4+ models only)
130+
Enable parallel tool calling to execute multiple tools simultaneously, improving performance for multi-tool workflows.
131+
132+
```python
133+
from langchain_oci import ChatOCIGenAI
134+
135+
llm = ChatOCIGenAI(
136+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
137+
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
138+
compartment_id="MY_COMPARTMENT_ID",
139+
)
140+
141+
# Enable parallel tool calling in bind_tools
142+
llm_with_tools = llm.bind_tools(
143+
[get_weather, calculate_tip, get_population],
144+
parallel_tool_calls=True # Tools can execute simultaneously
145+
)
146+
```
147+
148+
<sub>**Note:** Parallel tool calling is only supported for Llama 4+ models. Llama 3.x (including 3.3) and Cohere models will raise an error if this parameter is used.</sub>
149+
129150

130151
## OCI Data Science Model Deployment Examples
131152

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def process_stream_tool_calls(
209209
"""Process streaming tool calls from event data into chunks."""
210210
...
211211

212+
@property
213+
def supports_parallel_tool_calls(self) -> bool:
214+
"""Whether this provider supports parallel tool calling.
215+
216+
Parallel tool calling allows the model to call multiple tools
217+
simultaneously in a single response.
218+
219+
Returns:
220+
bool: True if parallel tool calling is supported, False otherwise.
221+
"""
222+
return False
223+
212224

213225
class CohereProvider(Provider):
214226
"""Provider implementation for Cohere."""
@@ -363,6 +375,14 @@ def messages_to_oci_params(
363375
364376
This includes conversion of chat history and tool call results.
365377
"""
378+
# Cohere models don't support parallel tool calls
379+
if kwargs.get("is_parallel_tool_calls"):
380+
raise ValueError(
381+
"Parallel tool calls are not supported for Cohere models. "
382+
"This feature is only available for models using GenericChatRequest "
383+
"(Meta, Llama, xAI Grok, OpenAI, Mistral)."
384+
)
385+
366386
is_force_single_step = kwargs.get("is_force_single_step", False)
367387
oci_chat_history = []
368388

@@ -585,6 +605,11 @@ class GenericProvider(Provider):
585605

586606
stop_sequence_key: str = "stop"
587607

608+
@property
609+
def supports_parallel_tool_calls(self) -> bool:
610+
"""GenericProvider models support parallel tool calling."""
611+
return True
612+
588613
def __init__(self) -> None:
589614
from oci.generative_ai_inference import models
590615

@@ -851,6 +876,10 @@ def _should_allow_more_tool_calls(
851876
result["tool_choice"] = self.oci_tool_choice_none()
852877
# else: Allow model to decide (default behavior)
853878

879+
# Add parallel tool calls support (GenericChatRequest models)
880+
if "is_parallel_tool_calls" in kwargs:
881+
result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"]
882+
854883
return result
855884

856885
def _process_message_content(
@@ -916,23 +945,9 @@ def convert_to_oci_tool(
916945
Raises:
917946
ValueError: If the tool type is not supported.
918947
"""
919-
if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool):
920-
as_json_schema_function = convert_to_openai_function(tool)
921-
parameters = as_json_schema_function.get("parameters", {})
948+
# Check BaseTool first since it's callable but needs special handling
949+
if isinstance(tool, BaseTool):
922950
return self.oci_function_definition(
923-
name=as_json_schema_function.get("name"),
924-
description=as_json_schema_function.get(
925-
"description",
926-
as_json_schema_function.get("name"),
927-
),
928-
parameters={
929-
"type": "object",
930-
"properties": parameters.get("properties", {}),
931-
"required": parameters.get("required", []),
932-
},
933-
)
934-
elif isinstance(tool, BaseTool): # type: ignore[unreachable]
935-
return self.oci_function_definition( # type: ignore[unreachable]
936951
name=tool.name,
937952
description=OCIUtils.remove_signature_from_tool_description(
938953
tool.name, tool.description
@@ -953,6 +968,21 @@ def convert_to_oci_tool(
953968
],
954969
},
955970
)
971+
if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool):
972+
as_json_schema_function = convert_to_openai_function(tool)
973+
parameters = as_json_schema_function.get("parameters", {})
974+
return self.oci_function_definition(
975+
name=as_json_schema_function.get("name"),
976+
description=as_json_schema_function.get(
977+
"description",
978+
as_json_schema_function.get("name"),
979+
),
980+
parameters={
981+
"type": "object",
982+
"properties": parameters.get("properties", {}),
983+
"required": parameters.get("required", []),
984+
},
985+
)
956986
raise ValueError(
957987
f"Unsupported tool type {type(tool)}. "
958988
"Tool must be passed in as a BaseTool "
@@ -1211,6 +1241,7 @@ def bind_tools(
12111241
tool_choice: Optional[
12121242
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
12131243
] = None,
1244+
parallel_tool_calls: Optional[bool] = None,
12141245
**kwargs: Any,
12151246
) -> Runnable[LanguageModelInput, BaseMessage]:
12161247
"""Bind tool-like objects to this chat model.
@@ -1231,6 +1262,11 @@ def bind_tools(
12311262
{"type": "function", "function": {"name": <<tool_name>>}}:
12321263
calls <<tool_name>> tool.
12331264
- False or None: no effect, default Meta behavior.
1265+
parallel_tool_calls: Whether to enable parallel function calling.
1266+
If True, the model can call multiple tools simultaneously.
1267+
If False or None (default), tools are called sequentially.
1268+
Supported for models using GenericChatRequest (Meta, xAI Grok,
1269+
OpenAI, Mistral). Not supported for Cohere models.
12341270
kwargs: Any additional parameters are passed directly to
12351271
:meth:`~langchain_oci.chat_models.oci_generative_ai.ChatOCIGenAI.bind`.
12361272
"""
@@ -1240,6 +1276,15 @@ def bind_tools(
12401276
if tool_choice is not None:
12411277
kwargs["tool_choice"] = self._provider.process_tool_choice(tool_choice)
12421278

1279+
# Add parallel tool calls support (only when explicitly enabled)
1280+
if parallel_tool_calls:
1281+
if not self._provider.supports_parallel_tool_calls:
1282+
raise ValueError(
1283+
"Parallel tool calls not supported for this provider. "
1284+
"Only GenericChatRequest models support parallel tool calling."
1285+
)
1286+
kwargs["is_parallel_tool_calls"] = True
1287+
12431288
return super().bind(tools=formatted_tools, **kwargs)
12441289

12451290
def with_structured_output(

0 commit comments

Comments
 (0)