Skip to content

Commit c125cbb

Browse files
Add total_tokens to AiMessage
Add total_tokens to AiMessage in cohere provider and generic provider
1 parent 5ce1280 commit c125cbb

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

libs/oci/PR_DESCRIPTION.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Add Token Usage to AIMessage Response
2+
3+
## Summary
4+
Adds `total_tokens` to `AIMessage.additional_kwargs` for non-streaming chat responses, enabling users to track token consumption when using `ChatOCIGenAI`.
5+
6+
## Problem
7+
When using `ChatOCIGenAI.invoke()`, token usage information (prompt_tokens, completion_tokens, total_tokens) from the OCI Generative AI API was not accessible in the `AIMessage` response, even though the raw OCI API returns this data.
8+
9+
## Solution
10+
Extract token usage from the OCI API response and add `total_tokens` to `additional_kwargs` in non-streaming mode.
11+
12+
### Changes Made
13+
**File:** `langchain_oci/chat_models/oci_generative_ai.py`
14+
15+
1. **CohereProvider.chat_generation_info()** (lines 246-248)
16+
- Extract `usage.total_tokens` from `response.data.chat_response.usage`
17+
- Add to `generation_info["total_tokens"]`
18+
19+
2. **GenericProvider.chat_generation_info()** (lines 611-613)
20+
- Same extraction for Meta/Llama models
21+
22+
## Usage
23+
24+
### Before
25+
```python
26+
response = chat.invoke("What is the capital of France?")
27+
# No way to access token usage
28+
```
29+
30+
### After
31+
```python
32+
response = chat.invoke("What is the capital of France?")
33+
print(response.additional_kwargs.get('total_tokens')) # 26
34+
```
35+
36+
## Limitations
37+
- **Streaming mode:** Token usage is NOT available when `is_stream=True` because the OCI Generative AI streaming API does not include usage statistics in stream events.
38+
- **Non-streaming only:** Use `is_stream=False` to get token usage information.
39+
40+
## Testing
41+
Tested with:
42+
- ✅ Cohere Command-R models (`cohere.command-r-plus-08-2024`)
43+
- ✅ Meta Llama models (`meta.llama-3.3-70b-instruct`)
44+
- ✅ Non-streaming mode (`is_stream=False`)
45+
- ❌ Streaming mode (not supported by OCI API)
46+
47+
## Backward Compatibility
48+
✅ Fully backward compatible - existing code continues to work unchanged.

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
242242
"is_search_required": response.data.chat_response.is_search_required,
243243
"finish_reason": response.data.chat_response.finish_reason,
244244
}
245+
246+
# Include token usage if available
247+
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
248+
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
249+
245250
# Include tool calls if available
246251
if self.chat_tool_calls(response):
247252
generation_info["tool_calls"] = self.format_response_tool_calls(
@@ -602,6 +607,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
602607
"finish_reason": response.data.chat_response.choices[0].finish_reason,
603608
"time_created": str(response.data.chat_response.time_created),
604609
}
610+
611+
# Include token usage if available
612+
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
613+
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
614+
605615
if self.chat_tool_calls(response):
606616
generation_info["tool_calls"] = self.format_response_tool_calls(
607617
self.chat_tool_calls(response)

0 commit comments

Comments
 (0)