From 4952aa80e2946e3878067759b1bd0a19b96ad2fb Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 6 Jun 2024 10:09:06 -0700 Subject: [PATCH 1/2] Add chat.count_tokens Change-Id: Ibbc4a88e0beb0f202121f6e32142d2be1a93e9c7 --- google/generativeai/generative_models.py | 51 ++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 10744a948..2c7aab11c 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -4,12 +4,9 @@ from collections.abc import Iterable import textwrap -from typing import Any, Union, overload +from typing import overload import reprlib -# pylint: disable=bad-continuation, line-too-long - - import google.api_core.exceptions from google.generativeai import protos from google.generativeai import client @@ -504,6 +501,52 @@ def __init__( self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling + def count_tokens( + self, + content: content_types.ContentType | None = None, + *, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ): + history = self.history[:] + + if content is not None: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history.append(content) + + return self.model.count_tokens( + contents=history, + tools=tools, + tool_config=tool_config, + request_options=request_options, + ) + + def count_tokens_async( + self, + content: content_types.ContentType | None = None, + *, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ): + history = self.history[:] + + if content is not None: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history.append(content) + + return await self.model.count_tokens( + contents=history, + tools=tools, + tool_config=tool_config, + request_options=request_options, + ) + def send_message( self, content: content_types.ContentType, From e852560c54fbe156d5889dd1296940a068eb0aa4 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 24 Jul 2024 08:44:09 -0700 Subject: [PATCH 2/2] Fix tests. Change-Id: Ifa2f917ff4b2cc2595bb8076459ec1289ae286d2 --- google/generativeai/generative_models.py | 2 +- tests/test_async_code_match.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 85b33b2fa..ee6a81b69 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -529,7 +529,7 @@ def count_tokens( request_options=request_options, ) - def count_tokens_async( + async def count_tokens_async( self, content: content_types.ContentType | None = None, *, diff --git a/tests/test_async_code_match.py b/tests/test_async_code_match.py index 0ec4550d4..2e45a9973 100644 --- a/tests/test_async_code_match.py +++ b/tests/test_async_code_match.py @@ -62,7 +62,7 @@ def _inspect_decorator_exemption(self, node, fpath) -> bool: return False - def _execute_code_match(self, source, asource): + def _execute_code_match(self, source, asource, fpath): asource = ( asource.replace("anext", "next") .replace("aiter", "iter") @@ -73,7 +73,7 @@ def _execute_code_match(self, source, asource): .replace("ASYNC_", "") ) asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource) + self.assertEqual(source, asource, f"Matching {fpath}") def test_code_match_for_async_methods(self): for fpath in (pathlib.Path(__file__).parent.parent / "google").rglob("*.py"): @@ -101,7 +101,7 @@ def test_code_match_for_async_methods(self): ) func_source = self._maybe_trim_docstring(snode) func_asource = self._maybe_trim_docstring(anode) - self._execute_code_match(func_source, func_asource) + self._execute_code_match(func_source, func_asource, fpath) # print(f"Matched {node.name}") else: code_match_funcs[node.name] = node