Skip to content

Commit 72bf79a

Browse files
committed
[anthropic] add count_tokens
1 parent ca678a2 commit 72bf79a

File tree

1 file changed

+89
-19
lines changed

1 file changed

+89
-19
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656

5757
try:
58-
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream, omit as OMIT
58+
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropicBedrock, AsyncStream, omit as OMIT
5959
from anthropic.types.beta import (
6060
BetaBase64PDFBlockParam,
6161
BetaBase64PDFSourceParam,
@@ -76,6 +76,7 @@
7676
BetaMemoryTool20250818Param,
7777
BetaMessage,
7878
BetaMessageParam,
79+
BetaMessageTokensCount,
7980
BetaMetadataParam,
8081
BetaPlainTextSourceParam,
8182
BetaRawContentBlockDeltaEvent,
@@ -239,6 +240,23 @@ async def request(
239240
model_response = self._process_response(response)
240241
return model_response
241242

243+
async def count_tokens(
244+
self,
245+
messages: list[ModelMessage],
246+
model_settings: ModelSettings | None,
247+
model_request_parameters: ModelRequestParameters,
248+
) -> usage.RequestUsage:
249+
model_settings, model_request_parameters = self.prepare_request(
250+
model_settings,
251+
model_request_parameters,
252+
)
253+
254+
response = await self._messages_count_tokens(
255+
messages, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
256+
)
257+
258+
return usage.RequestUsage(input_tokens=response.input_tokens)
259+
242260
@asynccontextmanager
243261
async def request_stream(
244262
self,
@@ -310,28 +328,12 @@ async def _messages_create(
310328
tools = self._get_tools(model_request_parameters, model_settings)
311329
tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
312330

313-
tool_choice: BetaToolChoiceParam | None
314-
315-
if not tools:
316-
tool_choice = None
317-
else:
318-
if not model_request_parameters.allow_text_output:
319-
tool_choice = {'type': 'any'}
320-
else:
321-
tool_choice = {'type': 'auto'}
322-
323-
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
324-
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
331+
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
325332

326333
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
327334

328335
try:
329-
extra_headers = model_settings.get('extra_headers', {})
330-
extra_headers.setdefault('User-Agent', get_user_agent())
331-
if beta_features:
332-
if 'anthropic-beta' in extra_headers:
333-
beta_features.insert(0, extra_headers['anthropic-beta'])
334-
extra_headers['anthropic-beta'] = ','.join(beta_features)
336+
extra_headers = self._map_extra_headers(beta_features, model_settings)
335337

336338
return await self.client.beta.messages.create(
337339
max_tokens=model_settings.get('max_tokens', 4096),
@@ -356,6 +358,43 @@ async def _messages_create(
356358
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
357359
raise # pragma: lax no cover
358360

361+
async def _messages_count_tokens(
362+
self,
363+
messages: list[ModelMessage],
364+
model_settings: AnthropicModelSettings,
365+
model_request_parameters: ModelRequestParameters,
366+
) -> BetaMessageTokensCount:
367+
if isinstance(self.client, AsyncAnthropicBedrock):
368+
raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.')
369+
370+
# standalone function to make it easier to override
371+
tools = self._get_tools(model_request_parameters, model_settings)
372+
tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
373+
374+
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
375+
376+
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
377+
378+
try:
379+
extra_headers = self._map_extra_headers(beta_features, model_settings)
380+
381+
return await self.client.beta.messages.count_tokens(
382+
system=system_prompt or OMIT,
383+
messages=anthropic_messages,
384+
model=self._model_name,
385+
tools=tools or OMIT,
386+
tool_choice=tool_choice or OMIT,
387+
mcp_servers=mcp_servers or OMIT,
388+
thinking=model_settings.get('anthropic_thinking', OMIT),
389+
timeout=model_settings.get('timeout', NOT_GIVEN),
390+
extra_headers=extra_headers,
391+
extra_body=model_settings.get('extra_body'),
392+
)
393+
except APIStatusError as e:
394+
if (status_code := e.status_code) >= 400:
395+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
396+
raise # pragma: lax no cover
397+
359398
def _process_response(self, response: BetaMessage) -> ModelResponse:
360399
"""Process a non-streamed response, and prepare a message to return."""
361400
items: list[ModelResponsePart] = []
@@ -492,6 +531,37 @@ def _add_builtin_tools(
492531
)
493532
return tools, mcp_servers, beta_features
494533

534+
def _infer_tool_choice(
535+
self,
536+
tools: list[BetaToolUnionParam],
537+
model_settings: AnthropicModelSettings,
538+
model_request_parameters: ModelRequestParameters,
539+
) -> BetaToolChoiceParam | None:
540+
if not tools:
541+
return None
542+
else:
543+
tool_choice: BetaToolChoiceParam
544+
545+
if not model_request_parameters.allow_text_output:
546+
tool_choice = {'type': 'any'}
547+
else:
548+
tool_choice = {'type': 'auto'}
549+
550+
if 'parallel_tool_calls' in model_settings:
551+
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
552+
553+
return tool_choice
554+
555+
def _map_extra_headers(self, beta_features: list[str], model_settings: AnthropicModelSettings) -> dict[str, str]:
556+
"""Apply beta_features to extra_headers in model_settings."""
557+
extra_headers = model_settings.get('extra_headers', {})
558+
extra_headers.setdefault('User-Agent', get_user_agent())
559+
if beta_features:
560+
if 'anthropic-beta' in extra_headers:
561+
beta_features.insert(0, extra_headers['anthropic-beta'])
562+
extra_headers['anthropic-beta'] = ','.join(beta_features)
563+
return extra_headers
564+
495565
async def _map_message( # noqa: C901
496566
self,
497567
messages: list[ModelMessage],

0 commit comments

Comments
 (0)