-
Notifications
You must be signed in to change notification settings - Fork 1.4k
WIP: Add count_tokens for openAI models #3447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d7f0b87
80a61f1
cc8cbf0
c1be8c1
1332cd8
cb5da87
6396f5d
46cd331
86a0b89
bacf788
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,11 @@ | |
| from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition | ||
| from .._run_context import RunContext | ||
| from .._thinking_part import split_content_into_text_and_thinking | ||
| from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime | ||
| from .._utils import ( | ||
| guard_tool_call_id as _guard_tool_call_id, | ||
| now_utc as _now_utc, | ||
| number_to_datetime, | ||
| ) | ||
| from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool | ||
| from ..exceptions import UserError | ||
| from ..messages import ( | ||
|
|
@@ -54,6 +58,7 @@ | |
| from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent | ||
|
|
||
| try: | ||
| import tiktoken | ||
| from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream | ||
| from openai.types import AllModels, chat, responses | ||
| from openai.types.chat import ( | ||
|
|
@@ -907,6 +912,23 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch | |
| ) | ||
| return ChatCompletionContentPartTextParam(text=text, type='text') | ||
|
|
||
| async def count_tokens( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> usage.RequestUsage: | ||
| """Count the number of tokens in the given messages.""" | ||
| if self.system != 'openai': | ||
| raise NotImplementedError('Token counting is only supported for OpenAI system.') | ||
|
|
||
| openai_messages = await self._map_messages(messages, model_request_parameters) | ||
| token_count = num_tokens_from_messages(openai_messages, self.model_name) | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return usage.RequestUsage( | ||
| input_tokens=token_count, | ||
| ) | ||
|
|
||
|
|
||
| @deprecated( | ||
| '`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which ' | ||
|
|
@@ -1701,6 +1723,25 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa | |
| assert_never(item) | ||
| return responses.EasyInputMessageParam(role='user', content=content) | ||
|
|
||
| async def count_tokens( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While we're at it, let's update the docstring for |
||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> usage.RequestUsage: | ||
| """Count the number of tokens in the given messages.""" | ||
| if self.system != 'openai': | ||
| raise NotImplementedError('Token counting is only supported for OpenAI system.') | ||
|
|
||
| _, openai_messages = await self._map_messages( | ||
| messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters | ||
| ) | ||
| token_count = num_tokens_from_messages(openai_messages, self.model_name) | ||
|
|
||
| return usage.RequestUsage( | ||
| input_tokens=token_count, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class OpenAIStreamedResponse(StreamedResponse): | ||
|
|
@@ -2333,3 +2374,50 @@ def _map_mcp_call( | |
| provider_name=provider_name, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def num_tokens_from_messages( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please make this a private function |
||
| messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam], | ||
| model: OpenAIModelName, | ||
| ) -> int: | ||
| """Return the number of tokens used by a list of messages.""" | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| encoding = tiktoken.get_encoding('o200k_base') | ||
| if model in { | ||
| 'gpt-3.5-turbo-0125', | ||
| 'gpt-4-0314', | ||
| 'gpt-4-32k-0314', | ||
| 'gpt-4-0613', | ||
| 'gpt-4-32k-0613', | ||
| 'gpt-4o-mini-2024-07-18', | ||
| 'gpt-4o-2024-08-06', | ||
| }: | ||
| tokens_per_message = 3 | ||
| final_primer = 3 # every reply is primed with <|start|>assistant<|message|> | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's include a link to the doc we took this from |
||
| elif model in { | ||
| 'gpt-5-2025-08-07', | ||
| }: | ||
| tokens_per_message = 3 | ||
| final_primer = 2 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make it explicit that this one was "reverse engineered" |
||
| elif 'gpt-3.5-turbo' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125') | ||
| elif 'gpt-4o-mini' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18') | ||
| elif 'gpt-4o' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06') | ||
| elif 'gpt-4' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4-0613') | ||
| elif 'gpt-5' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-5-2025-08-07') | ||
| else: | ||
| raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's simplify all of this as |
||
| num_tokens = 0 | ||
| for message in messages: | ||
| num_tokens += tokens_per_message | ||
| for value in message.values(): | ||
| if isinstance(value, str): | ||
| num_tokens += len(encoding.encode(value)) | ||
| num_tokens += final_primer | ||
| return num_tokens | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should call
self.prepare_requestbefore this call, like we do in the other model classes'count_tokensmethods