5555
5656
5757try :
58- from anthropic import NOT_GIVEN , APIStatusError , AsyncStream , omit as OMIT
58+ from anthropic import NOT_GIVEN , APIStatusError , AsyncAnthropicBedrock , AsyncStream , omit as OMIT
5959 from anthropic .types .beta import (
6060 BetaBase64PDFBlockParam ,
6161 BetaBase64PDFSourceParam ,
7676 BetaMemoryTool20250818Param ,
7777 BetaMessage ,
7878 BetaMessageParam ,
79+ BetaMessageTokensCount ,
7980 BetaMetadataParam ,
8081 BetaPlainTextSourceParam ,
8182 BetaRawContentBlockDeltaEvent ,
@@ -239,6 +240,23 @@ async def request(
239240 model_response = self ._process_response (response )
240241 return model_response
241242
243+ async def count_tokens (
244+ self ,
245+ messages : list [ModelMessage ],
246+ model_settings : ModelSettings | None ,
247+ model_request_parameters : ModelRequestParameters ,
248+ ) -> usage .RequestUsage :
249+ model_settings , model_request_parameters = self .prepare_request (
250+ model_settings ,
251+ model_request_parameters ,
252+ )
253+
254+ response = await self ._messages_count_tokens (
255+ messages , cast (AnthropicModelSettings , model_settings or {}), model_request_parameters
256+ )
257+
258+ return usage .RequestUsage (input_tokens = response .input_tokens )
259+
242260 @asynccontextmanager
243261 async def request_stream (
244262 self ,
@@ -310,28 +328,12 @@ async def _messages_create(
310328 tools = self ._get_tools (model_request_parameters , model_settings )
311329 tools , mcp_servers , beta_features = self ._add_builtin_tools (tools , model_request_parameters )
312330
313- tool_choice : BetaToolChoiceParam | None
314-
315- if not tools :
316- tool_choice = None
317- else :
318- if not model_request_parameters .allow_text_output :
319- tool_choice = {'type' : 'any' }
320- else :
321- tool_choice = {'type' : 'auto' }
322-
323- if (allow_parallel_tool_calls := model_settings .get ('parallel_tool_calls' )) is not None :
324- tool_choice ['disable_parallel_tool_use' ] = not allow_parallel_tool_calls
331+ tool_choice = self ._infer_tool_choice (tools , model_settings , model_request_parameters )
325332
326333 system_prompt , anthropic_messages = await self ._map_message (messages , model_request_parameters , model_settings )
327334
328335 try :
329- extra_headers = model_settings .get ('extra_headers' , {})
330- extra_headers .setdefault ('User-Agent' , get_user_agent ())
331- if beta_features :
332- if 'anthropic-beta' in extra_headers :
333- beta_features .insert (0 , extra_headers ['anthropic-beta' ])
334- extra_headers ['anthropic-beta' ] = ',' .join (beta_features )
336+ extra_headers = self ._map_extra_headers (beta_features , model_settings )
335337
336338 return await self .client .beta .messages .create (
337339 max_tokens = model_settings .get ('max_tokens' , 4096 ),
@@ -356,6 +358,43 @@ async def _messages_create(
356358 raise ModelHTTPError (status_code = status_code , model_name = self .model_name , body = e .body ) from e
357359 raise # pragma: lax no cover
358360
361+ async def _messages_count_tokens (
362+ self ,
363+ messages : list [ModelMessage ],
364+ model_settings : AnthropicModelSettings ,
365+ model_request_parameters : ModelRequestParameters ,
366+ ) -> BetaMessageTokensCount :
367+ if isinstance (self .client , AsyncAnthropicBedrock ):
368+ raise UserError ('AsyncAnthropicBedrock client does not support `count_tokens` api.' )
369+
370+ # standalone function to make it easier to override
371+ tools = self ._get_tools (model_request_parameters , model_settings )
372+ tools , mcp_servers , beta_features = self ._add_builtin_tools (tools , model_request_parameters )
373+
374+ tool_choice = self ._infer_tool_choice (tools , model_settings , model_request_parameters )
375+
376+ system_prompt , anthropic_messages = await self ._map_message (messages , model_request_parameters , model_settings )
377+
378+ try :
379+ extra_headers = self ._map_extra_headers (beta_features , model_settings )
380+
381+ return await self .client .beta .messages .count_tokens (
382+ system = system_prompt or OMIT ,
383+ messages = anthropic_messages ,
384+ model = self ._model_name ,
385+ tools = tools or OMIT ,
386+ tool_choice = tool_choice or OMIT ,
387+ mcp_servers = mcp_servers or OMIT ,
388+ thinking = model_settings .get ('anthropic_thinking' , OMIT ),
389+ timeout = model_settings .get ('timeout' , NOT_GIVEN ),
390+ extra_headers = extra_headers ,
391+ extra_body = model_settings .get ('extra_body' ),
392+ )
393+ except APIStatusError as e :
394+ if (status_code := e .status_code ) >= 400 :
395+ raise ModelHTTPError (status_code = status_code , model_name = self .model_name , body = e .body ) from e
396+ raise # pragma: lax no cover
397+
359398 def _process_response (self , response : BetaMessage ) -> ModelResponse :
360399 """Process a non-streamed response, and prepare a message to return."""
361400 items : list [ModelResponsePart ] = []
@@ -492,6 +531,37 @@ def _add_builtin_tools(
492531 )
493532 return tools , mcp_servers , beta_features
494533
534+ def _infer_tool_choice (
535+ self ,
536+ tools : list [BetaToolUnionParam ],
537+ model_settings : AnthropicModelSettings ,
538+ model_request_parameters : ModelRequestParameters ,
539+ ) -> BetaToolChoiceParam | None :
540+ if not tools :
541+ return None
542+ else :
543+ tool_choice : BetaToolChoiceParam
544+
545+ if not model_request_parameters .allow_text_output :
546+ tool_choice = {'type' : 'any' }
547+ else :
548+ tool_choice = {'type' : 'auto' }
549+
550+ if 'parallel_tool_calls' in model_settings :
551+ tool_choice ['disable_parallel_tool_use' ] = not model_settings ['parallel_tool_calls' ]
552+
553+ return tool_choice
554+
555+ def _map_extra_headers (self , beta_features : list [str ], model_settings : AnthropicModelSettings ) -> dict [str , str ]:
556+ """Apply beta_features to extra_headers in model_settings."""
557+ extra_headers = model_settings .get ('extra_headers' , {})
558+ extra_headers .setdefault ('User-Agent' , get_user_agent ())
559+ if beta_features :
560+ if 'anthropic-beta' in extra_headers :
561+ beta_features .insert (0 , extra_headers ['anthropic-beta' ])
562+ extra_headers ['anthropic-beta' ] = ',' .join (beta_features )
563+ return extra_headers
564+
495565 async def _map_message ( # noqa: C901
496566 self ,
497567 messages : list [ModelMessage ],
0 commit comments