Skip to content

Commit 587f785

Browse files
authored
langchain_oci bugfix: include ToolMessages in chat history (#8)
* include ToolMessages in chat history * fix linting / format issues * use self.tool_call and self.tool_result
1 parent 546b8c4 commit 587f785

File tree

9 files changed

+52
-33
lines changed

9 files changed

+52
-33
lines changed

libs/oci/langchain_oci/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4-
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
54
from langchain_oci.chat_models.oci_data_science import (
65
ChatOCIModelDeployment,
76
ChatOCIModelDeploymentTGI,
8-
ChatOCIModelDeploymentVLLM
7+
ChatOCIModelDeploymentVLLM,
8+
)
9+
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
10+
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import (
11+
OCIModelDeploymentEndpointEmbeddings,
912
)
1013
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings
11-
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
1214
from langchain_oci.llms.oci_data_science_model_deployment_endpoint import (
1315
BaseOCIModelDeployment,
1416
OCIModelDeploymentLLM,

libs/oci/langchain_oci/chat_models/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
from langchain_oci.chat_models.oci_data_science import (
55
ChatOCIModelDeployment,
66
ChatOCIModelDeploymentTGI,
7-
ChatOCIModelDeploymentVLLM
7+
ChatOCIModelDeploymentVLLM,
88
)
99
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
1010

11-
__all__ = ["ChatOCIGenAI", "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM"]
11+
__all__ = [
12+
"ChatOCIGenAI",
13+
"ChatOCIModelDeployment",
14+
"ChatOCIModelDeploymentTGI",
15+
"ChatOCIModelDeploymentVLLM",
16+
]

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,19 @@ def messages_to_oci_params(
362362
message=msg_content, tool_calls=tool_calls
363363
)
364364
)
365+
elif isinstance(msg, ToolMessage):
366+
oci_chat_history.append(
367+
self.oci_chat_message[self.get_role(msg)](
368+
tool_results=[
369+
self.oci_tool_result(
370+
call=self.oci_tool_call(
371+
name=msg.name, parameters={}
372+
),
373+
outputs=[{"output": msg.content}],
374+
)
375+
],
376+
)
377+
)
365378

366379
# Process current turn messages in reverse order until a HumanMessage
367380
current_turn = []
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4-
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
4+
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import (
5+
OCIModelDeploymentEndpointEmbeddings,
6+
)
57
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings
68

79
__all__ = ["OCIModelDeploymentEndpointEmbeddings", "OCIGenAIEmbeddings"]

libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4+
from typing import Any, Callable, Dict, List, Mapping, Optional
5+
6+
import requests
47
from langchain_core.embeddings import Embeddings
58
from langchain_core.language_models.llms import create_base_retry_decorator
69
from langchain_core.utils import get_from_dict_or_env
710
from pydantic import BaseModel, Field, model_validator
8-
import requests
9-
from typing import Any, Callable, Dict, List, Mapping, Optional
10-
1111

1212
DEFAULT_HEADER = {
1313
"Content-Type": "application/json",
@@ -39,7 +39,7 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
3939
embeddings = OCIModelDeploymentEndpointEmbeddings(
4040
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
4141
)
42-
""" # noqa: E501
42+
""" # noqa: E501
4343

4444
auth: dict = Field(default_factory=dict, exclude=True)
4545
"""ADS auth dictionary for OCI authentication:

libs/oci/langchain_oci/embeddings/oci_generative_ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
143143
oci_config=client_kwargs["config"]
144144
)
145145
elif values["auth_type"] == OCIAuthType(3).name:
146-
client_kwargs[
147-
"signer"
148-
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
146+
client_kwargs["signer"] = (
147+
oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
148+
)
149149
elif values["auth_type"] == OCIAuthType(4).name:
150-
client_kwargs[
151-
"signer"
152-
] = oci.auth.signers.get_resource_principals_signer()
150+
client_kwargs["signer"] = (
151+
oci.auth.signers.get_resource_principals_signer()
152+
)
153153
else:
154154
raise ValueError("Please provide valid value to auth_type")
155155

libs/oci/langchain_oci/llms/oci_data_science_model_deployment_endpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
"""LLM for OCI data science model deployment endpoint."""
55

6-
from contextlib import asynccontextmanager
76
import json
87
import logging
98
import traceback
9+
from contextlib import asynccontextmanager
1010
from typing import (
1111
Any,
1212
AsyncGenerator,
@@ -793,6 +793,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
793793
)
794794
795795
"""
796+
796797
max_tokens: int = 256
797798
"""Denotes the number of tokens to predict per generation."""
798799

@@ -943,6 +944,7 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
943944
)
944945
945946
"""
947+
946948
max_tokens: int = 256
947949
"""Denotes the number of tokens to predict per generation."""
948950

libs/oci/langchain_oci/llms/oci_generative_ai.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
class Provider(ABC):
2323
@property
2424
@abstractmethod
25-
def stop_sequence_key(self) -> str:
26-
...
25+
def stop_sequence_key(self) -> str: ...
2726

2827
@abstractmethod
29-
def completion_response_to_text(self, response: Any) -> str:
30-
...
28+
def completion_response_to_text(self, response: Any) -> str: ...
3129

3230

3331
class CohereProvider(Provider):
@@ -159,13 +157,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
159157
oci_config=client_kwargs["config"]
160158
)
161159
elif values["auth_type"] == OCIAuthType(3).name:
162-
client_kwargs[
163-
"signer"
164-
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
160+
client_kwargs["signer"] = (
161+
oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
162+
)
165163
elif values["auth_type"] == OCIAuthType(4).name:
166-
client_kwargs[
167-
"signer"
168-
] = oci.auth.signers.get_resource_principals_signer()
164+
client_kwargs["signer"] = (
165+
oci.auth.signers.get_resource_principals_signer()
166+
)
169167
else:
170168
raise ValueError(
171169
"Please provide valid value to auth_type, "

libs/oci/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Test OCI Data Science Model Deployment Endpoint."""
22

3-
import responses
43
import pytest
4+
import responses
55
from pytest_mock import MockerFixture
6+
67
from langchain_oci.embeddings import OCIModelDeploymentEndpointEmbeddings
78

89

@@ -17,11 +18,7 @@ def test_embedding_call(mocker: MockerFixture) -> None:
1718
responses.POST,
1819
endpoint,
1920
json={
20-
"data": [
21-
{
22-
"embedding": expected_output
23-
}
24-
],
21+
"data": [{"embedding": expected_output}],
2522
},
2623
status=200,
2724
)

0 commit comments

Comments
 (0)