Skip to content

Commit 2e65d06

Browse files
[MISC] Add Token Usage to AIMessage Response (#48)
* Add total_tokens to AiMessage Add total_tokens to AiMessage in cohere provider and generic provider * Revert "Add total_tokens to AiMessage" This reverts commit c125cbb. * Add total_tokens to AIMessage Add total tokens to AIMessage in cohere provider and generic provider * Add unit test for total_tokens Add unit test for total_tokens
1 parent 5ce1280 commit 2e65d06

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
242242
"is_search_required": response.data.chat_response.is_search_required,
243243
"finish_reason": response.data.chat_response.finish_reason,
244244
}
245+
246+
# Include token usage if available
247+
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
248+
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
249+
245250
# Include tool calls if available
246251
if self.chat_tool_calls(response):
247252
generation_info["tool_calls"] = self.format_response_tool_calls(
@@ -602,6 +607,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
602607
"finish_reason": response.data.chat_response.choices[0].finish_reason,
603608
"time_created": str(response.data.chat_response.time_created),
604609
}
610+
611+
# Include token usage if available
612+
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
613+
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
614+
605615
if self.chat_tool_calls(response):
606616
generation_info["tool_calls"] = self.format_response_tool_calls(
607617
self.chat_tool_calls(response)

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
5555
"citations": None,
5656
"documents": None,
5757
"tool_calls": None,
58+
"usage": MockResponseDict(
59+
{
60+
"total_tokens": 50,
61+
}
62+
),
5863
}
5964
),
6065
"model_id": "cohere.command-r-16k",
@@ -116,6 +121,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
116121
)
117122
],
118123
"time_created": "2025-08-14T10:00:01.100000+00:00",
124+
"usage": MockResponseDict(
125+
{
126+
"total_tokens": 75,
127+
}
128+
),
119129
}
120130
),
121131
"model_id": "meta.llama-3.3-70b-instruct",
@@ -141,6 +151,13 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
141151
expected = "Assistant chat reply."
142152
actual = llm.invoke(messages, temperature=0.2)
143153
assert actual.content == expected
154+
155+
# Test total_tokens in additional_kwargs
156+
assert "total_tokens" in actual.additional_kwargs
157+
if provider == "cohere":
158+
assert actual.additional_kwargs["total_tokens"] == 50
159+
elif provider == "meta":
160+
assert actual.additional_kwargs["total_tokens"] == 75
144161

145162

146163
@pytest.mark.requires("oci")

0 commit comments

Comments
 (0)