Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from .._utils import (
guard_tool_call_id as _guard_tool_call_id,
now_utc as _now_utc,
number_to_datetime,
)
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
from ..exceptions import UserError
from ..messages import (
Expand Down Expand Up @@ -54,6 +58,7 @@
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent

try:
import tiktoken
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
from openai.types import AllModels, chat, responses
from openai.types.chat import (
Expand Down Expand Up @@ -907,6 +912,23 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch
)
return ChatCompletionContentPartTextParam(text=text, type='text')

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

openai_messages = await self._map_messages(messages, model_request_parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call self.prepare_request before this call, like we do in the other model classes' count_tokens methods

token_count = num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@deprecated(
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
Expand Down Expand Up @@ -1701,6 +1723,25 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)

async def count_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it, let's update the docstring for UsageLimits.count_tokens_before_request to make it explicit which models support it (i.e. which implement the count_tokens method)

self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

_, openai_messages = await self._map_messages(
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
token_count = num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@dataclass
class OpenAIStreamedResponse(StreamedResponse):
Expand Down Expand Up @@ -2333,3 +2374,50 @@ def _map_mcp_call(
provider_name=provider_name,
),
)


def num_tokens_from_messages(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this a private function

messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam],
model: OpenAIModelName,
) -> int:
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding('o200k_base')
if model in {
'gpt-3.5-turbo-0125',
'gpt-4-0314',
'gpt-4-32k-0314',
'gpt-4-0613',
'gpt-4-32k-0613',
'gpt-4o-mini-2024-07-18',
'gpt-4o-2024-08-06',
}:
tokens_per_message = 3
final_primer = 3 # every reply is primed with <|start|>assistant<|message|>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include a link to the doc we took this from

elif model in {
'gpt-5-2025-08-07',
}:
tokens_per_message = 3
final_primer = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it explicit that this one was "reverse engineered"

elif 'gpt-3.5-turbo' in model:
return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125')
elif 'gpt-4o-mini' in model:
return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18')
elif 'gpt-4o' in model:
return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06')
elif 'gpt-4' in model:
return num_tokens_from_messages(messages, model='gpt-4-0613')
elif 'gpt-5' in model:
return num_tokens_from_messages(messages, model='gpt-5-2025-08-07')
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify all of this as if 'gpt-5' in model: <do the new thing> else: <do the old thing>

num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for value in message.values():
if isinstance(value, str):
num_tokens += len(encoding.encode(value))
num_tokens += final_primer
return num_tokens
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dependencies = [
# WARNING if you add optional groups, please update docs/install.md
logfire = ["logfire[httpx]>=3.14.1"]
# Models
openai = ["openai>=1.107.2"]
openai = ["openai>=1.107.2","tiktoken>=0.12.0"]
cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.51.0"]
Expand Down
Loading
Loading