diff --git a/google/generativeai/files.py b/google/generativeai/files.py index b2581bdcd..0d574fe25 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -22,12 +22,22 @@ from google.generativeai import protos from itertools import islice from io import IOBase +import asyncio from google.generativeai.types import file_types from google.generativeai.client import get_default_file_client -__all__ = ["upload_file", "get_file", "list_files", "delete_file"] +__all__ = [ + "upload_file", + "get_file", + "list_files", + "delete_file", + "upload_file_async", + "get_file_async", + "list_files_async", + "delete_file_async", +] mimetypes.add_type("image/webp", ".webp") @@ -88,6 +98,10 @@ def upload_file( return file_types.File(response) +async def upload_file_async(*args, **kwargs): + return await asyncio.to_thread(upload_file, *args, **kwargs) + + def list_files(page_size=100) -> Iterable[file_types.File]: """Calls the API to list files using a supported file service.""" client = get_default_file_client() @@ -97,6 +111,10 @@ def list_files(page_size=100) -> Iterable[file_types.File]: yield file_types.File(proto) +async def list_files_async(*args, **kwargs): + return await asyncio.to_thread(list_files, *args, **kwargs) + + def get_file(name: str) -> file_types.File: """Calls the API to retrieve a specified file using a supported file service.""" if "/" not in name: @@ -105,6 +123,10 @@ def get_file(name: str) -> file_types.File: return file_types.File(client.get_file(name=name)) +async def get_file_async(*args, **kwargs): + return await asyncio.to_thread(get_file, *args, **kwargs) + + def delete_file(name: str | file_types.File | protos.File): """Calls the API to permanently delete a specified file using a supported file service.""" if isinstance(name, (file_types.File, protos.File)): @@ -114,3 +136,7 @@ def delete_file(name: str | file_types.File | protos.File): request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) + + +async def delete_file_async(*args, **kwargs): + return await asyncio.to_thread(get_file, *args, **kwargs) diff --git a/samples/files.py b/samples/files.py index 8f98365aa..461e293d2 100644 --- a/samples/files.py +++ b/samples/files.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest from absl.testing import absltest import google import google.generativeai as genai import pathlib +import tempfile +import asyncio media = pathlib.Path(__file__).parents[1] / "third_party" @@ -127,5 +130,29 @@ def test_files_delete(self): # [END files_delete] +class AsyncTests(absltest.TestCase, unittest.IsolatedAsyncioTestCase): + async def test_upload_file_async(self): + import google.generativeai.files as files + + tempdir = pathlib.Path(tempfile.mkdtemp()) + results = [] + + async def create_and_upload_file(n: int): + fname = tempdir / str(n) + fname.write_text(str(n)) + file_obj = await files.upload_file_async(fname, mime_type="text/plain") + results.append(file_obj) + + tasks = [] + for n in range(5): + tasks.append(asyncio.create_task(create_and_upload_file(n))) + + for task in tasks: + await task + + self.assertLen(results, 5) + self.assertEqual(sorted(int(f.display_name) for f in results), list(range(5))) + + if __name__ == "__main__": absltest.main()